1 // -*- indent-tabs-mode:nil; c-basic-offset:4; -*-
2 //------------------------------------------------------------------------------
3 // Add MC2 annotations to C code.
4 // Copyright 2015 Patrick Lam <prof.lam@gmail.com>
6 // Permission is hereby granted, free of charge, to any person
7 // obtaining a copy of this software and associated documentation
8 // files (the "Software"), to deal with the Software without
9 // restriction, including without limitation the rights to use, copy,
10 // modify, merge, publish, distribute, sublicense, and/or sell copies
11 // of the Software, and to permit persons to whom the Software is
12 // furnished to do so, subject to the following conditions:
14 // Redistributions of source code must retain the above copyright
15 // notice, this list of conditions and the following disclaimers.
17 // Redistributions in binary form must reproduce the above copyright
18 // notice, this list of conditions and the following disclaimers in
19 // the documentation and/or other materials provided with the
22 // Neither the names of the University of Waterloo, nor the names of
23 // its contributors may be used to endorse or promote products derived
24 // from this Software without specific prior written permission.
26 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
27 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
28 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
29 // NONINFRINGEMENT. IN NO EVENT SHALL THE CONTRIBUTORS OR COPYRIGHT
30 // HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
31 // WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
32 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
33 // DEALINGS WITH THE SOFTWARE.
35 // Patrick Lam (prof.lam@gmail.com)
38 // Eli Bendersky (eliben@gmail.com)
40 //------------------------------------------------------------------------------
46 #include "clang/AST/AST.h"
47 #include "clang/AST/ASTContext.h"
48 #include "clang/AST/ASTConsumer.h"
49 #include "clang/AST/RecursiveASTVisitor.h"
50 #include "clang/ASTMatchers/ASTMatchers.h"
51 #include "clang/ASTMatchers/ASTMatchFinder.h"
52 #include "clang/Frontend/ASTConsumers.h"
53 #include "clang/Frontend/FrontendActions.h"
54 #include "clang/Frontend/CompilerInstance.h"
55 #include "clang/Lex/Lexer.h"
56 #include "clang/Tooling/CommonOptionsParser.h"
57 #include "clang/Tooling/Tooling.h"
58 #include "clang/Rewrite/Core/Rewriter.h"
59 #include "llvm/Support/raw_ostream.h"
60 #include "llvm/ADT/STLExtras.h"
62 using namespace clang;
63 using namespace clang::ast_matchers;
64 using namespace clang::driver;
65 using namespace clang::tooling;
68 static LangOptions LangOpts;
69 static llvm::cl::OptionCategory AddMC2AnnotationsCategory("Add MC2 Annotations");
71 static std::string encode(std::string varName) {
73 nn << "_m" << varName;
78 static std::string encodeFn(int num) {
85 static std::string encodePtr(int num) {
92 static std::string encodeRMW(int num) {
98 static int branchCount;
99 static std::string encodeBranch(int num) {
100 std::stringstream nn;
105 static int condCount;
106 static std::string encodeCond(int num) {
107 std::stringstream nn;
108 nn << "_cond" << num;
113 static std::string encodeRV(int num) {
114 std::stringstream nn;
119 static int funcCount;
121 struct ProvisionalName {
123 const DeclRefExpr * pname;
126 ProvisionalName(int index, const DeclRefExpr * pname) : index(index), pname(pname), length(encode(pname->getNameInfo().getName().getAsString()).length()), enabled(true) {}
127 ProvisionalName(int index, const DeclRefExpr * pname, int length) : index(index), pname(pname), length(length), enabled(true) {}
133 std::vector<ProvisionalName *> * pnames;
135 Update(SourceLocation loc, std::string update, std::vector<ProvisionalName *> * pnames) :
136 loc(loc), update(update), pnames(pnames) {}
139 for (auto pname : *pnames) delete pname;
144 void updateProvisionalName(std::vector<Update *> &DeferredUpdates, const ValueDecl * now_known, std::string mcVar) {
145 for (Update * u : DeferredUpdates) {
146 for (int i = 0; i < u->pnames->size(); i++) {
147 ProvisionalName * v = (*(u->pnames))[i];
148 if (!v->enabled) continue;
149 if (now_known == v->pname->getDecl()) {
151 std::string oldName = encode(v->pname->getNameInfo().getName().getAsString());
153 u->update.replace(v->index, v->length, mcVar);
154 for (int j = i+1; j < u->pnames->size(); j++) {
155 ProvisionalName * vv = (*(u->pnames))[j];
156 if (vv->index > v->index)
157 vv->index -= v->length - mcVar.length();
164 static const VarDecl * retrieveSingleDecl(const DeclStmt * s) {
165 // XXX iterate through all decls defined in s, not just the first one
166 assert(s->isSingleDecl() && isa<VarDecl>(s->getSingleDecl()) && "unsupported form of decl");
167 if (s->isSingleDecl() && isa<VarDecl>(s->getSingleDecl())) {
168 return cast<VarDecl>(s->getSingleDecl());
172 class FindCallArgVisitor : public RecursiveASTVisitor<FindCallArgVisitor> {
174 FindCallArgVisitor() : DE(NULL), UnaryOp(NULL) {}
176 bool VisitStmt(Stmt * s) {
178 if (UnaryOperator * uo = dyn_cast<UnaryOperator>(s)) {
179 if (uo->getOpcode() == UnaryOperatorKind::UO_AddrOf ||
180 uo->getOpcode() == UnaryOperatorKind::UO_Deref)
185 if (!DE && (DE = dyn_cast<DeclRefExpr>(s)))
191 UnaryOp = NULL; DE = NULL;
194 const UnaryOperator * RetrieveUnaryOp() {
197 const Stmt * s = UnaryOp;
203 if (const UnaryOperator * op = dyn_cast<UnaryOperator>(s))
204 s = op->getSubExpr();
205 else if (const CastExpr * op = dyn_cast<CastExpr>(s))
206 s = op->getSubExpr();
207 else if (const MemberExpr * op = dyn_cast<MemberExpr>(s))
218 const DeclRefExpr * RetrieveDeclRefExpr() {
223 const UnaryOperator * UnaryOp;
224 const DeclRefExpr * DE;
227 class FindLocalsVisitor : public RecursiveASTVisitor<FindLocalsVisitor> {
229 FindLocalsVisitor() : Vars() {}
231 bool VisitDeclRefExpr(DeclRefExpr * de) {
232 Vars.push_back(de->getDecl());
240 const TinyPtrVector<const NamedDecl *> RetrieveVars() {
245 TinyPtrVector<const NamedDecl *> Vars;
249 class MallocHandler : public MatchFinder::MatchCallback {
251 MallocHandler(std::set<const Expr *> & MallocExprs) :
252 MallocExprs(MallocExprs) {}
254 virtual void run(const MatchFinder::MatchResult &Result) {
255 const CallExpr * ce = Result.Nodes.getNodeAs<CallExpr>("callExpr");
257 MallocExprs.insert(ce);
261 std::set<const Expr *> &MallocExprs;
264 static void generateMC2Function(Rewriter & Rewrite,
269 const DeclRefExpr * lhs,
271 std::vector<ProvisionalName *> * vars1,
272 std::vector<Update *> & DeferredUpdates) {
273 // prettyprint the LHS (&newnode->value)
274 // e.g. int * _tmp0 = &newnode->value;
276 llvm::raw_string_ostream S(SStr);
277 e->printPretty(S, nullptr, Rewrite.getLangOpts());
278 const std::string &Str = S.str();
280 std::stringstream prel;
281 prel << "\nvoid * " << tmpname << " = " << Str << ";\n";
283 // MCID _p0 = MC2_function(1, MC2_PTR_LENGTH, _tmp0, _fn0);
284 prel << "MCID " << tmpFn << " = MC2_function_id(" << ++funcCount << ", 1, MC2_PTR_LENGTH, " << tmpname << ", ";
286 // XXX generate casts when they'd eliminate warnings
287 ProvisionalName * v = new ProvisionalName(prel.tellp(), lhs);
290 prel << encode(lhsName) << "); ";
292 Update * u = new Update(loc, prel.str(), vars1);
293 DeferredUpdates.push_back(u);
296 class LoadHandler : public MatchFinder::MatchCallback {
298 LoadHandler(Rewriter &Rewrite,
299 std::set<const NamedDecl *> & DeclsRead,
300 std::set<const VarDecl *> & DeclsNeedingMC,
301 std::map<const NamedDecl *, std::string> &DeclToMCVar,
302 std::map<const Expr *, std::string> &ExprToMCVar,
303 std::set<const Stmt *> & StmtsHandled,
304 std::map<const Expr *, SourceLocation> &Redirector,
305 std::vector<Update *> & DeferredUpdates) :
306 Rewrite(Rewrite), DeclsRead(DeclsRead), DeclsNeedingMC(DeclsNeedingMC), DeclToMCVar(DeclToMCVar),
307 ExprToMCVar(ExprToMCVar),
308 StmtsHandled(StmtsHandled), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
310 virtual void run(const MatchFinder::MatchResult &Result) {
311 std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
312 CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
313 const VarDecl * d = Result.Nodes.getNodeAs<VarDecl>("decl");
314 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
315 const Expr * lhs = NULL;
316 if (s && isa<BinaryOperator>(s)) lhs = cast<BinaryOperator>(s)->getLHS();
319 const DeclRefExpr * rhs = NULL;
320 MemberExpr * ml = NULL;
321 bool isAddrOfR = false, isAddrMemberR = false;
323 StmtsHandled.insert(s);
325 std::string n, n_decl;
327 FindCallArgVisitor fcaVisitor;
329 fcaVisitor.TraverseStmt(ce->getArg(0));
330 rhs = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
331 const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
332 isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
333 isAddrMemberR = ruop && isa<MemberExpr>(ruop->getSubExpr());
335 ml = dyn_cast<MemberExpr>(ruop->getSubExpr());
337 FindLocalsVisitor flv;
339 flv.TraverseStmt(const_cast<Stmt*>(cast<Stmt>(lhs)));
340 for (auto & d : flv.RetrieveVars()) {
341 const VarDecl * dd = cast<VarDecl>(d);
343 // XXX todo rhs for non-decl stmts
344 if (!isa<ParmVarDecl>(dd))
345 DeclsNeedingMC.insert(dd);
347 DeclToMCVar[dd] = encode(n);
350 FindCallArgVisitor fcaVisitor;
352 fcaVisitor.TraverseStmt(ce->getArg(0));
353 rhs = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
354 const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
355 isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
356 isAddrMemberR = ruop && isa<MemberExpr>(ruop->getSubExpr());
358 ml = dyn_cast<MemberExpr>(ruop->getSubExpr());
362 DeclsNeedingMC.insert(d);
364 DeclToMCVar[d] = encode(n);
368 fcaVisitor.TraverseStmt(ce);
369 const DeclRefExpr * dd = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
370 updateProvisionalName(DeferredUpdates, dd->getDecl(), encode(n));
371 DeclToMCVar[dd->getDecl()] = encode(n);
376 std::stringstream nol;
378 if (lhs && isa<DeclRefExpr>(lhs)) {
379 const DeclRefExpr * ll = cast<DeclRefExpr>(lhs);
380 ProvisionalName * v = new ProvisionalName(nol.tellp(), ll);
387 nol << n_decl << encode(n) << "=";
388 nol << "MC2_nextOpLoadOffset(";
390 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
392 nol << encode(rhs->getNameInfo().getName().getAsString());
394 nol << ", MC2_OFFSET(";
395 nol << ml->getBase()->getType().getAsString();
397 nol << ml->getMemberDecl()->getName().str();
399 } else if (!isAddrOfR) {
401 nol << n_decl << encode(n) << "=";
402 nol << "MC2_nextOpLoad(";
403 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
405 nol << encode(rhs->getNameInfo().getName().getAsString());
408 nol << n_decl << encode(n) << "=";
409 nol << "MC2_nextOpLoad(";
414 nol << n_decl << encode(n) << "=";
415 nol << "MC2_nextOpLoad(";
423 SourceLocation ss = s->getLocStart();
425 // if the load appears as its own stmt and is the 1st stmt, containingStmt may be the containing CompoundStmt;
426 // move over 1 so that we get the right location.
427 if (isa<CompoundStmt>(s)) ss = ss.getLocWithOffset(1);
428 const Expr * e = dyn_cast<Expr>(s);
429 if (e && Redirector.count(e) > 0)
431 Update * u = new Update(ss, nol.str(), vars);
432 DeferredUpdates.insert(DeferredUpdates.begin(), u);
437 std::set<const NamedDecl *> & DeclsRead;
438 std::set<const VarDecl *> & DeclsNeedingMC;
439 std::map<const Expr *, std::string> &ExprToMCVar;
440 std::map<const NamedDecl *, std::string> &DeclToMCVar;
441 std::set<const Stmt *> &StmtsHandled;
442 std::vector<Update *> &DeferredUpdates;
443 std::map<const Expr *, SourceLocation> &Redirector;
446 class StoreHandler : public MatchFinder::MatchCallback {
448 StoreHandler(Rewriter &Rewrite,
449 std::set<const NamedDecl *> & DeclsRead,
450 std::set<const VarDecl *> &DeclsNeedingMC,
451 std::vector<Update *> & DeferredUpdates) :
452 Rewrite(Rewrite), DeclsRead(DeclsRead), DeclsNeedingMC(DeclsNeedingMC), DeferredUpdates(DeferredUpdates) {}
454 virtual void run(const MatchFinder::MatchResult &Result) {
455 std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
456 CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
459 fcaVisitor.TraverseStmt(ce->getArg(0));
460 const DeclRefExpr * lhs = fcaVisitor.RetrieveDeclRefExpr();
461 const UnaryOperator * luop = fcaVisitor.RetrieveUnaryOp();
463 std::stringstream nol;
468 if (luop && luop->getOpcode() == UnaryOperatorKind::UO_AddrOf) {
469 isAddrMemberL = isa<MemberExpr>(luop->getSubExpr());
470 isAddrOfL = !isa<MemberExpr>(luop->getSubExpr());
475 nol << "MC2_nextOpStore(";
479 MemberExpr * ml = cast<MemberExpr>(luop->getSubExpr());
481 nol << "MC2_nextOpStoreOffset(";
483 ProvisionalName * v = new ProvisionalName(nol.tellp(), lhs);
486 nol << encode(lhs->getNameInfo().getName().getAsString());
487 if (!isa<ParmVarDecl>(lhs->getDecl()))
488 DeclsNeedingMC.insert(cast<VarDecl>(lhs->getDecl()));
490 nol << ", MC2_OFFSET(";
491 nol << ml->getBase()->getType().getAsString();
493 nol << ml->getMemberDecl()->getName().str();
496 nol << "MC2_nextOpStore(";
497 ProvisionalName * v = new ProvisionalName(nol.tellp(), lhs);
500 nol << encode(lhs->getNameInfo().getName().getAsString());
505 nol << "MC2_nextOpStore(";
512 fcaVisitor.TraverseStmt(ce->getArg(1));
513 const DeclRefExpr * rhs = fcaVisitor.RetrieveDeclRefExpr();
514 const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
516 bool isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
517 bool isDerefR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_Deref;
519 if (rhs && !isAddrOfR) {
520 assert (!isDerefR && "Must use atomic load for dereferences!");
521 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
524 nol << encode(rhs->getNameInfo().getName().getAsString());
525 DeclsRead.insert(rhs->getDecl());
531 Update * u = new Update(ce->getLocStart(), nol.str(), vars);
532 DeferredUpdates.push_back(u);
537 FindCallArgVisitor fcaVisitor;
538 std::set<const NamedDecl *> & DeclsRead;
539 std::set<const VarDecl *> & DeclsNeedingMC;
540 std::vector<Update *> &DeferredUpdates;
543 class RMWHandler : public MatchFinder::MatchCallback {
545 RMWHandler(Rewriter &rewrite,
546 std::set<const NamedDecl *> & DeclsRead,
547 std::set<const NamedDecl *> & DeclsInCond,
548 std::map<const NamedDecl *, std::string> &DeclToMCVar,
549 std::map<const Expr *, std::string> &ExprToMCVar,
550 std::set<const Stmt *> & StmtsHandled,
551 std::map<const Expr *, SourceLocation> &Redirector,
552 std::vector<Update *> & DeferredUpdates) :
553 rewrite(rewrite), DeclsRead(DeclsRead), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar),
554 ExprToMCVar(ExprToMCVar),
555 StmtsHandled(StmtsHandled), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
557 virtual void run(const MatchFinder::MatchResult &Result) {
558 CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
559 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
560 std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
562 std::stringstream nol;
564 std::string rmwMCVar;
565 rmwMCVar = encodeRMW(rmwCount++);
567 const VarDecl * rmw_lhs;
569 StmtsHandled.insert(s);
570 assert (isa<DeclStmt>(s) || isa<BinaryOperator>(s) && "unknown RMW format: not declrefexpr, not binaryoperator");
572 if ((ds = dyn_cast<DeclStmt>(s))) {
573 rmw_lhs = retrieveSingleDecl(ds);
575 const Expr * e = cast<BinaryOperator>(s)->getLHS();
576 assert (isa<DeclRefExpr>(e));
577 rmw_lhs = cast<VarDecl>(cast<DeclRefExpr>(e)->getDecl());
579 DeclToMCVar[rmw_lhs] = rmwMCVar;
582 // retrieve effective LHS of the RMW
584 fcaVisitor.TraverseStmt(ce->getArg(1));
585 const DeclRefExpr * elhs = fcaVisitor.RetrieveDeclRefExpr();
586 const UnaryOperator * eluop = fcaVisitor.RetrieveUnaryOp();
587 bool isAddrMemberL = false;
589 if (eluop && eluop->getOpcode() == UnaryOperatorKind::UO_AddrOf) {
590 isAddrMemberL = isa<MemberExpr>(eluop->getSubExpr());
593 nol << "MCID " << rmwMCVar;
595 MemberExpr * ml = cast<MemberExpr>(eluop->getSubExpr());
597 nol << " = MC2_nextRMWOffset(";
599 ProvisionalName * v = new ProvisionalName(nol.tellp(), elhs);
602 nol << encode(elhs->getNameInfo().getName().getAsString());
604 nol << ", MC2_OFFSET(";
605 nol << ml->getBase()->getType().getAsString();
607 nol << ml->getMemberDecl()->getName().str();
610 nol << " = MC2_nextRMW(";
611 bool isAddrOfL = eluop && eluop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
617 ProvisionalName * v = new ProvisionalName(nol.tellp(), elhs);
620 std::string elhsName = encode(elhs->getNameInfo().getName().getAsString());
629 // handle both RHS ops
631 for (int arg = 2; arg < 4; arg++) {
633 fcaVisitor.TraverseStmt(ce->getArg(arg));
634 const DeclRefExpr * a = fcaVisitor.RetrieveDeclRefExpr();
635 const UnaryOperator * op = fcaVisitor.RetrieveUnaryOp();
637 bool isAddrOfR = op && op->getOpcode() == UnaryOperatorKind::UO_AddrOf;
638 bool isDerefR = op && op->getOpcode() == UnaryOperatorKind::UO_Deref;
640 if (a && !isAddrOfR) {
641 assert (!isDerefR && "Must use atomic load for dereferences!");
643 DeclsInCond.insert(a->getDecl());
645 if (outputted > 0) nol << ", ";
648 bool alreadyMCVar = false;
649 if (DeclToMCVar.find(a->getDecl()) != DeclToMCVar.end()) {
651 nol << DeclToMCVar[a->getDecl()];
654 std::string an = "MCID_NODEP";
655 ProvisionalName * v = new ProvisionalName(nol.tellp(), a, an.length());
660 DeclsRead.insert(a->getDecl());
663 if (outputted > 0) nol << ", ";
671 SourceLocation place = s ? s->getLocStart() : ce->getLocStart();
672 const Expr * e = s ? dyn_cast<Expr>(s) : ce;
673 if (e && Redirector.count(e) > 0)
674 place = Redirector[e];
675 Update * u = new Update(place, nol.str(), vars);
676 DeferredUpdates.insert(DeferredUpdates.begin(), u);
681 FindCallArgVisitor fcaVisitor;
682 std::set<const NamedDecl *> &DeclsRead;
683 std::set<const NamedDecl *> &DeclsInCond;
684 std::map<const NamedDecl *, std::string> &DeclToMCVar;
685 std::map<const Expr *, std::string> &ExprToMCVar;
686 std::set<const Stmt *> &StmtsHandled;
687 std::vector<Update *> &DeferredUpdates;
688 std::map<const Expr *, SourceLocation> &Redirector;
691 class FindReturnsBreaksVisitor : public RecursiveASTVisitor<FindReturnsBreaksVisitor> {
693 FindReturnsBreaksVisitor() : Returns(), Breaks() {}
695 bool VisitStmt(Stmt * s) {
696 if (isa<ReturnStmt>(s))
697 Returns.push_back(cast<ReturnStmt>(s));
699 if (isa<BreakStmt>(s))
700 Breaks.push_back(cast<BreakStmt>(s));
705 Returns.clear(); Breaks.clear();
708 const std::vector<const ReturnStmt *> RetrieveReturns() {
712 const std::vector<const BreakStmt *> RetrieveBreaks() {
717 std::vector<const ReturnStmt *> Returns;
718 std::vector<const BreakStmt *> Breaks;
721 class LoopHandler : public MatchFinder::MatchCallback {
723 LoopHandler(Rewriter &rewrite) : rewrite(rewrite) {}
725 virtual void run(const MatchFinder::MatchResult &Result) {
726 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("s");
728 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(s->getLocStart()),
729 "MC2_enterLoop();\n", true, true);
731 // annotate all returns with MC2_exitLoop()
732 // annotate all breaks that aren't further nested with MC2_exitLoop().
733 FindReturnsBreaksVisitor frbv;
735 frbv.TraverseStmt(const_cast<Stmt *>(cast<ForStmt>(s)->getBody()));
736 if (isa<WhileStmt>(s))
737 frbv.TraverseStmt(const_cast<Stmt *>(cast<WhileStmt>(s)->getBody()));
739 frbv.TraverseStmt(const_cast<Stmt *>(cast<DoStmt>(s)->getBody()));
741 for (auto & r : frbv.RetrieveReturns()) {
742 rewrite.InsertText(r->getLocStart(), "MC2_exitLoop();\n", true, true);
745 // need to find all breaks and returns embedded inside the loop
747 rewrite.InsertTextAfterToken(rewrite.getSourceMgr().getExpansionLoc(s->getLocEnd().getLocWithOffset(1)),
748 "\nMC2_exitLoop();\n");
755 /* Inserts MC2_function for any variables which are subsequently used by the model checker, as long as they depend on MC-visible [currently: read] state. */
756 class AssignHandler : public MatchFinder::MatchCallback {
758 AssignHandler(Rewriter &rewrite, std::set<const NamedDecl *> &DeclsRead,
759 std::set<const NamedDecl *> &DeclsInCond,
760 std::set<const VarDecl *> &DeclsNeedingMC,
761 std::map<const NamedDecl *, std::string> &DeclToMCVar,
762 std::set<const Stmt *> &StmtsHandled,
763 std::set<const Expr *> &MallocExprs,
764 std::vector<Update *> &DeferredUpdates) :
766 DeclsRead(DeclsRead),
767 DeclsInCond(DeclsInCond),
768 DeclsNeedingMC(DeclsNeedingMC),
769 DeclToMCVar(DeclToMCVar),
770 StmtsHandled(StmtsHandled),
771 MallocExprs(MallocExprs),
772 DeferredUpdates(DeferredUpdates) {}
774 virtual void run(const MatchFinder::MatchResult &Result) {
775 BinaryOperator * op = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("op"));
776 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
777 FindLocalsVisitor locals, locals_rhs;
779 const VarDecl * lhs = NULL;
780 const Expr * rhs = NULL;
783 if (s && (ds = dyn_cast<DeclStmt>(s))) {
784 // long term goal: refactor the run() method to deal with one assignment at a time
785 // for now, if there is only declarations and no rhs's, we'll ignore this stmt
786 if (!ds->isSingleDecl()) {
787 for (auto & d : ds->decls()) {
788 VarDecl * vd = dyn_cast<VarDecl>(d);
789 if (!d || vd->hasInit())
790 assert(0 && "unsupported form of decl");
795 lhs = retrieveSingleDecl(ds);
798 if (StmtsHandled.find(ds) != StmtsHandled.end() || StmtsHandled.find(op) != StmtsHandled.end())
802 if (lhs->hasInit()) {
803 rhs = lhs->getInit();
805 rhs = rhs->IgnoreCasts();
811 std::set<std::string> mcState;
814 bool rhsRead = false;
816 bool lhsTooComplicated = false;
819 if ((vd = dyn_cast<DeclRefExpr>(op->getLHS())))
820 lhs = dyn_cast<VarDecl>(vd->getDecl());
822 // kick the can along...
823 lhsTooComplicated = true;
828 rhs = rhs->IgnoreCasts();
831 // rhs must be MC-active state, i.e. in declsread
832 // lhs must be subsequently used in (1) store/load or (2) branch condition or (3) other functions and (3a) uses values from other functions or (3b) uses values from loads, stores, or phi functions
835 locals_rhs.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
836 for (auto & nd : locals_rhs.RetrieveVars()) {
837 if (DeclsRead.find(nd) != DeclsRead.end())
842 locals.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
844 lhsUsedInCond = DeclsInCond.find(lhs) != DeclsInCond.end();
846 for (auto & d : locals.RetrieveVars()) {
847 if (DeclToMCVar.count(d) > 0)
848 mcState.insert(DeclToMCVar[d]);
849 else if (DeclsRead.find(d) != DeclsRead.end())
850 mcState.insert(encode(d->getName().str()));
854 for (auto & d : locals_rhs.RetrieveVars()) {
855 if (DeclToMCVar.count(d) > 0)
856 mcState.insert(DeclToMCVar[d]);
857 else if (DeclsRead.find(d) != DeclsRead.end())
858 mcState.insert(encode(d->getName().str()));
861 if (mcState.size() > 0 || MallocExprs.find(rhs) != MallocExprs.end()) {
862 if (lhsTooComplicated)
863 assert(0 && "couldn't find LHS of = operator");
865 std::stringstream nol;
866 std::string _lhsStr, lhsStr;
867 std::string mcVar = encodeFn(fnCount++);
869 lhsStr = lhs->getName().str();
870 _lhsStr = encode(lhsStr);
871 DeclToMCVar[lhs] = mcVar;
872 DeclsNeedingMC.insert(cast<VarDecl>(lhs));
875 if (!(MallocExprs.find(rhs) != MallocExprs.end()))
876 function_id = ++funcCount;
877 nol << "\n" << mcVar << " = MC2_function_id(" << function_id << ", " << mcState.size();
879 nol << ", sizeof (" << lhsStr << "), (uint64_t)" << lhsStr;
881 nol << ", MC2_PTR_LENGTH";
882 for (auto & d : mcState) {
890 SourceLocation place;
892 place = Lexer::getLocForEndOfToken(op->getLocEnd(), 0, rewrite.getSourceMgr(), rewrite.getLangOpts()).getLocWithOffset(1);
894 place = s->getLocEnd();
895 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(place.getLocWithOffset(1)),
896 nol.str(), true, true);
898 updateProvisionalName(DeferredUpdates, lhs, mcVar);
904 std::set<const NamedDecl *> &DeclsRead, &DeclsInCond;
905 std::set<const VarDecl *> &DeclsNeedingMC;
906 std::map<const NamedDecl *, std::string> &DeclToMCVar;
907 std::set<const Stmt *> &StmtsHandled;
908 std::set<const Expr *> &MallocExprs;
909 std::vector<Update *> &DeferredUpdates;
912 // record vars used in conditions
913 class BranchConditionRefactoringHandler : public MatchFinder::MatchCallback {
915 BranchConditionRefactoringHandler(Rewriter &rewrite,
916 std::set<const NamedDecl *> & DeclsInCond,
917 std::map<const NamedDecl *, std::string> &DeclToMCVar,
918 std::map<const Expr *, std::string> &ExprToMCVar,
919 std::map<const Expr *, SourceLocation> &Redirector,
920 std::vector<Update *> &DeferredUpdates) :
921 rewrite(rewrite), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar),
922 ExprToMCVar(ExprToMCVar), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
924 virtual void run(const MatchFinder::MatchResult &Result) {
925 IfStmt * is = const_cast<IfStmt *>(Result.Nodes.getNodeAs<IfStmt>("if"));
926 Expr * cond = is->getCond();
928 // refactor out complicated conditions
929 FindCallArgVisitor flv;
930 flv.TraverseStmt(cond);
933 BinaryOperator * bc = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("bc"));
935 std::string condVar = encodeCond(condCount++);
936 std::stringstream condVarEncoded;
937 condVarEncoded << condVar << "_m";
939 // prettyprint the binary op
940 // e.g. int _cond0 = x == y;
942 llvm::raw_string_ostream S(SStr);
943 bc->printPretty(S, nullptr, rewrite.getLangOpts());
944 const std::string &Str = S.str();
946 std::stringstream prel;
948 bool is_equality = false;
949 // handle equality tests
950 if (bc->getOpcode() == BO_EQ) {
951 Expr * lhs = bc->getLHS()->IgnoreCasts(), * rhs = bc->getRHS()->IgnoreCasts();
952 if (isa<DeclRefExpr>(lhs) && isa<DeclRefExpr>(rhs)) {
953 DeclRefExpr * l = dyn_cast<DeclRefExpr>(lhs), *r = dyn_cast<DeclRefExpr>(rhs);
955 prel << "\nMCID " << condVarEncoded.str() << ";\n";
956 std::string ld = DeclToMCVar.find(l->getDecl())->second,
957 rd = DeclToMCVar.find(r->getDecl())->second;
959 prel << "\nint " << condVar << " = MC2_equals(" <<
960 ld << ", (uint64_t)" << l->getNameInfo().getName().getAsString() << ", " <<
961 rd << ", (uint64_t)" << r->getNameInfo().getName().getAsString() << ", " <<
962 "&" << condVarEncoded.str() << ");\n";
967 prel << "\nint " << condVar << " = " << Str << ";";
968 prel << "\nMCID " << condVarEncoded.str() << " = MC2_function_id(" << ++funcCount << ", 1, sizeof(" << condVar << "), " << condVar;
969 const DeclRefExpr * d = flv.RetrieveDeclRefExpr();
970 if (DeclToMCVar.find(d->getDecl()) != DeclToMCVar.end()) {
971 prel << ", " << DeclToMCVar[d->getDecl()];
976 ExprToMCVar[cond] = condVarEncoded.str();
977 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(is->getLocStart()),
978 prel.str(), false, true);
980 // rewrite the binary op with the newly-inserted var
981 Expr * RO = bc->getRHS(); // used for location only
983 int cl = Lexer::MeasureTokenLength(RO->getLocStart(), rewrite.getSourceMgr(), rewrite.getLangOpts());
984 SourceRange SR(cond->getLocStart(), rewrite.getSourceMgr().getExpansionLoc(RO->getLocStart()).getLocWithOffset(cl-1));
985 rewrite.ReplaceText(SR, condVar);
987 std::string condVar = encodeCond(condCount++);
988 std::stringstream condVarEncoded;
989 condVarEncoded << condVar << "_m";
992 llvm::raw_string_ostream S(SStr);
993 cond->printPretty(S, nullptr, rewrite.getLangOpts());
994 const std::string &Str = S.str();
996 std::stringstream prel;
997 prel << "\nint " << condVar << " = " << Str << ";";
998 prel << "\nMCID " << condVarEncoded.str() << " = MC2_function_id(" << ++funcCount << ", 1, sizeof(" << condVar << "), " << condVar;
999 std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
1000 const DeclRefExpr * d = flv.RetrieveDeclRefExpr();
1001 if (isa<VarDecl>(d->getDecl()) && DeclToMCVar.find(d->getDecl()) != DeclToMCVar.end()) {
1002 prel << ", " << DeclToMCVar[d->getDecl()];
1005 ProvisionalName * v = new ProvisionalName(prel.tellp(), d, 0);
1010 ExprToMCVar[cond] = condVarEncoded.str();
1011 // gross hack; should look for any callexprs in cond
1012 // but right now, if it's a unaryop, just manually traverse
1013 if (isa<UnaryOperator>(cond)) {
1014 Expr * e = dyn_cast<UnaryOperator>(cond)->getSubExpr();
1015 ExprToMCVar[e] = condVarEncoded.str();
1017 Update * u = new Update(is->getLocStart(), prel.str(), vars);
1018 DeferredUpdates.push_back(u);
1020 // rewrite the call op with the newly-inserted var
1021 SourceRange SR(cond->getLocStart(), cond->getLocEnd());
1022 Redirector[cond] = is->getLocStart();
1023 rewrite.ReplaceText(SR, condVar);
1026 std::deque<const Decl *> q;
1027 const NamedDecl * d = flv.RetrieveDeclRefExpr()->getDecl();
1029 while (!q.empty()) {
1030 const Decl * d = q.back();
1032 if (isa<NamedDecl>(d))
1033 DeclsInCond.insert(cast<NamedDecl>(d));
1036 if ((vd = dyn_cast<VarDecl>(d))) {
1037 if (vd->hasInit()) {
1038 const Expr * e = vd->getInit();
1040 flv.TraverseStmt(const_cast<Expr *>(e));
1041 const NamedDecl * d = flv.RetrieveDeclRefExpr()->getDecl();
1050 std::set<const NamedDecl *> & DeclsInCond;
1051 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1052 std::map<const Expr *, std::string> &ExprToMCVar;
1053 std::map<const Expr *, SourceLocation> &Redirector;
1054 std::vector<Update *> &DeferredUpdates;
1057 class BranchAnnotationHandler : public MatchFinder::MatchCallback {
1059 BranchAnnotationHandler(Rewriter &rewrite,
1060 std::map<const NamedDecl *, std::string> & DeclToMCVar,
1061 std::map<const Expr *, std::string> & ExprToMCVar)
1063 DeclToMCVar(DeclToMCVar),
1064 ExprToMCVar(ExprToMCVar){}
1065 virtual void run(const MatchFinder::MatchResult &Result) {
1066 IfStmt * is = const_cast<IfStmt *>(Result.Nodes.getNodeAs<IfStmt>("if"));
1068 // if the branch condition is interesting:
1069 // (but right now, not too interesting)
1070 Expr * cond = is->getCond()->IgnoreCasts();
1072 FindLocalsVisitor flv;
1073 flv.TraverseStmt(cond);
1074 if (flv.RetrieveVars().size() == 0) return;
1076 const NamedDecl * condVar = flv.RetrieveVars()[0];
1078 std::string mCondVar;
1079 if (ExprToMCVar.count(cond) > 0)
1080 mCondVar = ExprToMCVar[cond];
1081 else if (DeclToMCVar.count(condVar) > 0)
1082 mCondVar = DeclToMCVar[condVar];
1084 mCondVar = encode(condVar->getName());
1085 std::string brVar = encodeBranch(branchCount++);
1087 std::stringstream brline;
1088 brline << "MCID " << brVar << ";\n";
1089 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(is->getLocStart()),
1090 brline.str(), false, true);
1092 Stmt * ts = is->getThen(), * es = is->getElse();
1093 bool tHasChild = hasChild(ts);
1096 if (isa<CompoundStmt>(ts))
1097 tfl = getFirstChild(ts)->getLocStart();
1099 tfl = ts->getLocStart();
1101 tfl = ts->getLocStart().getLocWithOffset(1);
1102 SourceLocation tsl = ts->getLocEnd().getLocWithOffset(-1);
1104 std::stringstream tlineStart, mergeStmt, eline;
1106 UnaryOperator * uop = dyn_cast<UnaryOperator>(cond);
1107 tlineStart << brVar << " = MC2_branchUsesID(" << mCondVar << ", " << "1" << ", 2, true);\n";
1108 eline << brVar << " = MC2_branchUsesID(" << mCondVar << ", " << "0" << ", 2, true);";
1110 mergeStmt << "\tMC2_merge(" << brVar << ");\n";
1112 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(tfl), tlineStart.str(), false, true);
1115 int extra_else_offset = 0;
1117 if (tHasChild) { tls = getLastChild(ts); }
1118 if (tls) extra_else_offset = 2; else extra_else_offset = 1;
1120 if (!tHasChild || (!isa<ReturnStmt>(tls) && !isa<BreakStmt>(tls))) {
1121 extra_else_offset = 0;
1122 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(tsl.getLocWithOffset(1)),
1123 mergeStmt.str(), true, true);
1125 if (tHasChild && !isa<CompoundStmt>(ts)) {
1126 rewrite.InsertText(rewrite.getSourceMgr().getFileLoc(tls->getLocStart()), "{", false, true);
1127 SourceLocation tend = Lexer::findLocationAfterToken(tls->getLocEnd(), tok::semi, rewrite.getSourceMgr(), rewrite.getLangOpts(), false);
1128 rewrite.InsertText(tend, "}", true, true);
1130 if (tHasChild && isa<CompoundStmt>(ts)) extra_else_offset++;
1133 SourceLocation esl = es->getLocEnd().getLocWithOffset(-1);
1134 bool eHasChild = hasChild(es);
1136 if (eHasChild) els = getLastChild(es); else els = es;
1142 if (isa<CompoundStmt>(es))
1143 el = getFirstChild(es)->getLocStart();
1145 el = es->getLocStart();
1148 el = es->getLocStart().getLocWithOffset(1);
1149 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(el), eline.str(), false, true);
1151 if (eHasChild && !isa<CompoundStmt>(es)) {
1152 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(el), "{", false, true);
1153 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(es->getLocEnd().getLocWithOffset(1)), "}", true, true);
1156 if (!eHasChild || (!isa<ReturnStmt>(els) && !isa<BreakStmt>(els)))
1157 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(esl.getLocWithOffset(1)), mergeStmt.str(), true, true);
1160 std::stringstream eCompoundLine;
1161 eCompoundLine << " else { " << eline.str() << mergeStmt.str() << " }";
1162 SourceLocation tend = Lexer::findLocationAfterToken(ts->getLocEnd(), tok::semi, rewrite.getSourceMgr(), rewrite.getLangOpts(), false);
1163 if (!tend.isValid())
1164 tend = Lexer::getLocForEndOfToken(ts->getLocEnd(), 0, rewrite.getSourceMgr(), rewrite.getLangOpts());
1165 rewrite.InsertText(tend.getLocWithOffset(1), eCompoundLine.str(), false, true);
1170 bool hasChild(Stmt * s) {
1171 if (!isa<CompoundStmt>(s)) return true;
1172 return (!cast<CompoundStmt>(s)->body_empty());
1175 Stmt * getFirstChild(Stmt * s) {
1176 assert(isa<CompoundStmt>(s) && "haven't yet added code to rewrite then/elsestmt to CompoundStmt");
1177 assert(!cast<CompoundStmt>(s)->body_empty());
1178 return *(cast<CompoundStmt>(s)->body_begin());
1181 Stmt * getLastChild(Stmt * s) {
1183 if ((cs = dyn_cast<CompoundStmt>(s))) {
1184 assert (!cs->body_empty());
1185 return cs->body_back();
1191 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1192 std::map<const Expr *, std::string> &ExprToMCVar;
1195 class FunctionCallHandler : public MatchFinder::MatchCallback {
1197 FunctionCallHandler(Rewriter &rewrite,
1198 std::map<const NamedDecl *, std::string> &DeclToMCVar,
1199 std::set<const FunctionDecl *> &ThreadMains)
1200 : rewrite(rewrite), DeclToMCVar(DeclToMCVar), ThreadMains(ThreadMains) {}
1202 virtual void run(const MatchFinder::MatchResult &Result) {
1203 CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
1204 Decl * d = ce->getCalleeDecl();
1205 NamedDecl * nd = dyn_cast<NamedDecl>(d);
1206 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
1207 ASTContext *Context = Result.Context;
1209 if (nd->getName() == "thrd_create") {
1210 Expr * callee0 = ce->getArg(1)->IgnoreCasts();
1211 UnaryOperator * callee1;
1212 if ((callee1 = dyn_cast<UnaryOperator>(callee0))) {
1213 if (callee1->getOpcode() == UnaryOperatorKind::UO_AddrOf)
1214 callee0 = callee1->getSubExpr();
1216 DeclRefExpr * callee = dyn_cast<DeclRefExpr>(callee0);
1217 if (!callee) return;
1218 FunctionDecl * fd = dyn_cast<FunctionDecl>(callee->getDecl());
1219 ThreadMains.insert(fd);
1226 if (s && !ce->getCallReturnType(*Context)->isVoidType()) {
1227 // TODO check that the type is mc-visible also?
1228 const DeclStmt * ds;
1229 const VarDecl * lhs = NULL;
1230 std::string mc_rv = encodeRV(rvCount++);
1232 std::stringstream brline;
1233 brline << "MCID " << mc_rv << ";\n";
1234 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(s->getLocStart()),
1235 brline.str(), false, true);
1237 std::stringstream nol;
1238 if (ce->getNumArgs() > 0) nol << ", ";
1239 nol << "&" << mc_rv;
1240 rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(ce->getRParenLoc()),
1243 if (s && (ds = dyn_cast<DeclStmt>(s))) {
1244 if (!ds->isSingleDecl()) {
1245 for (auto & d : ds->decls()) {
1246 VarDecl * vd = dyn_cast<VarDecl>(d);
1247 if (!d || vd->hasInit())
1248 assert(0 && "unsupported form of decl");
1253 lhs = retrieveSingleDecl(ds);
1256 DeclToMCVar[lhs] = mc_rv;
1259 for (const auto & a : ce->arguments()) {
1260 std::stringstream nol;
1262 std::string aa = "MCID_NODEP";
1264 Expr * e = a->IgnoreCasts();
1265 DeclRefExpr * dr = dyn_cast<DeclRefExpr>(e);
1267 NamedDecl * d = dr->getDecl();
1268 if (DeclToMCVar.find(d) != DeclToMCVar.end())
1269 aa = DeclToMCVar[d];
1274 if (a->getLocEnd().isValid())
1275 rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(a->getLocStart()),
1282 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1283 std::set<const FunctionDecl *> &ThreadMains;
1286 class ReturnHandler : public MatchFinder::MatchCallback {
1288 ReturnHandler(Rewriter &rewrite,
1289 std::map<const NamedDecl *, std::string> &DeclToMCVar,
1290 std::set<const FunctionDecl *> &ThreadMains)
1291 : rewrite(rewrite), DeclToMCVar(DeclToMCVar), ThreadMains(ThreadMains) {}
1293 virtual void run(const MatchFinder::MatchResult &Result) {
1294 const FunctionDecl * fd = Result.Nodes.getNodeAs<FunctionDecl>("containingFunction");
1295 ReturnStmt * rs = const_cast<ReturnStmt *>(Result.Nodes.getNodeAs<ReturnStmt>("returnStmt"));
1296 Expr * rv = const_cast<Expr *>(rs->getRetValue());
1299 if (ThreadMains.find(fd) != ThreadMains.end()) return;
1300 // not sure why this is explicitly needed, but crashes without it
1301 if (!fd->getIdentifier() || fd->getName() == "user_main") return;
1303 FindLocalsVisitor flv;
1304 flv.TraverseStmt(rv);
1305 std::string mrv = "MCID_NODEP";
1307 if (flv.RetrieveVars().size() > 0) {
1308 const NamedDecl * returnVar = flv.RetrieveVars()[0];
1309 if (DeclToMCVar.find(returnVar) != DeclToMCVar.end()) {
1310 mrv = DeclToMCVar[returnVar];
1313 std::stringstream nol;
1314 nol << "*retval = " << mrv << ";\n";
1315 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(rs->getLocStart()),
1316 nol.str(), false, true);
1321 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1322 std::set<const FunctionDecl *> &ThreadMains;
1325 class VarDeclHandler : public MatchFinder::MatchCallback {
1327 VarDeclHandler(Rewriter &rewrite,
1328 std::map<const NamedDecl *, std::string> &DeclToMCVar,
1329 std::set<const VarDecl *> &DeclsNeedingMC)
1330 : rewrite(rewrite), DeclToMCVar(DeclToMCVar), DeclsNeedingMC(DeclsNeedingMC) {}
1332 virtual void run(const MatchFinder::MatchResult &Result) {
1333 VarDecl * d = const_cast<VarDecl *>(Result.Nodes.getNodeAs<VarDecl>("d"));
1334 std::stringstream nol;
1336 if (DeclsNeedingMC.find(d) == DeclsNeedingMC.end()) return;
1339 if (DeclToMCVar.find(d) != DeclToMCVar.end())
1340 dn = DeclToMCVar[d];
1342 dn = encode(d->getName().str());
1344 nol << "MCID " << dn << "; ";
1346 if (d->getLocStart().isValid())
1347 rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(d->getLocStart()),
1353 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1354 std::set<const VarDecl *> &DeclsNeedingMC;
1357 class FunctionDeclHandler : public MatchFinder::MatchCallback {
1359 FunctionDeclHandler(Rewriter &rewrite,
1360 std::set<const FunctionDecl *> &ThreadMains)
1361 : rewrite(rewrite), ThreadMains(ThreadMains) {}
1363 virtual void run(const MatchFinder::MatchResult &Result) {
1364 FunctionDecl * fd = const_cast<FunctionDecl *>(Result.Nodes.getNodeAs<FunctionDecl>("fd"));
1366 if (!fd->getIdentifier()) return;
1368 if (fd->getName() == "user_main") { ThreadMains.insert(fd); return; }
1370 if (ThreadMains.find(fd) != ThreadMains.end()) return;
1372 SourceLocation LastParam = fd->getNameInfo().getLocStart().getLocWithOffset(fd->getName().size()).getLocWithOffset(1);
1373 for (auto & p : fd->params()) {
1374 std::stringstream nol;
1375 nol << "MCID " << encode(p->getName()) << ", ";
1376 if (p->getLocStart().isValid())
1377 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(p->getLocStart()),
1379 if (p->getLocEnd().isValid())
1380 LastParam = p->getLocEnd().getLocWithOffset(p->getName().size());
1383 if (!fd->getReturnType()->isVoidType()) {
1384 std::stringstream nol;
1385 if (fd->param_size() > 0) nol << ", ";
1386 nol << "MCID * retval";
1387 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(LastParam),
1394 std::set<const FunctionDecl *> &ThreadMains;
1397 class BailHandler : public MatchFinder::MatchCallback {
1400 virtual void run(const MatchFinder::MatchResult &Result) {
1401 assert(0 && "we don't handle goto statements");
1405 class MyASTConsumer : public ASTConsumer {
1407 MyASTConsumer(Rewriter &R) : R(R),
1411 HandlerMalloc(MallocExprs),
1412 HandlerLoad(R, DeclsRead, DeclsNeedingMC, DeclToMCVar, ExprToMCVar, StmtsHandled, Redirector, DeferredUpdates),
1413 HandlerStore(R, DeclsRead, DeclsNeedingMC, DeferredUpdates),
1414 HandlerRMW(R, DeclsRead, DeclsInCond, DeclToMCVar, ExprToMCVar, StmtsHandled, Redirector, DeferredUpdates),
1416 HandlerBranchConditionRefactoring(R, DeclsInCond, DeclToMCVar, ExprToMCVar, Redirector, DeferredUpdates),
1417 HandlerAssign(R, DeclsRead, DeclsInCond, DeclsNeedingMC, DeclToMCVar, StmtsHandled, MallocExprs, DeferredUpdates),
1418 HandlerAnnotateBranch(R, DeclToMCVar, ExprToMCVar),
1419 HandlerFunctionDecl(R, ThreadMains),
1420 HandlerFunctionCall(R, DeclToMCVar, ThreadMains),
1421 HandlerReturn(R, DeclToMCVar, ThreadMains),
1422 HandlerVarDecl(R, DeclToMCVar, DeclsNeedingMC),
1424 MatcherFunctionCall.addMatcher(callExpr(anyOf(hasParent(compoundStmt()),
1425 hasAncestor(varDecl(hasParent(stmt().bind("containingStmt")))),
1426 hasAncestor(binaryOperator(hasOperatorName("=")).bind("containingStmt")))).bind("callExpr"),
1427 &HandlerFunctionCall);
1428 MatcherLoadStore.addMatcher
1429 (callExpr(callee(functionDecl(anyOf(hasName("malloc"), hasName("calloc"))))).bind("callExpr"),
1432 MatcherLoadStore.addMatcher
1433 (callExpr(callee(functionDecl(anyOf(hasName("load_32"), hasName("load_64")))),
1434 anyOf(hasAncestor(varDecl(hasParent(stmt().bind("containingStmt"))).bind("decl")),
1435 hasAncestor(binaryOperator(hasOperatorName("=")).bind("containingStmt")),
1436 hasParent(stmt().bind("containingStmt"))))
1440 MatcherLoadStore.addMatcher(callExpr(callee(functionDecl(anyOf(hasName("store_32"), hasName("store_64"))))).bind("callExpr"),
1443 MatcherLoadStore.addMatcher
1444 (callExpr(callee(functionDecl(anyOf(hasName("rmw_32"), hasName("rmw_64")))),
1445 anyOf(hasAncestor(varDecl(hasParent(stmt().bind("containingStmt"))).bind("decl")),
1446 hasAncestor(binaryOperator(hasOperatorName("="),
1447 hasLHS(declRefExpr().bind("lhs"))).bind("containingStmt")),
1452 MatcherLoadStore.addMatcher(ifStmt(hasCondition
1453 (anyOf(binaryOperator().bind("bc"),
1454 hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("load_32"), hasName("load_64"))))).bind("callExpr")),
1455 hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("store_32"), hasName("store_64"))))).bind("callExpr")),
1456 hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("rmw_32"), hasName("rmw_64"))))).bind("callExpr")),
1457 anything()))).bind("if"),
1458 &HandlerBranchConditionRefactoring);
1460 MatcherLoadStore.addMatcher(forStmt().bind("s"),
1462 MatcherLoadStore.addMatcher(whileStmt().bind("s"),
1464 MatcherLoadStore.addMatcher(doStmt().bind("s"),
1467 MatcherFunction.addMatcher(binaryOperator(anyOf(hasAncestor(declStmt().bind("containingStmt")),
1468 hasParent(compoundStmt())),
1469 hasOperatorName("=")).bind("op"),
1471 MatcherFunction.addMatcher(declStmt().bind("containingStmt"), &HandlerAssign);
1473 MatcherFunction.addMatcher(ifStmt().bind("if"),
1474 &HandlerAnnotateBranch);
1476 MatcherFunctionDecl.addMatcher(functionDecl().bind("fd"),
1477 &HandlerFunctionDecl);
1478 MatcherFunctionDecl.addMatcher(varDecl().bind("d"), &HandlerVarDecl);
1479 MatcherFunctionDecl.addMatcher(returnStmt(hasAncestor(functionDecl().bind("containingFunction"))).bind("returnStmt"),
1482 MatcherSanity.addMatcher(gotoStmt(), &HandlerBail);
1485 // Override the method that gets called for each parsed top-level
1487 void HandleTranslationUnit(ASTContext &Context) override {
1488 LangOpts = Context.getLangOpts();
1490 MatcherFunctionCall.matchAST(Context);
1491 MatcherLoadStore.matchAST(Context);
1492 MatcherFunction.matchAST(Context);
1493 MatcherFunctionDecl.matchAST(Context);
1494 MatcherSanity.matchAST(Context);
1496 for (auto & u : DeferredUpdates) {
1497 R.InsertText(R.getSourceMgr().getExpansionLoc(u->loc), u->update, true, true);
1500 DeferredUpdates.clear();
1504 /* DeclsRead contains all local variables 'x' which:
1505 * 1) appear in 'x = load_32(...);
1506 * 2) appear in 'y = store_32(x); */
1507 std::set<const NamedDecl *> DeclsRead;
1508 /* DeclsInCond contains all local variables 'x' used in a branch condition or rmw parameter */
1509 std::set<const NamedDecl *> DeclsInCond;
1510 std::map<const NamedDecl *, std::string> DeclToMCVar;
1511 std::map<const Expr *, std::string> ExprToMCVar;
1512 std::set<const VarDecl *> DeclsNeedingMC;
1513 std::set<const FunctionDecl *> ThreadMains;
1514 std::set<const Stmt *> StmtsHandled;
1515 std::set<const Expr *> MallocExprs;
1516 std::map<const Expr *, SourceLocation> Redirector;
1517 std::vector<Update *> DeferredUpdates;
1521 MallocHandler HandlerMalloc;
1522 LoadHandler HandlerLoad;
1523 StoreHandler HandlerStore;
1524 RMWHandler HandlerRMW;
1525 LoopHandler HandlerLoop;
1526 BranchConditionRefactoringHandler HandlerBranchConditionRefactoring;
1527 BranchAnnotationHandler HandlerAnnotateBranch;
1528 AssignHandler HandlerAssign;
1529 FunctionDeclHandler HandlerFunctionDecl;
1530 FunctionCallHandler HandlerFunctionCall;
1531 ReturnHandler HandlerReturn;
1532 VarDeclHandler HandlerVarDecl;
1533 BailHandler HandlerBail;
1534 MatchFinder MatcherLoadStore, MatcherFunction, MatcherFunctionDecl, MatcherFunctionCall, MatcherSanity;
1537 // For each source file provided to the tool, a new FrontendAction is created.
1538 class MyFrontendAction : public ASTFrontendAction {
1540 MyFrontendAction() {}
1541 void EndSourceFileAction() override {
1542 SourceManager &SM = TheRewriter.getSourceMgr();
1543 llvm::errs() << "** EndSourceFileAction for: "
1544 << SM.getFileEntryForID(SM.getMainFileID())->getName() << "\n";
1546 // Now emit the rewritten buffer.
1547 TheRewriter.getEditBuffer(SM.getMainFileID()).write(llvm::outs());
1550 std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI,
1551 StringRef file) override {
1552 llvm::errs() << "** Creating AST consumer for: " << file << "\n";
1553 TheRewriter.setSourceMgr(CI.getSourceManager(), CI.getLangOpts());
1554 return llvm::make_unique<MyASTConsumer>(TheRewriter);
1558 Rewriter TheRewriter;
1561 int main(int argc, const char **argv) {
1562 CommonOptionsParser op(argc, argv, AddMC2AnnotationsCategory);
1563 ClangTool Tool(op.getCompilations(), op.getSourcePathList());
1565 return Tool.run(newFrontendActionFactory<MyFrontendAction>().get());