1 // -*- indent-tabs-mode:nil; c-basic-offset:4; -*-
2 //------------------------------------------------------------------------------
3 // Add MC2 annotations to C code.
4 // Copyright 2015 Patrick Lam <prof.lam@gmail.com>
6 // Permission is hereby granted, free of charge, to any person
7 // obtaining a copy of this software and associated documentation
8 // files (the "Software"), to deal with the Software without
9 // restriction, including without limitation the rights to use, copy,
10 // modify, merge, publish, distribute, sublicense, and/or sell copies
11 // of the Software, and to permit persons to whom the Software is
12 // furnished to do so, subject to the following conditions:
14 // Redistributions of source code must retain the above copyright
15 // notice, this list of conditions and the following disclaimers.
17 // Redistributions in binary form must reproduce the above copyright
18 // notice, this list of conditions and the following disclaimers in
19 // the documentation and/or other materials provided with the
22 // Neither the names of the University of Waterloo, nor the names of
23 // its contributors may be used to endorse or promote products derived
24 // from this Software without specific prior written permission.
26 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
27 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
28 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
29 // NONINFRINGEMENT. IN NO EVENT SHALL THE CONTRIBUTORS OR COPYRIGHT
30 // HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
31 // WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
32 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
33 // DEALINGS WITH THE SOFTWARE.
35 // Patrick Lam (prof.lam@gmail.com)
38 // Eli Bendersky (eliben@gmail.com)
40 //------------------------------------------------------------------------------
46 #include "clang/AST/AST.h"
47 #include "clang/AST/ASTContext.h"
48 #include "clang/AST/ASTConsumer.h"
49 #include "clang/AST/RecursiveASTVisitor.h"
50 #include "clang/ASTMatchers/ASTMatchers.h"
51 #include "clang/ASTMatchers/ASTMatchFinder.h"
52 #include "clang/Frontend/ASTConsumers.h"
53 #include "clang/Frontend/FrontendActions.h"
54 #include "clang/Frontend/CompilerInstance.h"
55 #include "clang/Lex/Lexer.h"
56 #include "clang/Tooling/CommonOptionsParser.h"
57 #include "clang/Tooling/Tooling.h"
58 #include "clang/Rewrite/Core/Rewriter.h"
59 #include "llvm/Support/raw_ostream.h"
60 #include "llvm/ADT/STLExtras.h"
62 using namespace clang;
63 using namespace clang::ast_matchers;
64 using namespace clang::driver;
65 using namespace clang::tooling;
68 static LangOptions LangOpts;
69 static llvm::cl::OptionCategory AddMC2AnnotationsCategory("Add MC2 Annotations");
71 static std::string encode(std::string varName) {
73 nn << "_m" << varName;
78 static std::string encodeFn(int num) {
85 static std::string encodePtr(int num) {
92 static std::string encodeRMW(int num) {
98 static int branchCount;
99 static std::string encodeBranch(int num) {
100 std::stringstream nn;
105 static int condCount;
106 static std::string encodeCond(int num) {
107 std::stringstream nn;
108 nn << "_cond" << num;
113 static std::string encodeRV(int num) {
114 std::stringstream nn;
119 static int funcCount;
121 struct ProvisionalName {
123 const DeclRefExpr * pname;
126 ProvisionalName(int index, const DeclRefExpr * pname) : index(index), pname(pname), length(encode(pname->getNameInfo().getName().getAsString()).length()), enabled(true) {}
127 ProvisionalName(int index, const DeclRefExpr * pname, int length) : index(index), pname(pname), length(length), enabled(true) {}
133 std::vector<ProvisionalName *> * pnames;
135 Update(SourceLocation loc, std::string update, std::vector<ProvisionalName *> * pnames) :
136 loc(loc), update(update), pnames(pnames) {}
139 for (auto pname : *pnames) delete pname;
144 void updateProvisionalName(std::vector<Update *> &DeferredUpdates, const ValueDecl * now_known, std::string mcVar) {
145 for (Update * u : DeferredUpdates) {
146 for (int i = 0; i < u->pnames->size(); i++) {
147 ProvisionalName * v = (*(u->pnames))[i];
148 if (!v->enabled) continue;
149 if (now_known == v->pname->getDecl()) {
151 std::string oldName = encode(v->pname->getNameInfo().getName().getAsString());
153 u->update.replace(v->index, v->length, mcVar);
154 for (int j = i+1; j < u->pnames->size(); j++) {
155 ProvisionalName * vv = (*(u->pnames))[j];
156 if (vv->index > v->index)
157 vv->index -= v->length - mcVar.length();
164 static const VarDecl * retrieveSingleDecl(const DeclStmt * s) {
165 // XXX iterate through all decls defined in s, not just the first one
166 assert(s->isSingleDecl() && isa<VarDecl>(s->getSingleDecl()) && "unsupported form of decl");
167 if (s->isSingleDecl() && isa<VarDecl>(s->getSingleDecl())) {
168 return cast<VarDecl>(s->getSingleDecl());
172 class FindCallArgVisitor : public RecursiveASTVisitor<FindCallArgVisitor> {
174 FindCallArgVisitor() : DE(NULL), UnaryOp(NULL) {}
176 bool VisitStmt(Stmt * s) {
178 if (UnaryOperator * uo = dyn_cast<UnaryOperator>(s)) {
179 if (uo->getOpcode() == UnaryOperatorKind::UO_AddrOf ||
180 uo->getOpcode() == UnaryOperatorKind::UO_Deref)
185 if (!DE && (DE = dyn_cast<DeclRefExpr>(s)))
191 UnaryOp = NULL; DE = NULL;
194 const UnaryOperator * RetrieveUnaryOp() {
197 const Stmt * s = UnaryOp;
203 if (const UnaryOperator * op = dyn_cast<UnaryOperator>(s))
204 s = op->getSubExpr();
205 else if (const CastExpr * op = dyn_cast<CastExpr>(s))
206 s = op->getSubExpr();
207 else if (const MemberExpr * op = dyn_cast<MemberExpr>(s))
218 const DeclRefExpr * RetrieveDeclRefExpr() {
223 const UnaryOperator * UnaryOp;
224 const DeclRefExpr * DE;
227 class FindLocalsVisitor : public RecursiveASTVisitor<FindLocalsVisitor> {
229 FindLocalsVisitor() : Vars() {}
231 bool VisitDeclRefExpr(DeclRefExpr * de) {
232 Vars.push_back(de->getDecl());
240 const TinyPtrVector<const NamedDecl *> RetrieveVars() {
245 TinyPtrVector<const NamedDecl *> Vars;
249 class MallocHandler : public MatchFinder::MatchCallback {
251 MallocHandler(std::set<const Expr *> & MallocExprs) :
252 MallocExprs(MallocExprs) {}
254 virtual void run(const MatchFinder::MatchResult &Result) {
255 const CallExpr * ce = Result.Nodes.getNodeAs<CallExpr>("callExpr");
257 MallocExprs.insert(ce);
261 std::set<const Expr *> &MallocExprs;
264 static void generateMC2Function(Rewriter & Rewrite,
269 const DeclRefExpr * lhs,
271 std::vector<ProvisionalName *> * vars1,
272 std::vector<Update *> & DeferredUpdates) {
273 // prettyprint the LHS (&newnode->value)
274 // e.g. int * _tmp0 = &newnode->value;
276 llvm::raw_string_ostream S(SStr);
277 e->printPretty(S, nullptr, Rewrite.getLangOpts());
278 const std::string &Str = S.str();
280 std::stringstream prel;
281 prel << "\nvoid * " << tmpname << " = " << Str << ";\n";
283 // MCID _p0 = MC2_function(1, MC2_PTR_LENGTH, _tmp0, _fn0);
284 prel << "MCID " << tmpFn << " = MC2_function_id(" << ++funcCount << ", 1, MC2_PTR_LENGTH, " << tmpname << ", ";
286 // XXX generate casts when they'd eliminate warnings
287 ProvisionalName * v = new ProvisionalName(prel.tellp(), lhs);
290 prel << encode(lhsName) << "); ";
292 Update * u = new Update(loc, prel.str(), vars1);
293 DeferredUpdates.push_back(u);
296 class LoadHandler : public MatchFinder::MatchCallback {
298 LoadHandler(Rewriter &Rewrite,
299 std::set<const NamedDecl *> & DeclsRead,
300 std::set<const VarDecl *> & DeclsNeedingMC,
301 std::map<const NamedDecl *, std::string> &DeclToMCVar,
302 std::map<const Expr *, std::string> &ExprToMCVar,
303 std::set<const Stmt *> & StmtsHandled,
304 std::map<const Expr *, SourceLocation> &Redirector,
305 std::vector<Update *> & DeferredUpdates) :
306 Rewrite(Rewrite), DeclsRead(DeclsRead), DeclsNeedingMC(DeclsNeedingMC), DeclToMCVar(DeclToMCVar),
307 ExprToMCVar(ExprToMCVar),
308 StmtsHandled(StmtsHandled), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
310 virtual void run(const MatchFinder::MatchResult &Result) {
311 std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
312 CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
313 const VarDecl * d = Result.Nodes.getNodeAs<VarDecl>("decl");
314 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
315 const Expr * lhs = NULL;
316 if (s && isa<BinaryOperator>(s)) lhs = cast<BinaryOperator>(s)->getLHS();
319 const DeclRefExpr * rhs = NULL;
320 MemberExpr * ml = NULL;
321 bool isAddrOfR = false, isAddrMemberR = false;
323 StmtsHandled.insert(s);
325 std::string n, n_decl;
327 FindCallArgVisitor fcaVisitor;
329 fcaVisitor.TraverseStmt(ce->getArg(0));
330 rhs = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
331 const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
332 isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
333 isAddrMemberR = ruop && isa<MemberExpr>(ruop->getSubExpr());
335 ml = dyn_cast<MemberExpr>(ruop->getSubExpr());
337 FindLocalsVisitor flv;
339 flv.TraverseStmt(const_cast<Stmt*>(cast<Stmt>(lhs)));
340 for (auto & d : flv.RetrieveVars()) {
341 const VarDecl * dd = cast<VarDecl>(d);
343 // XXX todo rhs for non-decl stmts
344 if (!isa<ParmVarDecl>(dd))
345 DeclsNeedingMC.insert(dd);
347 DeclToMCVar[dd] = encode(n);
350 FindCallArgVisitor fcaVisitor;
352 fcaVisitor.TraverseStmt(ce->getArg(0));
353 rhs = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
354 const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
355 isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
356 isAddrMemberR = ruop && isa<MemberExpr>(ruop->getSubExpr());
358 ml = dyn_cast<MemberExpr>(ruop->getSubExpr());
362 DeclsNeedingMC.insert(d);
364 DeclToMCVar[d] = encode(n);
368 fcaVisitor.TraverseStmt(ce);
369 const DeclRefExpr * dd = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
370 updateProvisionalName(DeferredUpdates, dd->getDecl(), encode(n));
371 DeclToMCVar[dd->getDecl()] = encode(n);
376 std::stringstream nol;
378 if (lhs && isa<DeclRefExpr>(lhs)) {
379 const DeclRefExpr * ll = cast<DeclRefExpr>(lhs);
380 ProvisionalName * v = new ProvisionalName(nol.tellp(), ll);
387 nol << n_decl << encode(n) << "=";
388 nol << "MC2_nextOpLoadOffset(";
390 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
392 nol << encode(rhs->getNameInfo().getName().getAsString());
394 nol << ", MC2_OFFSET(";
395 nol << ml->getBase()->getType().getAsString();
397 nol << ml->getMemberDecl()->getName().str();
399 } else if (!isAddrOfR) {
401 nol << n_decl << encode(n) << "=";
402 nol << "MC2_nextOpLoad(";
403 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
405 nol << encode(rhs->getNameInfo().getName().getAsString());
408 nol << n_decl << encode(n) << "=";
409 nol << "MC2_nextOpLoad(";
414 nol << n_decl << encode(n) << "=";
415 nol << "MC2_nextOpLoad(";
423 SourceLocation ss = s->getLocStart();
425 // if the load appears as its own stmt and is the 1st stmt, containingStmt may be the containing CompoundStmt;
426 // move over 1 so that we get the right location.
427 if (isa<CompoundStmt>(s)) ss = ss.getLocWithOffset(1);
428 const Expr * e = dyn_cast<Expr>(s);
429 if (e && Redirector.count(e) > 0)
431 Update * u = new Update(ss, nol.str(), vars);
432 DeferredUpdates.insert(DeferredUpdates.begin(), u);
437 std::set<const NamedDecl *> & DeclsRead;
438 std::set<const VarDecl *> & DeclsNeedingMC;
439 std::map<const Expr *, std::string> &ExprToMCVar;
440 std::map<const NamedDecl *, std::string> &DeclToMCVar;
441 std::set<const Stmt *> &StmtsHandled;
442 std::vector<Update *> &DeferredUpdates;
443 std::map<const Expr *, SourceLocation> &Redirector;
446 class StoreHandler : public MatchFinder::MatchCallback {
448 StoreHandler(Rewriter &Rewrite,
449 std::set<const NamedDecl *> & DeclsRead,
450 std::set<const VarDecl *> &DeclsNeedingMC,
451 std::vector<Update *> & DeferredUpdates) :
452 Rewrite(Rewrite), DeclsRead(DeclsRead), DeclsNeedingMC(DeclsNeedingMC), DeferredUpdates(DeferredUpdates) {}
454 virtual void run(const MatchFinder::MatchResult &Result) {
455 std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
456 CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
459 fcaVisitor.TraverseStmt(ce->getArg(0));
460 const DeclRefExpr * lhs = fcaVisitor.RetrieveDeclRefExpr();
461 const UnaryOperator * luop = fcaVisitor.RetrieveUnaryOp();
463 std::stringstream nol;
468 if (luop && luop->getOpcode() == UnaryOperatorKind::UO_AddrOf) {
469 isAddrMemberL = isa<MemberExpr>(luop->getSubExpr());
470 isAddrOfL = !isa<MemberExpr>(luop->getSubExpr());
475 nol << "MC2_nextOpStore(";
479 MemberExpr * ml = cast<MemberExpr>(luop->getSubExpr());
481 nol << "MC2_nextOpStoreOffset(";
483 ProvisionalName * v = new ProvisionalName(nol.tellp(), lhs);
486 nol << encode(lhs->getNameInfo().getName().getAsString());
487 if (!isa<ParmVarDecl>(lhs->getDecl()))
488 DeclsNeedingMC.insert(cast<VarDecl>(lhs->getDecl()));
490 nol << ", MC2_OFFSET(";
491 nol << ml->getBase()->getType().getAsString();
493 nol << ml->getMemberDecl()->getName().str();
496 nol << "MC2_nextOpStore(";
497 ProvisionalName * v = new ProvisionalName(nol.tellp(), lhs);
500 nol << encode(lhs->getNameInfo().getName().getAsString());
505 nol << "MC2_nextOpStore(";
512 fcaVisitor.TraverseStmt(ce->getArg(1));
513 const DeclRefExpr * rhs = fcaVisitor.RetrieveDeclRefExpr();
514 const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
516 bool isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
517 bool isDerefR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_Deref;
519 if (rhs && !isAddrOfR) {
520 assert (!isDerefR && "Must use atomic load for dereferences!");
521 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
524 nol << encode(rhs->getNameInfo().getName().getAsString());
525 DeclsRead.insert(rhs->getDecl());
531 Update * u = new Update(ce->getLocStart(), nol.str(), vars);
532 DeferredUpdates.push_back(u);
537 FindCallArgVisitor fcaVisitor;
538 std::set<const NamedDecl *> & DeclsRead;
539 std::set<const VarDecl *> & DeclsNeedingMC;
540 std::vector<Update *> &DeferredUpdates;
543 class RMWHandler : public MatchFinder::MatchCallback {
545 RMWHandler(Rewriter &rewrite,
546 std::set<const NamedDecl *> & DeclsRead,
547 std::set<const NamedDecl *> & DeclsInCond,
548 std::map<const NamedDecl *, std::string> &DeclToMCVar,
549 std::map<const Expr *, std::string> &ExprToMCVar,
550 std::set<const Stmt *> & StmtsHandled,
551 std::map<const Expr *, SourceLocation> &Redirector,
552 std::vector<Update *> & DeferredUpdates) :
553 rewrite(rewrite), DeclsRead(DeclsRead), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar),
554 ExprToMCVar(ExprToMCVar),
555 StmtsHandled(StmtsHandled), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
557 virtual void run(const MatchFinder::MatchResult &Result) {
558 CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
559 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
560 std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
562 std::stringstream nol;
564 std::string rmwMCVar;
565 rmwMCVar = encodeRMW(rmwCount++);
567 const VarDecl * rmw_lhs;
569 StmtsHandled.insert(s);
570 assert (isa<DeclStmt>(s) || isa<BinaryOperator>(s) && "unknown RMW format: not declrefexpr, not binaryoperator");
572 if ((ds = dyn_cast<DeclStmt>(s))) {
573 rmw_lhs = retrieveSingleDecl(ds);
575 const Expr * e = cast<BinaryOperator>(s)->getLHS();
576 assert (isa<DeclRefExpr>(e));
577 rmw_lhs = cast<VarDecl>(cast<DeclRefExpr>(e)->getDecl());
579 DeclToMCVar[rmw_lhs] = rmwMCVar;
582 // retrieve effective LHS of the RMW
584 fcaVisitor.TraverseStmt(ce->getArg(1));
585 const DeclRefExpr * elhs = fcaVisitor.RetrieveDeclRefExpr();
586 const UnaryOperator * eluop = fcaVisitor.RetrieveUnaryOp();
587 bool isAddrMemberL = false;
589 if (eluop && eluop->getOpcode() == UnaryOperatorKind::UO_AddrOf) {
590 isAddrMemberL = isa<MemberExpr>(eluop->getSubExpr());
593 nol << "MCID " << rmwMCVar;
595 MemberExpr * ml = cast<MemberExpr>(eluop->getSubExpr());
597 nol << " = MC2_nextRMWOffset(";
599 ProvisionalName * v = new ProvisionalName(nol.tellp(), elhs);
602 nol << encode(elhs->getNameInfo().getName().getAsString());
604 nol << ", MC2_OFFSET(";
605 nol << ml->getBase()->getType().getAsString();
607 nol << ml->getMemberDecl()->getName().str();
610 nol << " = MC2_nextRMW(";
611 bool isAddrOfL = eluop && eluop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
617 ProvisionalName * v = new ProvisionalName(nol.tellp(), elhs);
620 std::string elhsName = encode(elhs->getNameInfo().getName().getAsString());
629 // handle both RHS ops
631 for (int arg = 2; arg < 4; arg++) {
633 fcaVisitor.TraverseStmt(ce->getArg(arg));
634 const DeclRefExpr * a = fcaVisitor.RetrieveDeclRefExpr();
635 const UnaryOperator * op = fcaVisitor.RetrieveUnaryOp();
637 bool isAddrOfR = op && op->getOpcode() == UnaryOperatorKind::UO_AddrOf;
638 bool isDerefR = op && op->getOpcode() == UnaryOperatorKind::UO_Deref;
640 if (a && !isAddrOfR) {
641 assert (!isDerefR && "Must use atomic load for dereferences!");
643 DeclsInCond.insert(a->getDecl());
645 if (outputted > 0) nol << ", ";
648 bool alreadyMCVar = false;
649 if (DeclToMCVar.find(a->getDecl()) != DeclToMCVar.end()) {
651 nol << DeclToMCVar[a->getDecl()];
654 std::string an = "MCID_NODEP";
655 ProvisionalName * v = new ProvisionalName(nol.tellp(), a, an.length());
660 DeclsRead.insert(a->getDecl());
663 if (outputted > 0) nol << ", ";
671 SourceLocation place = s ? s->getLocStart() : ce->getLocStart();
672 const Expr * e = s ? dyn_cast<Expr>(s) : ce;
673 if (e && Redirector.count(e) > 0)
674 place = Redirector[e];
675 Update * u = new Update(place, nol.str(), vars);
676 DeferredUpdates.insert(DeferredUpdates.begin(), u);
681 FindCallArgVisitor fcaVisitor;
682 std::set<const NamedDecl *> &DeclsRead;
683 std::set<const NamedDecl *> &DeclsInCond;
684 std::map<const NamedDecl *, std::string> &DeclToMCVar;
685 std::map<const Expr *, std::string> &ExprToMCVar;
686 std::set<const Stmt *> &StmtsHandled;
687 std::vector<Update *> &DeferredUpdates;
688 std::map<const Expr *, SourceLocation> &Redirector;
691 class FindReturnsBreaksVisitor : public RecursiveASTVisitor<FindReturnsBreaksVisitor> {
693 FindReturnsBreaksVisitor() : Returns(), Breaks() {}
695 bool VisitStmt(Stmt * s) {
696 if (isa<ReturnStmt>(s))
697 Returns.push_back(cast<ReturnStmt>(s));
699 if (isa<BreakStmt>(s))
700 Breaks.push_back(cast<BreakStmt>(s));
705 Returns.clear(); Breaks.clear();
708 const std::vector<const ReturnStmt *> RetrieveReturns() {
712 const std::vector<const BreakStmt *> RetrieveBreaks() {
717 std::vector<const ReturnStmt *> Returns;
718 std::vector<const BreakStmt *> Breaks;
721 class LoopHandler : public MatchFinder::MatchCallback {
723 LoopHandler(Rewriter &rewrite) : rewrite(rewrite) {}
725 virtual void run(const MatchFinder::MatchResult &Result) {
726 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("s");
728 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(s->getLocStart()),
729 "MC2_enterLoop();\n", true, true);
731 // annotate all returns with MC2_exitLoop()
732 // annotate all breaks that aren't further nested with MC2_exitLoop().
733 FindReturnsBreaksVisitor frbv;
735 frbv.TraverseStmt(const_cast<Stmt *>(cast<ForStmt>(s)->getBody()));
736 if (isa<WhileStmt>(s))
737 frbv.TraverseStmt(const_cast<Stmt *>(cast<WhileStmt>(s)->getBody()));
739 frbv.TraverseStmt(const_cast<Stmt *>(cast<DoStmt>(s)->getBody()));
741 for (auto & r : frbv.RetrieveReturns()) {
742 rewrite.InsertText(r->getLocStart(), "MC2_exitLoop();\n", true, true);
745 // need to find all breaks and returns embedded inside the loop
747 rewrite.InsertTextAfterToken(rewrite.getSourceMgr().getExpansionLoc(s->getLocEnd().getLocWithOffset(1)),
748 "\nMC2_exitLoop();\n");
755 /* Inserts MC2_function for any variables which are subsequently used by the model checker, as long as they depend on MC-visible [currently: read] state. */
756 class AssignHandler : public MatchFinder::MatchCallback {
758 AssignHandler(Rewriter &rewrite, std::set<const NamedDecl *> &DeclsRead,
759 std::set<const NamedDecl *> &DeclsInCond,
760 std::set<const VarDecl *> &DeclsNeedingMC,
761 std::map<const NamedDecl *, std::string> &DeclToMCVar,
762 std::set<const Stmt *> &StmtsHandled,
763 std::set<const Expr *> &MallocExprs,
764 std::vector<Update *> &DeferredUpdates) :
766 DeclsRead(DeclsRead),
767 DeclsInCond(DeclsInCond),
768 DeclsNeedingMC(DeclsNeedingMC),
769 DeclToMCVar(DeclToMCVar),
770 StmtsHandled(StmtsHandled),
771 MallocExprs(MallocExprs),
772 DeferredUpdates(DeferredUpdates) {}
774 virtual void run(const MatchFinder::MatchResult &Result) {
775 BinaryOperator * op = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("op"));
776 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
777 FindLocalsVisitor flv;
779 const VarDecl * lhs = NULL;
780 const Expr * rhs = NULL;
783 if (s && (ds = dyn_cast<DeclStmt>(s))) {
784 // long term goal: refactor the run() method to deal with one assignment at a time
785 // for now, if there is only declarations and no rhs's, we'll ignore this stmt
786 if (!ds->isSingleDecl()) {
787 for (auto & d : ds->decls()) {
788 VarDecl * vd = dyn_cast<VarDecl>(d);
789 if (!d || vd->hasInit())
790 assert(0 && "unsupported form of decl");
795 lhs = retrieveSingleDecl(ds);
798 if (StmtsHandled.find(ds) != StmtsHandled.end() || StmtsHandled.find(op) != StmtsHandled.end())
802 if (lhs->hasInit()) {
803 rhs = lhs->getInit();
805 rhs = rhs->IgnoreCasts();
811 std::set<std::string> mcState;
813 bool lhsTooComplicated = false;
815 flv.TraverseStmt(op);
818 if ((vd = dyn_cast<DeclRefExpr>(op->getLHS())))
819 lhs = dyn_cast<VarDecl>(vd->getDecl());
821 // kick the can along...
822 lhsTooComplicated = true;
827 rhs = rhs->IgnoreCasts();
830 // rhs must be MC-active state, i.e. in declsread
831 // lhs must be subsequently used in (1) store/load or (2) branch condition or (3) other functions and (3a) uses values from other functions or (3b) uses values from loads, stores, or phi functions
832 flv.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
835 if (DeclsInCond.find(lhs) != DeclsInCond.end()) {
836 for (auto & d : flv.RetrieveVars()) {
837 if (DeclToMCVar.count(d) > 0)
838 mcState.insert(DeclToMCVar[d]);
839 else if (DeclsRead.find(d) != DeclsRead.end())
840 mcState.insert(encode(d->getName().str()));
844 if (mcState.size() > 0 || MallocExprs.find(rhs) != MallocExprs.end()) {
845 if (lhsTooComplicated)
846 assert(0 && "couldn't find LHS of = operator");
848 std::stringstream nol;
849 std::string _lhsStr, lhsStr;
850 std::string mcVar = encodeFn(fnCount++);
852 lhsStr = lhs->getName().str();
853 _lhsStr = encode(lhsStr);
854 DeclToMCVar[lhs] = mcVar;
855 DeclsNeedingMC.insert(cast<VarDecl>(lhs));
858 if (!(MallocExprs.find(rhs) != MallocExprs.end()))
859 function_id = ++funcCount;
860 nol << "\n" << mcVar << " = MC2_function_id(" << function_id << ", " << mcState.size();
862 nol << ", sizeof (" << lhsStr << "), (uint64_t)" << lhsStr;
864 nol << ", MC2_PTR_LENGTH";
865 for (auto & d : mcState) {
873 SourceLocation place;
875 place = op->getLocEnd().getLocWithOffset(1);
877 place = s->getLocEnd();
878 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(place.getLocWithOffset(1)),
879 nol.str(), true, true);
881 updateProvisionalName(DeferredUpdates, lhs, mcVar);
887 std::set<const NamedDecl *> &DeclsRead, &DeclsInCond;
888 std::set<const VarDecl *> &DeclsNeedingMC;
889 std::map<const NamedDecl *, std::string> &DeclToMCVar;
890 std::set<const Stmt *> &StmtsHandled;
891 std::set<const Expr *> &MallocExprs;
892 std::vector<Update *> &DeferredUpdates;
895 // record vars used in conditions
896 class BranchConditionRefactoringHandler : public MatchFinder::MatchCallback {
898 BranchConditionRefactoringHandler(Rewriter &rewrite,
899 std::set<const NamedDecl *> & DeclsInCond,
900 std::map<const NamedDecl *, std::string> &DeclToMCVar,
901 std::map<const Expr *, std::string> &ExprToMCVar,
902 std::map<const Expr *, SourceLocation> &Redirector,
903 std::vector<Update *> &DeferredUpdates) :
904 rewrite(rewrite), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar),
905 ExprToMCVar(ExprToMCVar), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
907 virtual void run(const MatchFinder::MatchResult &Result) {
908 IfStmt * is = const_cast<IfStmt *>(Result.Nodes.getNodeAs<IfStmt>("if"));
909 Expr * cond = is->getCond();
911 // refactor out complicated conditions
912 FindCallArgVisitor flv;
913 flv.TraverseStmt(cond);
916 BinaryOperator * bc = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("bc"));
918 std::string condVar = encodeCond(condCount++);
919 std::stringstream condVarEncoded;
920 condVarEncoded << condVar << "_m";
922 // prettyprint the binary op
923 // e.g. int _cond0 = x == y;
925 llvm::raw_string_ostream S(SStr);
926 bc->printPretty(S, nullptr, rewrite.getLangOpts());
927 const std::string &Str = S.str();
929 std::stringstream prel;
931 bool is_equality = false;
932 // handle equality tests
933 if (bc->getOpcode() == BO_EQ) {
934 Expr * lhs = bc->getLHS()->IgnoreCasts(), * rhs = bc->getRHS()->IgnoreCasts();
935 if (isa<DeclRefExpr>(lhs) && isa<DeclRefExpr>(rhs)) {
936 DeclRefExpr * l = dyn_cast<DeclRefExpr>(lhs), *r = dyn_cast<DeclRefExpr>(rhs);
938 prel << "\nMCID " << condVarEncoded.str() << ";\n";
939 std::string ld = DeclToMCVar.find(l->getDecl())->second,
940 rd = DeclToMCVar.find(r->getDecl())->second;
942 prel << "\nint " << condVar << " = MC2_equals(" <<
943 ld << ", (uint64_t)" << l->getNameInfo().getName().getAsString() << ", " <<
944 rd << ", (uint64_t)" << r->getNameInfo().getName().getAsString() << ", " <<
945 "&" << condVarEncoded.str() << ");\n";
950 prel << "\nint " << condVar << " = " << Str << ";";
951 prel << "\nMCID " << condVarEncoded.str() << " = MC2_function_id(" << ++funcCount << ", 1, sizeof(" << condVar << "), " << condVar;
952 const DeclRefExpr * d = flv.RetrieveDeclRefExpr();
953 if (DeclToMCVar.find(d->getDecl()) != DeclToMCVar.end()) {
954 prel << ", " << DeclToMCVar[d->getDecl()];
959 ExprToMCVar[cond] = condVarEncoded.str();
960 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(is->getLocStart()),
961 prel.str(), false, true);
963 // rewrite the binary op with the newly-inserted var
964 Expr * RO = bc->getRHS(); // used for location only
966 int cl = Lexer::MeasureTokenLength(RO->getLocStart(), rewrite.getSourceMgr(), rewrite.getLangOpts());
967 SourceRange SR(cond->getLocStart(), rewrite.getSourceMgr().getExpansionLoc(RO->getLocStart()).getLocWithOffset(cl-1));
968 rewrite.ReplaceText(SR, condVar);
970 std::string condVar = encodeCond(condCount++);
971 std::stringstream condVarEncoded;
972 condVarEncoded << condVar << "_m";
975 llvm::raw_string_ostream S(SStr);
976 cond->printPretty(S, nullptr, rewrite.getLangOpts());
977 const std::string &Str = S.str();
979 std::stringstream prel;
980 prel << "\nint " << condVar << " = " << Str << ";";
981 prel << "\nMCID " << condVarEncoded.str() << " = MC2_function_id(" << ++funcCount << ", 1, sizeof(" << condVar << "), " << condVar;
982 std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
983 const DeclRefExpr * d = flv.RetrieveDeclRefExpr();
984 if (isa<VarDecl>(d->getDecl()) && DeclToMCVar.find(d->getDecl()) != DeclToMCVar.end()) {
985 prel << ", " << DeclToMCVar[d->getDecl()];
988 ProvisionalName * v = new ProvisionalName(prel.tellp(), d, 0);
993 ExprToMCVar[cond] = condVarEncoded.str();
994 // gross hack; should look for any callexprs in cond
995 // but right now, if it's a unaryop, just manually traverse
996 if (isa<UnaryOperator>(cond)) {
997 Expr * e = dyn_cast<UnaryOperator>(cond)->getSubExpr();
998 ExprToMCVar[e] = condVarEncoded.str();
1000 Update * u = new Update(is->getLocStart(), prel.str(), vars);
1001 DeferredUpdates.push_back(u);
1003 // rewrite the call op with the newly-inserted var
1004 SourceRange SR(cond->getLocStart(), cond->getLocEnd());
1005 Redirector[cond] = is->getLocStart();
1006 rewrite.ReplaceText(SR, condVar);
1009 std::deque<const Decl *> q;
1010 const NamedDecl * d = flv.RetrieveDeclRefExpr()->getDecl();
1012 while (!q.empty()) {
1013 const Decl * d = q.back();
1015 if (isa<NamedDecl>(d))
1016 DeclsInCond.insert(cast<NamedDecl>(d));
1019 if ((vd = dyn_cast<VarDecl>(d))) {
1020 if (vd->hasInit()) {
1021 const Expr * e = vd->getInit();
1023 flv.TraverseStmt(const_cast<Expr *>(e));
1024 const NamedDecl * d = flv.RetrieveDeclRefExpr()->getDecl();
1033 std::set<const NamedDecl *> & DeclsInCond;
1034 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1035 std::map<const Expr *, std::string> &ExprToMCVar;
1036 std::map<const Expr *, SourceLocation> &Redirector;
1037 std::vector<Update *> &DeferredUpdates;
1040 class BranchAnnotationHandler : public MatchFinder::MatchCallback {
1042 BranchAnnotationHandler(Rewriter &rewrite,
1043 std::map<const NamedDecl *, std::string> & DeclToMCVar,
1044 std::map<const Expr *, std::string> & ExprToMCVar)
1046 DeclToMCVar(DeclToMCVar),
1047 ExprToMCVar(ExprToMCVar){}
1048 virtual void run(const MatchFinder::MatchResult &Result) {
1049 IfStmt * is = const_cast<IfStmt *>(Result.Nodes.getNodeAs<IfStmt>("if"));
1051 // if the branch condition is interesting:
1052 // (but right now, not too interesting)
1053 Expr * cond = is->getCond()->IgnoreCasts();
1055 FindLocalsVisitor flv;
1056 flv.TraverseStmt(cond);
1057 if (flv.RetrieveVars().size() == 0) return;
1059 const NamedDecl * condVar = flv.RetrieveVars()[0];
1061 std::string mCondVar;
1062 if (ExprToMCVar.count(cond) > 0)
1063 mCondVar = ExprToMCVar[cond];
1064 else if (DeclToMCVar.count(condVar) > 0)
1065 mCondVar = DeclToMCVar[condVar];
1067 mCondVar = encode(condVar->getName());
1068 std::string brVar = encodeBranch(branchCount++);
1070 std::stringstream brline;
1071 brline << "MCID " << brVar << ";\n";
1072 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(is->getLocStart()),
1073 brline.str(), false, true);
1075 Stmt * ts = is->getThen(), * es = is->getElse();
1076 bool tHasChild = hasChild(ts);
1079 if (isa<CompoundStmt>(ts))
1080 tfl = getFirstChild(ts)->getLocStart();
1082 tfl = ts->getLocStart();
1084 tfl = ts->getLocStart().getLocWithOffset(1);
1085 SourceLocation tsl = ts->getLocEnd().getLocWithOffset(-1);
1087 std::stringstream tlineStart, mergeStmt, eline;
1089 UnaryOperator * uop = dyn_cast<UnaryOperator>(cond);
1090 tlineStart << brVar << " = MC2_branchUsesID(" << mCondVar << ", " << "1" << ", 2, true);\n";
1091 eline << brVar << " = MC2_branchUsesID(" << mCondVar << ", " << "0" << ", 2, true);";
1093 mergeStmt << "\tMC2_merge(" << brVar << ");\n";
1095 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(tfl), tlineStart.str(), false, true);
1098 int extra_else_offset = 0;
1100 if (tHasChild) { tls = getLastChild(ts); }
1101 if (tls) extra_else_offset = 2; else extra_else_offset = 1;
1103 if (!tHasChild || (!isa<ReturnStmt>(tls) && !isa<BreakStmt>(tls))) {
1104 extra_else_offset = 0;
1105 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(tsl.getLocWithOffset(1)),
1106 mergeStmt.str(), true, true);
1108 if (tHasChild && !isa<CompoundStmt>(ts)) {
1109 rewrite.InsertText(rewrite.getSourceMgr().getFileLoc(tls->getLocStart()), "{", false, true);
1110 SourceLocation tend = Lexer::findLocationAfterToken(tls->getLocEnd(), tok::semi, rewrite.getSourceMgr(), rewrite.getLangOpts(), false);
1111 rewrite.InsertText(tend, "}", true, true);
1113 if (tHasChild && isa<CompoundStmt>(ts)) extra_else_offset++;
1116 SourceLocation esl = es->getLocEnd().getLocWithOffset(-1);
1117 bool eHasChild = hasChild(es);
1119 if (eHasChild) els = getLastChild(es); else els = es;
1125 if (isa<CompoundStmt>(es))
1126 el = getFirstChild(es)->getLocStart();
1128 el = es->getLocStart();
1131 el = es->getLocStart().getLocWithOffset(1);
1132 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(el), eline.str(), false, true);
1134 if (eHasChild && !isa<CompoundStmt>(es)) {
1135 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(el), "{", false, true);
1136 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(es->getLocEnd().getLocWithOffset(1)), "}", true, true);
1139 if (!eHasChild || (!isa<ReturnStmt>(els) && !isa<BreakStmt>(els)))
1140 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(esl.getLocWithOffset(1)), mergeStmt.str(), true, true);
1143 std::stringstream eCompoundLine;
1144 eCompoundLine << " else { " << eline.str() << mergeStmt.str() << " }";
1145 SourceLocation tend = Lexer::findLocationAfterToken(ts->getLocEnd(), tok::semi, rewrite.getSourceMgr(), rewrite.getLangOpts(), false);
1146 if (!tend.isValid())
1147 tend = Lexer::getLocForEndOfToken(ts->getLocEnd(), 0, rewrite.getSourceMgr(), rewrite.getLangOpts());
1148 rewrite.InsertText(tend.getLocWithOffset(1), eCompoundLine.str(), false, true);
1153 bool hasChild(Stmt * s) {
1154 if (!isa<CompoundStmt>(s)) return true;
1155 return (!cast<CompoundStmt>(s)->body_empty());
1158 Stmt * getFirstChild(Stmt * s) {
1159 assert(isa<CompoundStmt>(s) && "haven't yet added code to rewrite then/elsestmt to CompoundStmt");
1160 assert(!cast<CompoundStmt>(s)->body_empty());
1161 return *(cast<CompoundStmt>(s)->body_begin());
1164 Stmt * getLastChild(Stmt * s) {
1166 if ((cs = dyn_cast<CompoundStmt>(s))) {
1167 assert (!cs->body_empty());
1168 return cs->body_back();
1174 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1175 std::map<const Expr *, std::string> &ExprToMCVar;
1178 class FunctionCallHandler : public MatchFinder::MatchCallback {
1180 FunctionCallHandler(Rewriter &rewrite,
1181 std::map<const NamedDecl *, std::string> &DeclToMCVar,
1182 std::set<const FunctionDecl *> &ThreadMains)
1183 : rewrite(rewrite), DeclToMCVar(DeclToMCVar), ThreadMains(ThreadMains) {}
1185 virtual void run(const MatchFinder::MatchResult &Result) {
1186 CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
1187 Decl * d = ce->getCalleeDecl();
1188 NamedDecl * nd = dyn_cast<NamedDecl>(d);
1189 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
1190 ASTContext *Context = Result.Context;
1192 if (nd->getName() == "thrd_create") {
1193 Expr * callee0 = ce->getArg(1)->IgnoreCasts();
1194 UnaryOperator * callee1;
1195 if ((callee1 = dyn_cast<UnaryOperator>(callee0))) {
1196 if (callee1->getOpcode() == UnaryOperatorKind::UO_AddrOf)
1197 callee0 = callee1->getSubExpr();
1199 DeclRefExpr * callee = dyn_cast<DeclRefExpr>(callee0);
1200 if (!callee) return;
1201 FunctionDecl * fd = dyn_cast<FunctionDecl>(callee->getDecl());
1202 ThreadMains.insert(fd);
1209 if (s && !ce->getCallReturnType(*Context)->isVoidType()) {
1210 // TODO check that the type is mc-visible also?
1211 const DeclStmt * ds;
1212 const VarDecl * lhs = NULL;
1213 std::string mc_rv = encodeRV(rvCount++);
1215 std::stringstream brline;
1216 brline << "MCID " << mc_rv << ";\n";
1217 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(s->getLocStart()),
1218 brline.str(), false, true);
1220 std::stringstream nol;
1221 if (ce->getNumArgs() > 0) nol << ", ";
1222 nol << "&" << mc_rv;
1223 rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(ce->getRParenLoc()),
1226 if (s && (ds = dyn_cast<DeclStmt>(s))) {
1227 if (!ds->isSingleDecl()) {
1228 for (auto & d : ds->decls()) {
1229 VarDecl * vd = dyn_cast<VarDecl>(d);
1230 if (!d || vd->hasInit())
1231 assert(0 && "unsupported form of decl");
1236 lhs = retrieveSingleDecl(ds);
1239 DeclToMCVar[lhs] = mc_rv;
1242 for (const auto & a : ce->arguments()) {
1243 std::stringstream nol;
1245 std::string aa = "MCID_NODEP";
1247 Expr * e = a->IgnoreCasts();
1248 DeclRefExpr * dr = dyn_cast<DeclRefExpr>(e);
1250 NamedDecl * d = dr->getDecl();
1251 if (DeclToMCVar.find(d) != DeclToMCVar.end())
1252 aa = DeclToMCVar[d];
1257 if (a->getLocEnd().isValid())
1258 rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(a->getLocStart()),
1265 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1266 std::set<const FunctionDecl *> &ThreadMains;
1269 class ReturnHandler : public MatchFinder::MatchCallback {
1271 ReturnHandler(Rewriter &rewrite,
1272 std::map<const NamedDecl *, std::string> &DeclToMCVar,
1273 std::set<const FunctionDecl *> &ThreadMains)
1274 : rewrite(rewrite), DeclToMCVar(DeclToMCVar), ThreadMains(ThreadMains) {}
1276 virtual void run(const MatchFinder::MatchResult &Result) {
1277 const FunctionDecl * fd = Result.Nodes.getNodeAs<FunctionDecl>("containingFunction");
1278 ReturnStmt * rs = const_cast<ReturnStmt *>(Result.Nodes.getNodeAs<ReturnStmt>("returnStmt"));
1279 Expr * rv = const_cast<Expr *>(rs->getRetValue());
1282 if (ThreadMains.find(fd) != ThreadMains.end()) return;
1283 // not sure why this is explicitly needed, but crashes without it
1284 if (!fd->getIdentifier() || fd->getName() == "user_main") return;
1286 FindLocalsVisitor flv;
1287 flv.TraverseStmt(rv);
1288 std::string mrv = "MCID_NODEP";
1290 if (flv.RetrieveVars().size() > 0) {
1291 const NamedDecl * returnVar = flv.RetrieveVars()[0];
1292 if (DeclToMCVar.find(returnVar) != DeclToMCVar.end()) {
1293 mrv = DeclToMCVar[returnVar];
1296 std::stringstream nol;
1297 nol << "*retval = " << mrv << ";\n";
1298 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(rs->getLocStart()),
1299 nol.str(), false, true);
1304 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1305 std::set<const FunctionDecl *> &ThreadMains;
1308 class VarDeclHandler : public MatchFinder::MatchCallback {
1310 VarDeclHandler(Rewriter &rewrite,
1311 std::map<const NamedDecl *, std::string> &DeclToMCVar,
1312 std::set<const VarDecl *> &DeclsNeedingMC)
1313 : rewrite(rewrite), DeclToMCVar(DeclToMCVar), DeclsNeedingMC(DeclsNeedingMC) {}
1315 virtual void run(const MatchFinder::MatchResult &Result) {
1316 VarDecl * d = const_cast<VarDecl *>(Result.Nodes.getNodeAs<VarDecl>("d"));
1317 std::stringstream nol;
1319 if (DeclsNeedingMC.find(d) == DeclsNeedingMC.end()) return;
1322 if (DeclToMCVar.find(d) != DeclToMCVar.end())
1323 dn = DeclToMCVar[d];
1325 dn = encode(d->getName().str());
1327 nol << "MCID " << dn << "; ";
1329 if (d->getLocStart().isValid())
1330 rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(d->getLocStart()),
1336 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1337 std::set<const VarDecl *> &DeclsNeedingMC;
1340 class FunctionDeclHandler : public MatchFinder::MatchCallback {
1342 FunctionDeclHandler(Rewriter &rewrite,
1343 std::set<const FunctionDecl *> &ThreadMains)
1344 : rewrite(rewrite), ThreadMains(ThreadMains) {}
1346 virtual void run(const MatchFinder::MatchResult &Result) {
1347 FunctionDecl * fd = const_cast<FunctionDecl *>(Result.Nodes.getNodeAs<FunctionDecl>("fd"));
1349 if (!fd->getIdentifier()) return;
1351 if (fd->getName() == "user_main") { ThreadMains.insert(fd); return; }
1353 if (ThreadMains.find(fd) != ThreadMains.end()) return;
1355 SourceLocation LastParam = fd->getNameInfo().getLocStart().getLocWithOffset(fd->getName().size()).getLocWithOffset(1);
1356 for (auto & p : fd->params()) {
1357 std::stringstream nol;
1358 nol << "MCID " << encode(p->getName()) << ", ";
1359 if (p->getLocStart().isValid())
1360 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(p->getLocStart()),
1362 if (p->getLocEnd().isValid())
1363 LastParam = p->getLocEnd().getLocWithOffset(p->getName().size());
1366 if (!fd->getReturnType()->isVoidType()) {
1367 std::stringstream nol;
1368 if (fd->param_size() > 0) nol << ", ";
1369 nol << "MCID * retval";
1370 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(LastParam),
1377 std::set<const FunctionDecl *> &ThreadMains;
1380 class BailHandler : public MatchFinder::MatchCallback {
1383 virtual void run(const MatchFinder::MatchResult &Result) {
1384 assert(0 && "we don't handle goto statements");
1388 class MyASTConsumer : public ASTConsumer {
1390 MyASTConsumer(Rewriter &R) : R(R),
1394 HandlerMalloc(MallocExprs),
1395 HandlerLoad(R, DeclsRead, DeclsNeedingMC, DeclToMCVar, ExprToMCVar, StmtsHandled, Redirector, DeferredUpdates),
1396 HandlerStore(R, DeclsRead, DeclsNeedingMC, DeferredUpdates),
1397 HandlerRMW(R, DeclsRead, DeclsInCond, DeclToMCVar, ExprToMCVar, StmtsHandled, Redirector, DeferredUpdates),
1399 HandlerBranchConditionRefactoring(R, DeclsInCond, DeclToMCVar, ExprToMCVar, Redirector, DeferredUpdates),
1400 HandlerAssign(R, DeclsRead, DeclsInCond, DeclsNeedingMC, DeclToMCVar, StmtsHandled, MallocExprs, DeferredUpdates),
1401 HandlerAnnotateBranch(R, DeclToMCVar, ExprToMCVar),
1402 HandlerFunctionDecl(R, ThreadMains),
1403 HandlerFunctionCall(R, DeclToMCVar, ThreadMains),
1404 HandlerReturn(R, DeclToMCVar, ThreadMains),
1405 HandlerVarDecl(R, DeclToMCVar, DeclsNeedingMC),
1407 MatcherFunctionCall.addMatcher(callExpr(anyOf(hasParent(compoundStmt()),
1408 hasAncestor(varDecl(hasParent(stmt().bind("containingStmt")))),
1409 hasAncestor(binaryOperator(hasOperatorName("=")).bind("containingStmt")))).bind("callExpr"),
1410 &HandlerFunctionCall);
1411 MatcherLoadStore.addMatcher
1412 (callExpr(callee(functionDecl(anyOf(hasName("malloc"), hasName("calloc"))))).bind("callExpr"),
1415 MatcherLoadStore.addMatcher
1416 (callExpr(callee(functionDecl(anyOf(hasName("load_32"), hasName("load_64")))),
1417 anyOf(hasAncestor(varDecl(hasParent(stmt().bind("containingStmt"))).bind("decl")),
1418 hasAncestor(binaryOperator(hasOperatorName("=")).bind("containingStmt")),
1419 hasParent(stmt().bind("containingStmt"))))
1423 MatcherLoadStore.addMatcher(callExpr(callee(functionDecl(anyOf(hasName("store_32"), hasName("store_64"))))).bind("callExpr"),
1426 MatcherLoadStore.addMatcher
1427 (callExpr(callee(functionDecl(anyOf(hasName("rmw_32"), hasName("rmw_64")))),
1428 anyOf(hasAncestor(varDecl(hasParent(stmt().bind("containingStmt"))).bind("decl")),
1429 hasAncestor(binaryOperator(hasOperatorName("="),
1430 hasLHS(declRefExpr().bind("lhs"))).bind("containingStmt")),
1435 MatcherLoadStore.addMatcher(ifStmt(hasCondition
1436 (anyOf(binaryOperator().bind("bc"),
1437 hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("load_32"), hasName("load_64"))))).bind("callExpr")),
1438 hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("store_32"), hasName("store_64"))))).bind("callExpr")),
1439 hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("rmw_32"), hasName("rmw_64"))))).bind("callExpr")),
1440 anything()))).bind("if"),
1441 &HandlerBranchConditionRefactoring);
1443 MatcherLoadStore.addMatcher(forStmt().bind("s"),
1445 MatcherLoadStore.addMatcher(whileStmt().bind("s"),
1447 MatcherLoadStore.addMatcher(doStmt().bind("s"),
1450 MatcherFunction.addMatcher(binaryOperator(anyOf(hasAncestor(declStmt().bind("containingStmt")),
1451 hasParent(compoundStmt())),
1452 hasOperatorName("=")).bind("op"),
1454 MatcherFunction.addMatcher(declStmt().bind("containingStmt"), &HandlerAssign);
1456 MatcherFunction.addMatcher(ifStmt().bind("if"),
1457 &HandlerAnnotateBranch);
1459 MatcherFunctionDecl.addMatcher(functionDecl().bind("fd"),
1460 &HandlerFunctionDecl);
1461 MatcherFunctionDecl.addMatcher(varDecl().bind("d"), &HandlerVarDecl);
1462 MatcherFunctionDecl.addMatcher(returnStmt(hasAncestor(functionDecl().bind("containingFunction"))).bind("returnStmt"),
1465 MatcherSanity.addMatcher(gotoStmt(), &HandlerBail);
1468 // Override the method that gets called for each parsed top-level
1470 void HandleTranslationUnit(ASTContext &Context) override {
1471 LangOpts = Context.getLangOpts();
1473 MatcherFunctionCall.matchAST(Context);
1474 MatcherLoadStore.matchAST(Context);
1475 MatcherFunction.matchAST(Context);
1476 MatcherFunctionDecl.matchAST(Context);
1477 MatcherSanity.matchAST(Context);
1479 for (auto & u : DeferredUpdates) {
1480 R.InsertText(R.getSourceMgr().getExpansionLoc(u->loc), u->update, true, true);
1483 DeferredUpdates.clear();
1487 /* DeclsRead contains all local variables 'x' which:
1488 * 1) appear in 'x = load_32(...);
1489 * 2) appear in 'y = store_32(x); */
1490 std::set<const NamedDecl *> DeclsRead, DeclsInCond;
1491 std::map<const NamedDecl *, std::string> DeclToMCVar;
1492 std::map<const Expr *, std::string> ExprToMCVar;
1493 std::set<const VarDecl *> DeclsNeedingMC;
1494 std::set<const FunctionDecl *> ThreadMains;
1495 std::set<const Stmt *> StmtsHandled;
1496 std::set<const Expr *> MallocExprs;
1497 std::map<const Expr *, SourceLocation> Redirector;
1498 std::vector<Update *> DeferredUpdates;
1502 MallocHandler HandlerMalloc;
1503 LoadHandler HandlerLoad;
1504 StoreHandler HandlerStore;
1505 RMWHandler HandlerRMW;
1506 LoopHandler HandlerLoop;
1507 BranchConditionRefactoringHandler HandlerBranchConditionRefactoring;
1508 BranchAnnotationHandler HandlerAnnotateBranch;
1509 AssignHandler HandlerAssign;
1510 FunctionDeclHandler HandlerFunctionDecl;
1511 FunctionCallHandler HandlerFunctionCall;
1512 ReturnHandler HandlerReturn;
1513 VarDeclHandler HandlerVarDecl;
1514 BailHandler HandlerBail;
1515 MatchFinder MatcherLoadStore, MatcherFunction, MatcherFunctionDecl, MatcherFunctionCall, MatcherSanity;
1518 // For each source file provided to the tool, a new FrontendAction is created.
1519 class MyFrontendAction : public ASTFrontendAction {
1521 MyFrontendAction() {}
1522 void EndSourceFileAction() override {
1523 SourceManager &SM = TheRewriter.getSourceMgr();
1524 llvm::errs() << "** EndSourceFileAction for: "
1525 << SM.getFileEntryForID(SM.getMainFileID())->getName() << "\n";
1527 // Now emit the rewritten buffer.
1528 TheRewriter.getEditBuffer(SM.getMainFileID()).write(llvm::outs());
1531 std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI,
1532 StringRef file) override {
1533 llvm::errs() << "** Creating AST consumer for: " << file << "\n";
1534 TheRewriter.setSourceMgr(CI.getSourceManager(), CI.getLangOpts());
1535 return llvm::make_unique<MyASTConsumer>(TheRewriter);
1539 Rewriter TheRewriter;
1542 int main(int argc, const char **argv) {
1543 CommonOptionsParser op(argc, argv, AddMC2AnnotationsCategory);
1544 ClangTool Tool(op.getCompilations(), op.getSourcePathList());
1546 return Tool.run(newFrontendActionFactory<MyFrontendAction>().get());