1 // -*- indent-tabs-mode:nil; c-basic-offset:4; -*-
2 //------------------------------------------------------------------------------
3 // Add MC2 annotations to C code.
4 // Copyright 2015 Patrick Lam <prof.lam@gmail.com>
6 // Permission is hereby granted, free of charge, to any person
7 // obtaining a copy of this software and associated documentation
8 // files (the "Software"), to deal with the Software without
9 // restriction, including without limitation the rights to use, copy,
10 // modify, merge, publish, distribute, sublicense, and/or sell copies
11 // of the Software, and to permit persons to whom the Software is
12 // furnished to do so, subject to the following conditions:
14 // Redistributions of source code must retain the above copyright
15 // notice, this list of conditions and the following disclaimers.
17 // Redistributions in binary form must reproduce the above copyright
18 // notice, this list of conditions and the following disclaimers in
19 // the documentation and/or other materials provided with the
22 // Neither the names of the University of Waterloo, nor the names of
23 // its contributors may be used to endorse or promote products derived
24 // from this Software without specific prior written permission.
26 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
27 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
28 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
29 // NONINFRINGEMENT. IN NO EVENT SHALL THE CONTRIBUTORS OR COPYRIGHT
30 // HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
31 // WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
32 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
33 // DEALINGS WITH THE SOFTWARE.
35 // Patrick Lam (prof.lam@gmail.com)
38 // Eli Bendersky (eliben@gmail.com)
40 //------------------------------------------------------------------------------
46 #include "clang/AST/AST.h"
47 #include "clang/AST/ASTContext.h"
48 #include "clang/AST/ASTConsumer.h"
49 #include "clang/AST/RecursiveASTVisitor.h"
50 #include "clang/ASTMatchers/ASTMatchers.h"
51 #include "clang/ASTMatchers/ASTMatchFinder.h"
52 #include "clang/Frontend/ASTConsumers.h"
53 #include "clang/Frontend/FrontendActions.h"
54 #include "clang/Frontend/CompilerInstance.h"
55 #include "clang/Lex/Lexer.h"
56 #include "clang/Tooling/CommonOptionsParser.h"
57 #include "clang/Tooling/Tooling.h"
58 #include "clang/Rewrite/Core/Rewriter.h"
59 #include "llvm/Support/raw_ostream.h"
60 #include "llvm/ADT/STLExtras.h"
62 using namespace clang;
63 using namespace clang::ast_matchers;
64 using namespace clang::driver;
65 using namespace clang::tooling;
68 static LangOptions LangOpts;
69 static llvm::cl::OptionCategory AddMC2AnnotationsCategory("Add MC2 Annotations");
71 static std::string encode(std::string varName) {
73 nn << "_m" << varName;
78 static std::string encodeFn(int num) {
85 static std::string encodePtr(int num) {
92 static std::string encodeRMW(int num) {
98 static int branchCount;
99 static std::string encodeBranch(int num) {
100 std::stringstream nn;
105 static int condCount;
106 static std::string encodeCond(int num) {
107 std::stringstream nn;
108 nn << "_cond" << num;
113 static std::string encodeRV(int num) {
114 std::stringstream nn;
119 static int funcCount;
121 struct ProvisionalName {
123 const DeclRefExpr * pname;
126 ProvisionalName(int index, const DeclRefExpr * pname) : index(index), pname(pname), length(encode(pname->getNameInfo().getName().getAsString()).length()), enabled(true) {}
127 ProvisionalName(int index, const DeclRefExpr * pname, int length) : index(index), pname(pname), length(length), enabled(true) {}
133 std::vector<ProvisionalName *> * pnames;
135 Update(SourceLocation loc, std::string update, std::vector<ProvisionalName *> * pnames) :
136 loc(loc), update(update), pnames(pnames) {}
139 for (auto pname : *pnames) delete pname;
144 void updateProvisionalName(std::vector<Update *> &DeferredUpdates, const ValueDecl * now_known, std::string mcVar) {
145 for (Update * u : DeferredUpdates) {
146 for (int i = 0; i < u->pnames->size(); i++) {
147 ProvisionalName * v = (*(u->pnames))[i];
148 if (!v->enabled) continue;
149 if (now_known == v->pname->getDecl()) {
151 std::string oldName = encode(v->pname->getNameInfo().getName().getAsString());
153 u->update.replace(v->index, v->length, mcVar);
154 for (int j = i+1; j < u->pnames->size(); j++) {
155 ProvisionalName * vv = (*(u->pnames))[j];
156 if (vv->index > v->index)
157 vv->index -= v->length - mcVar.length();
164 static const VarDecl * retrieveSingleDecl(const DeclStmt * s) {
165 // XXX iterate through all decls defined in s, not just the first one
166 assert(s->isSingleDecl() && isa<VarDecl>(s->getSingleDecl()) && "unsupported form of decl");
167 if (s->isSingleDecl() && isa<VarDecl>(s->getSingleDecl())) {
168 return cast<VarDecl>(s->getSingleDecl());
172 class FindCallArgVisitor : public RecursiveASTVisitor<FindCallArgVisitor> {
174 FindCallArgVisitor() : DE(NULL), UnaryOp(NULL) {}
176 bool VisitStmt(Stmt * s) {
178 if (UnaryOperator * uo = dyn_cast<UnaryOperator>(s)) {
179 if (uo->getOpcode() == UnaryOperatorKind::UO_AddrOf ||
180 uo->getOpcode() == UnaryOperatorKind::UO_Deref)
185 if (!DE && (DE = dyn_cast<DeclRefExpr>(s)))
191 UnaryOp = NULL; DE = NULL;
194 const UnaryOperator * RetrieveUnaryOp() {
197 const Stmt * s = UnaryOp;
203 if (const UnaryOperator * op = dyn_cast<UnaryOperator>(s))
204 s = op->getSubExpr();
205 else if (const CastExpr * op = dyn_cast<CastExpr>(s))
206 s = op->getSubExpr();
207 else if (const MemberExpr * op = dyn_cast<MemberExpr>(s))
218 const DeclRefExpr * RetrieveDeclRefExpr() {
223 const UnaryOperator * UnaryOp;
224 const DeclRefExpr * DE;
227 class FindLocalsVisitor : public RecursiveASTVisitor<FindLocalsVisitor> {
229 FindLocalsVisitor() : Vars() {}
231 bool VisitDeclRefExpr(DeclRefExpr * de) {
232 Vars.push_back(de->getDecl());
240 const TinyPtrVector<const NamedDecl *> RetrieveVars() {
245 TinyPtrVector<const NamedDecl *> Vars;
249 class MallocHandler : public MatchFinder::MatchCallback {
251 MallocHandler(std::set<const Expr *> & MallocExprs) :
252 MallocExprs(MallocExprs) {}
254 virtual void run(const MatchFinder::MatchResult &Result) {
255 const CallExpr * ce = Result.Nodes.getNodeAs<CallExpr>("callExpr");
257 MallocExprs.insert(ce);
261 std::set<const Expr *> &MallocExprs;
264 static void generateMC2Function(Rewriter & Rewrite,
269 const DeclRefExpr * lhs,
271 std::vector<ProvisionalName *> * vars1,
272 std::vector<Update *> & DeferredUpdates) {
273 // prettyprint the LHS (&newnode->value)
274 // e.g. int * _tmp0 = &newnode->value;
276 llvm::raw_string_ostream S(SStr);
277 e->printPretty(S, nullptr, Rewrite.getLangOpts());
278 const std::string &Str = S.str();
280 std::stringstream prel;
281 prel << "\nvoid * " << tmpname << " = " << Str << ";\n";
283 // MCID _p0 = MC2_function(1, MC2_PTR_LENGTH, _tmp0, _fn0);
284 prel << "MCID " << tmpFn << " = MC2_function_id(" << ++funcCount << ", 1, MC2_PTR_LENGTH, " << tmpname << ", ";
286 // XXX generate casts when they'd eliminate warnings
287 ProvisionalName * v = new ProvisionalName(prel.tellp(), lhs);
290 prel << encode(lhsName) << "); ";
292 Update * u = new Update(loc, prel.str(), vars1);
293 DeferredUpdates.push_back(u);
296 class LoadHandler : public MatchFinder::MatchCallback {
298 LoadHandler(Rewriter &Rewrite,
299 std::set<const NamedDecl *> & DeclsRead,
300 std::set<const VarDecl *> & DeclsNeedingMC,
301 std::map<const NamedDecl *, std::string> &DeclToMCVar,
302 std::map<const Expr *, std::string> &ExprToMCVar,
303 std::set<const Stmt *> & StmtsHandled,
304 std::map<const Expr *, SourceLocation> &Redirector,
305 std::vector<Update *> & DeferredUpdates) :
306 Rewrite(Rewrite), DeclsRead(DeclsRead), DeclsNeedingMC(DeclsNeedingMC), DeclToMCVar(DeclToMCVar),
307 ExprToMCVar(ExprToMCVar),
308 StmtsHandled(StmtsHandled), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
310 virtual void run(const MatchFinder::MatchResult &Result) {
311 std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
312 CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
313 const VarDecl * d = Result.Nodes.getNodeAs<VarDecl>("decl");
314 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
315 const Expr * lhs = NULL;
316 if (s && isa<BinaryOperator>(s)) lhs = cast<BinaryOperator>(s)->getLHS();
319 const DeclRefExpr * rhs = NULL;
320 MemberExpr * ml = NULL;
321 bool isAddrOfR = false, isAddrMemberR = false;
323 StmtsHandled.insert(s);
325 std::string n, n_decl;
327 FindCallArgVisitor fcaVisitor;
329 fcaVisitor.TraverseStmt(ce->getArg(0));
330 rhs = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
331 const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
332 isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
333 isAddrMemberR = ruop && isa<MemberExpr>(ruop->getSubExpr());
335 ml = dyn_cast<MemberExpr>(ruop->getSubExpr());
337 FindLocalsVisitor flv;
339 flv.TraverseStmt(const_cast<Stmt*>(cast<Stmt>(lhs)));
340 for (auto & d : flv.RetrieveVars()) {
341 const VarDecl * dd = cast<VarDecl>(d);
343 // XXX todo rhs for non-decl stmts
344 if (!isa<ParmVarDecl>(dd))
345 DeclsNeedingMC.insert(dd);
347 DeclToMCVar[dd] = encode(n);
350 FindCallArgVisitor fcaVisitor;
352 fcaVisitor.TraverseStmt(ce->getArg(0));
353 rhs = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
354 const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
355 isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
356 isAddrMemberR = ruop && isa<MemberExpr>(ruop->getSubExpr());
358 ml = dyn_cast<MemberExpr>(ruop->getSubExpr());
362 DeclsNeedingMC.insert(d);
364 DeclToMCVar[d] = encode(n);
368 fcaVisitor.TraverseStmt(ce);
369 const DeclRefExpr * dd = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
370 updateProvisionalName(DeferredUpdates, dd->getDecl(), encode(n));
371 DeclToMCVar[dd->getDecl()] = encode(n);
376 std::stringstream nol;
378 if (lhs && isa<DeclRefExpr>(lhs)) {
379 const DeclRefExpr * ll = cast<DeclRefExpr>(lhs);
380 ProvisionalName * v = new ProvisionalName(nol.tellp(), ll);
387 nol << n_decl << encode(n) << "=";
388 nol << "MC2_nextOpLoadOffset(";
390 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
392 nol << encode(rhs->getNameInfo().getName().getAsString());
394 nol << ", MC2_OFFSET(";
395 nol << ml->getBase()->getType().getAsString();
397 nol << ml->getMemberDecl()->getName().str();
399 } else if (!isAddrOfR) {
401 nol << n_decl << encode(n) << "=";
402 nol << "MC2_nextOpLoad(";
403 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
405 nol << encode(rhs->getNameInfo().getName().getAsString());
408 nol << n_decl << encode(n) << "=";
409 nol << "MC2_nextOpLoad(";
414 nol << n_decl << encode(n) << "=";
415 nol << "MC2_nextOpLoad(";
423 SourceLocation ss = s->getLocStart();
425 // if the load appears as its own stmt and is the 1st stmt, containingStmt may be the containing CompoundStmt;
426 // move over 1 so that we get the right location.
427 if (isa<CompoundStmt>(s)) ss = ss.getLocWithOffset(1);
428 const Expr * e = dyn_cast<Expr>(s);
429 if (e && Redirector.count(e) > 0)
431 Update * u = new Update(ss, nol.str(), vars);
432 DeferredUpdates.insert(DeferredUpdates.begin(), u);
437 std::set<const NamedDecl *> & DeclsRead;
438 std::set<const VarDecl *> & DeclsNeedingMC;
439 std::map<const Expr *, std::string> &ExprToMCVar;
440 std::map<const NamedDecl *, std::string> &DeclToMCVar;
441 std::set<const Stmt *> &StmtsHandled;
442 std::vector<Update *> &DeferredUpdates;
443 std::map<const Expr *, SourceLocation> &Redirector;
446 class StoreHandler : public MatchFinder::MatchCallback {
448 StoreHandler(Rewriter &Rewrite,
449 std::set<const NamedDecl *> & DeclsRead,
450 std::set<const VarDecl *> &DeclsNeedingMC,
451 std::vector<Update *> & DeferredUpdates) :
452 Rewrite(Rewrite), DeclsRead(DeclsRead), DeclsNeedingMC(DeclsNeedingMC), DeferredUpdates(DeferredUpdates) {}
454 virtual void run(const MatchFinder::MatchResult &Result) {
455 std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
456 CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
459 fcaVisitor.TraverseStmt(ce->getArg(0));
460 const DeclRefExpr * lhs = fcaVisitor.RetrieveDeclRefExpr();
461 const UnaryOperator * luop = fcaVisitor.RetrieveUnaryOp();
463 std::stringstream nol;
468 if (luop && luop->getOpcode() == UnaryOperatorKind::UO_AddrOf) {
469 isAddrMemberL = isa<MemberExpr>(luop->getSubExpr());
470 isAddrOfL = !isa<MemberExpr>(luop->getSubExpr());
475 nol << "MC2_nextOpStore(";
479 MemberExpr * ml = cast<MemberExpr>(luop->getSubExpr());
481 nol << "MC2_nextOpStoreOffset(";
483 ProvisionalName * v = new ProvisionalName(nol.tellp(), lhs);
486 nol << encode(lhs->getNameInfo().getName().getAsString());
487 if (!isa<ParmVarDecl>(lhs->getDecl()))
488 DeclsNeedingMC.insert(cast<VarDecl>(lhs->getDecl()));
490 nol << ", MC2_OFFSET(";
491 nol << ml->getBase()->getType().getAsString();
493 nol << ml->getMemberDecl()->getName().str();
496 nol << "MC2_nextOpStore(";
497 ProvisionalName * v = new ProvisionalName(nol.tellp(), lhs);
500 nol << encode(lhs->getNameInfo().getName().getAsString());
505 nol << "MC2_nextOpStore(";
512 fcaVisitor.TraverseStmt(ce->getArg(1));
513 const DeclRefExpr * rhs = fcaVisitor.RetrieveDeclRefExpr();
514 const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
516 bool isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
517 bool isDerefR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_Deref;
519 if (rhs && !isAddrOfR) {
520 assert (!isDerefR && "Must use atomic load for dereferences!");
521 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
524 nol << encode(rhs->getNameInfo().getName().getAsString());
525 DeclsRead.insert(rhs->getDecl());
531 Update * u = new Update(ce->getLocStart(), nol.str(), vars);
532 DeferredUpdates.push_back(u);
537 FindCallArgVisitor fcaVisitor;
538 std::set<const NamedDecl *> & DeclsRead;
539 std::set<const VarDecl *> & DeclsNeedingMC;
540 std::vector<Update *> &DeferredUpdates;
543 class RMWHandler : public MatchFinder::MatchCallback {
545 RMWHandler(Rewriter &rewrite,
546 std::set<const NamedDecl *> & DeclsRead,
547 std::set<const NamedDecl *> & DeclsInCond,
548 std::map<const NamedDecl *, std::string> &DeclToMCVar,
549 std::map<const Expr *, std::string> &ExprToMCVar,
550 std::set<const Stmt *> & StmtsHandled,
551 std::map<const Expr *, SourceLocation> &Redirector,
552 std::vector<Update *> & DeferredUpdates) :
553 rewrite(rewrite), DeclsRead(DeclsRead), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar),
554 ExprToMCVar(ExprToMCVar),
555 StmtsHandled(StmtsHandled), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
557 virtual void run(const MatchFinder::MatchResult &Result) {
558 CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
559 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
560 std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
562 std::stringstream nol;
564 std::string rmwMCVar;
565 rmwMCVar = encodeRMW(rmwCount++);
567 const VarDecl * rmw_lhs;
569 StmtsHandled.insert(s);
570 assert (isa<DeclStmt>(s) || isa<BinaryOperator>(s) && "unknown RMW format: not declrefexpr, not binaryoperator");
572 if ((ds = dyn_cast<DeclStmt>(s))) {
573 rmw_lhs = retrieveSingleDecl(ds);
575 const Expr * e = cast<BinaryOperator>(s)->getLHS();
576 assert (isa<DeclRefExpr>(e));
577 rmw_lhs = cast<VarDecl>(cast<DeclRefExpr>(e)->getDecl());
579 DeclToMCVar[rmw_lhs] = rmwMCVar;
582 // retrieve effective LHS of the RMW
584 fcaVisitor.TraverseStmt(ce->getArg(1));
585 const DeclRefExpr * elhs = fcaVisitor.RetrieveDeclRefExpr();
586 const UnaryOperator * eluop = fcaVisitor.RetrieveUnaryOp();
587 bool isAddrMemberL = false;
589 if (eluop && eluop->getOpcode() == UnaryOperatorKind::UO_AddrOf) {
590 isAddrMemberL = isa<MemberExpr>(eluop->getSubExpr());
593 nol << "MCID " << rmwMCVar;
595 MemberExpr * ml = cast<MemberExpr>(eluop->getSubExpr());
597 nol << " = MC2_nextRMWOffset(";
599 ProvisionalName * v = new ProvisionalName(nol.tellp(), elhs);
602 nol << encode(elhs->getNameInfo().getName().getAsString());
604 nol << ", MC2_OFFSET(";
605 nol << ml->getBase()->getType().getAsString();
607 nol << ml->getMemberDecl()->getName().str();
610 nol << " = MC2_nextRMW(";
611 bool isAddrOfL = eluop && eluop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
617 ProvisionalName * v = new ProvisionalName(nol.tellp(), elhs);
620 std::string elhsName = encode(elhs->getNameInfo().getName().getAsString());
629 // handle both RHS ops
631 for (int arg = 2; arg < 4; arg++) {
633 fcaVisitor.TraverseStmt(ce->getArg(arg));
634 const DeclRefExpr * a = fcaVisitor.RetrieveDeclRefExpr();
635 const UnaryOperator * op = fcaVisitor.RetrieveUnaryOp();
637 bool isAddrOfR = op && op->getOpcode() == UnaryOperatorKind::UO_AddrOf;
638 bool isDerefR = op && op->getOpcode() == UnaryOperatorKind::UO_Deref;
640 if (a && !isAddrOfR) {
641 assert (!isDerefR && "Must use atomic load for dereferences!");
643 DeclsInCond.insert(a->getDecl());
645 if (outputted > 0) nol << ", ";
648 bool alreadyMCVar = false;
649 if (DeclToMCVar.find(a->getDecl()) != DeclToMCVar.end()) {
651 nol << DeclToMCVar[a->getDecl()];
654 std::string an = "MCID_NODEP";
655 ProvisionalName * v = new ProvisionalName(nol.tellp(), a, an.length());
660 DeclsRead.insert(a->getDecl());
663 if (outputted > 0) nol << ", ";
671 SourceLocation place = s ? s->getLocStart() : ce->getLocStart();
672 const Expr * e = s ? dyn_cast<Expr>(s) : ce;
673 if (e && Redirector.count(e) > 0)
674 place = Redirector[e];
675 Update * u = new Update(place, nol.str(), vars);
676 DeferredUpdates.insert(DeferredUpdates.begin(), u);
681 FindCallArgVisitor fcaVisitor;
682 std::set<const NamedDecl *> &DeclsRead;
683 std::set<const NamedDecl *> &DeclsInCond;
684 std::map<const NamedDecl *, std::string> &DeclToMCVar;
685 std::map<const Expr *, std::string> &ExprToMCVar;
686 std::set<const Stmt *> &StmtsHandled;
687 std::vector<Update *> &DeferredUpdates;
688 std::map<const Expr *, SourceLocation> &Redirector;
691 class FindReturnsBreaksVisitor : public RecursiveASTVisitor<FindReturnsBreaksVisitor> {
693 FindReturnsBreaksVisitor() : Returns(), Breaks() {}
695 bool VisitStmt(Stmt * s) {
696 if (isa<ReturnStmt>(s))
697 Returns.push_back(cast<ReturnStmt>(s));
699 if (isa<BreakStmt>(s))
700 Breaks.push_back(cast<BreakStmt>(s));
705 Returns.clear(); Breaks.clear();
708 const std::vector<const ReturnStmt *> RetrieveReturns() {
712 const std::vector<const BreakStmt *> RetrieveBreaks() {
717 std::vector<const ReturnStmt *> Returns;
718 std::vector<const BreakStmt *> Breaks;
721 class LoopHandler : public MatchFinder::MatchCallback {
723 LoopHandler(Rewriter &rewrite) : rewrite(rewrite) {}
725 virtual void run(const MatchFinder::MatchResult &Result) {
726 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("s");
728 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(s->getLocStart()),
729 "MC2_enterLoop();\n", true, true);
731 // annotate all returns with MC2_exitLoop()
732 // annotate all breaks that aren't further nested with MC2_exitLoop().
733 FindReturnsBreaksVisitor frbv;
735 frbv.TraverseStmt(const_cast<Stmt *>(cast<ForStmt>(s)->getBody()));
736 if (isa<WhileStmt>(s))
737 frbv.TraverseStmt(const_cast<Stmt *>(cast<WhileStmt>(s)->getBody()));
739 frbv.TraverseStmt(const_cast<Stmt *>(cast<DoStmt>(s)->getBody()));
741 for (auto & r : frbv.RetrieveReturns()) {
742 rewrite.InsertText(r->getLocStart(), "MC2_exitLoop();\n", true, true);
745 // need to find all breaks and returns embedded inside the loop
747 rewrite.InsertTextAfterToken(rewrite.getSourceMgr().getExpansionLoc(s->getLocEnd().getLocWithOffset(1)),
748 "\nMC2_exitLoop();\n");
755 /* Inserts MC2_function for any variables which are subsequently used by the model checker, as long as they depend on MC-visible [currently: read] state. */
756 class AssignHandler : public MatchFinder::MatchCallback {
758 AssignHandler(Rewriter &rewrite, std::set<const NamedDecl *> &DeclsRead,
759 std::set<const NamedDecl *> &DeclsInCond,
760 std::set<const VarDecl *> &DeclsNeedingMC,
761 std::map<const NamedDecl *, std::string> &DeclToMCVar,
762 std::set<const Stmt *> &StmtsHandled,
763 std::set<const Expr *> &MallocExprs,
764 std::vector<Update *> &DeferredUpdates) :
766 DeclsRead(DeclsRead),
767 DeclsInCond(DeclsInCond),
768 DeclsNeedingMC(DeclsNeedingMC),
769 DeclToMCVar(DeclToMCVar),
770 StmtsHandled(StmtsHandled),
771 MallocExprs(MallocExprs),
772 DeferredUpdates(DeferredUpdates) {}
774 virtual void run(const MatchFinder::MatchResult &Result) {
775 BinaryOperator * op = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("op"));
776 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
777 FindLocalsVisitor locals, locals_rhs;
779 const VarDecl * lhs = NULL;
780 const Expr * rhs = NULL;
783 if (s && (ds = dyn_cast<DeclStmt>(s))) {
784 // long term goal: refactor the run() method to deal with one assignment at a time
785 // for now, if there is only declarations and no rhs's, we'll ignore this stmt
786 if (!ds->isSingleDecl()) {
787 for (auto & d : ds->decls()) {
788 VarDecl * vd = dyn_cast<VarDecl>(d);
789 if (!d || vd->hasInit())
790 assert(0 && "unsupported form of decl");
795 lhs = retrieveSingleDecl(ds);
798 if (StmtsHandled.find(ds) != StmtsHandled.end() || StmtsHandled.find(op) != StmtsHandled.end())
802 if (lhs->hasInit()) {
803 rhs = lhs->getInit();
805 rhs = rhs->IgnoreCasts();
811 std::set<std::string> mcState;
814 bool rhsRead = false;
816 bool lhsTooComplicated = false;
819 if ((vd = dyn_cast<DeclRefExpr>(op->getLHS())))
820 lhs = dyn_cast<VarDecl>(vd->getDecl());
822 // kick the can along...
823 lhsTooComplicated = true;
828 rhs = rhs->IgnoreCasts();
831 // rhs must be MC-active state, i.e. in declsread
832 // lhs must be subsequently used in (1) store/load or (2) branch condition or (3) other functions and (3a) uses values from other functions or (3b) uses values from loads, stores, or phi functions
835 locals_rhs.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
836 for (auto & nd : locals_rhs.RetrieveVars()) {
837 if (DeclsRead.find(nd) != DeclsRead.end())
842 locals.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
844 lhsUsedInCond = DeclsInCond.find(lhs) != DeclsInCond.end();
846 for (auto & d : locals.RetrieveVars()) {
847 if (DeclToMCVar.count(d) > 0)
848 mcState.insert(DeclToMCVar[d]);
849 else if (DeclsRead.find(d) != DeclsRead.end())
850 mcState.insert(encode(d->getName().str()));
854 for (auto & d : locals_rhs.RetrieveVars()) {
855 if (DeclToMCVar.count(d) > 0)
856 mcState.insert(DeclToMCVar[d]);
857 else if (DeclsRead.find(d) != DeclsRead.end())
858 mcState.insert(encode(d->getName().str()));
861 if (mcState.size() > 0 || MallocExprs.find(rhs) != MallocExprs.end()) {
862 if (lhsTooComplicated)
863 assert(0 && "couldn't find LHS of = operator");
865 std::stringstream nol;
866 std::string _lhsStr, lhsStr;
867 std::string mcVar = encodeFn(fnCount++);
869 lhsStr = lhs->getName().str();
870 _lhsStr = encode(lhsStr);
871 DeclToMCVar[lhs] = mcVar;
872 DeclsNeedingMC.insert(cast<VarDecl>(lhs));
875 if (!(MallocExprs.find(rhs) != MallocExprs.end()))
876 function_id = ++funcCount;
877 nol << "\n" << mcVar << " = MC2_function_id(" << function_id << ", " << mcState.size();
879 nol << ", sizeof (" << lhsStr << "), (uint64_t)" << lhsStr;
881 nol << ", MC2_PTR_LENGTH";
882 for (auto & d : mcState) {
890 SourceLocation place;
892 place = Lexer::getLocForEndOfToken(op->getLocEnd(), 0, rewrite.getSourceMgr(), rewrite.getLangOpts()).getLocWithOffset(1);
894 place = s->getLocEnd();
895 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(place.getLocWithOffset(1)),
896 nol.str(), true, true);
898 updateProvisionalName(DeferredUpdates, lhs, mcVar);
904 std::set<const NamedDecl *> &DeclsRead, &DeclsInCond;
905 std::set<const VarDecl *> &DeclsNeedingMC;
906 std::map<const NamedDecl *, std::string> &DeclToMCVar;
907 std::set<const Stmt *> &StmtsHandled;
908 std::set<const Expr *> &MallocExprs;
909 std::vector<Update *> &DeferredUpdates;
912 // record vars used in conditions
913 class BranchConditionRefactoringHandler : public MatchFinder::MatchCallback {
915 BranchConditionRefactoringHandler(Rewriter &rewrite,
916 std::set<const NamedDecl *> & DeclsInCond,
917 std::map<const NamedDecl *, std::string> &DeclToMCVar,
918 std::map<const Expr *, std::string> &ExprToMCVar,
919 std::map<const Expr *, SourceLocation> &Redirector,
920 std::vector<Update *> &DeferredUpdates) :
921 rewrite(rewrite), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar),
922 ExprToMCVar(ExprToMCVar), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
924 virtual void run(const MatchFinder::MatchResult &Result) {
925 IfStmt * is = const_cast<IfStmt *>(Result.Nodes.getNodeAs<IfStmt>("if"));
926 Expr * cond = is->getCond();
928 // refactor out complicated conditions
929 FindCallArgVisitor flv;
930 flv.TraverseStmt(cond);
933 BinaryOperator * bc = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("bc"));
935 std::string condVar = encodeCond(condCount++);
936 std::stringstream condVarEncoded;
937 condVarEncoded << condVar << "_m";
939 // prettyprint the binary op
940 // e.g. int _cond0 = x == y;
942 llvm::raw_string_ostream S(SStr);
943 bc->printPretty(S, nullptr, rewrite.getLangOpts());
944 const std::string &Str = S.str();
946 std::stringstream prel;
948 bool is_equality = false;
949 // handle equality tests
950 if (bc->getOpcode() == BO_EQ) {
951 Expr * lhs = bc->getLHS()->IgnoreCasts(), * rhs = bc->getRHS()->IgnoreCasts();
952 if (isa<DeclRefExpr>(lhs) && isa<DeclRefExpr>(rhs)) {
953 DeclRefExpr * l = dyn_cast<DeclRefExpr>(lhs), *r = dyn_cast<DeclRefExpr>(rhs);
955 prel << "\nMCID " << condVarEncoded.str() << ";\n";
957 if (DeclToMCVar.find(l->getDecl()) != DeclToMCVar.end())
958 ld = DeclToMCVar.find(l->getDecl())->second;
960 ld = encode(l->getDecl()->getName());
961 if (DeclToMCVar.find(r->getDecl()) != DeclToMCVar.end())
962 rd = DeclToMCVar.find(r->getDecl())->second;
964 rd = encode(r->getDecl()->getName());
966 prel << "\nint " << condVar << " = MC2_equals(" <<
967 ld << ", (uint64_t)" << l->getNameInfo().getName().getAsString() << ", " <<
968 rd << ", (uint64_t)" << r->getNameInfo().getName().getAsString() << ", " <<
969 "&" << condVarEncoded.str() << ");\n";
974 prel << "\nint " << condVar << " = " << Str << ";";
975 prel << "\nMCID " << condVarEncoded.str() << " = MC2_function_id(" << ++funcCount << ", 1, sizeof(" << condVar << "), " << condVar;
976 const DeclRefExpr * d = flv.RetrieveDeclRefExpr();
977 if (DeclToMCVar.find(d->getDecl()) != DeclToMCVar.end()) {
978 prel << ", " << DeclToMCVar[d->getDecl()];
983 ExprToMCVar[cond] = condVarEncoded.str();
984 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(is->getLocStart()),
985 prel.str(), false, true);
987 // rewrite the binary op with the newly-inserted var
988 Expr * RO = bc->getRHS(); // used for location only
990 int cl = Lexer::MeasureTokenLength(RO->getLocStart(), rewrite.getSourceMgr(), rewrite.getLangOpts());
991 SourceRange SR(cond->getLocStart(), rewrite.getSourceMgr().getExpansionLoc(RO->getLocStart()).getLocWithOffset(cl-1));
992 rewrite.ReplaceText(SR, condVar);
994 std::string condVar = encodeCond(condCount++);
995 std::stringstream condVarEncoded;
996 condVarEncoded << condVar << "_m";
999 llvm::raw_string_ostream S(SStr);
1000 cond->printPretty(S, nullptr, rewrite.getLangOpts());
1001 const std::string &Str = S.str();
1003 std::stringstream prel;
1004 prel << "\nint " << condVar << " = " << Str << ";";
1005 prel << "\nMCID " << condVarEncoded.str() << " = MC2_function_id(" << ++funcCount << ", 1, sizeof(" << condVar << "), " << condVar;
1006 std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
1007 const DeclRefExpr * d = flv.RetrieveDeclRefExpr();
1008 if (isa<VarDecl>(d->getDecl()) && DeclToMCVar.find(d->getDecl()) != DeclToMCVar.end()) {
1009 prel << ", " << DeclToMCVar[d->getDecl()];
1012 ProvisionalName * v = new ProvisionalName(prel.tellp(), d, 0);
1017 ExprToMCVar[cond] = condVarEncoded.str();
1018 // gross hack; should look for any callexprs in cond
1019 // but right now, if it's a unaryop, just manually traverse
1020 if (isa<UnaryOperator>(cond)) {
1021 Expr * e = dyn_cast<UnaryOperator>(cond)->getSubExpr();
1022 ExprToMCVar[e] = condVarEncoded.str();
1024 Update * u = new Update(is->getLocStart(), prel.str(), vars);
1025 DeferredUpdates.push_back(u);
1027 // rewrite the call op with the newly-inserted var
1028 SourceRange SR(cond->getLocStart(), cond->getLocEnd());
1029 Redirector[cond] = is->getLocStart();
1030 rewrite.ReplaceText(SR, condVar);
1033 std::deque<const Decl *> q;
1034 const NamedDecl * d = flv.RetrieveDeclRefExpr()->getDecl();
1036 while (!q.empty()) {
1037 const Decl * d = q.back();
1039 if (isa<NamedDecl>(d))
1040 DeclsInCond.insert(cast<NamedDecl>(d));
1043 if ((vd = dyn_cast<VarDecl>(d))) {
1044 if (vd->hasInit()) {
1045 const Expr * e = vd->getInit();
1047 flv.TraverseStmt(const_cast<Expr *>(e));
1048 const NamedDecl * d = flv.RetrieveDeclRefExpr()->getDecl();
1057 std::set<const NamedDecl *> & DeclsInCond;
1058 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1059 std::map<const Expr *, std::string> &ExprToMCVar;
1060 std::map<const Expr *, SourceLocation> &Redirector;
1061 std::vector<Update *> &DeferredUpdates;
1064 class BranchAnnotationHandler : public MatchFinder::MatchCallback {
1066 BranchAnnotationHandler(Rewriter &rewrite,
1067 std::map<const NamedDecl *, std::string> & DeclToMCVar,
1068 std::map<const Expr *, std::string> & ExprToMCVar)
1070 DeclToMCVar(DeclToMCVar),
1071 ExprToMCVar(ExprToMCVar){}
1072 virtual void run(const MatchFinder::MatchResult &Result) {
1073 IfStmt * is = const_cast<IfStmt *>(Result.Nodes.getNodeAs<IfStmt>("if"));
1075 // if the branch condition is interesting:
1076 // (but right now, not too interesting)
1077 Expr * cond = is->getCond()->IgnoreCasts();
1079 FindLocalsVisitor flv;
1080 flv.TraverseStmt(cond);
1081 if (flv.RetrieveVars().size() == 0) return;
1083 const NamedDecl * condVar = flv.RetrieveVars()[0];
1085 std::string mCondVar;
1086 if (ExprToMCVar.count(cond) > 0)
1087 mCondVar = ExprToMCVar[cond];
1088 else if (DeclToMCVar.count(condVar) > 0)
1089 mCondVar = DeclToMCVar[condVar];
1091 mCondVar = encode(condVar->getName());
1092 std::string brVar = encodeBranch(branchCount++);
1094 std::stringstream brline;
1095 brline << "MCID " << brVar << ";\n";
1096 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(is->getLocStart()),
1097 brline.str(), false, true);
1099 Stmt * ts = is->getThen(), * es = is->getElse();
1100 bool tHasChild = hasChild(ts);
1103 if (isa<CompoundStmt>(ts))
1104 tfl = getFirstChild(ts)->getLocStart();
1106 tfl = ts->getLocStart();
1108 tfl = ts->getLocStart().getLocWithOffset(1);
1109 SourceLocation tsl = ts->getLocEnd().getLocWithOffset(-1);
1111 std::stringstream tlineStart, mergeStmt, eline;
1113 UnaryOperator * uop = dyn_cast<UnaryOperator>(cond);
1114 tlineStart << brVar << " = MC2_branchUsesID(" << mCondVar << ", " << "1" << ", 2, true);\n";
1115 eline << brVar << " = MC2_branchUsesID(" << mCondVar << ", " << "0" << ", 2, true);";
1117 mergeStmt << "\tMC2_merge(" << brVar << ");\n";
1119 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(tfl), tlineStart.str(), false, true);
1122 int extra_else_offset = 0;
1124 if (tHasChild) { tls = getLastChild(ts); }
1125 if (tls) extra_else_offset = 2; else extra_else_offset = 1;
1127 if (!tHasChild || (!isa<ReturnStmt>(tls) && !isa<BreakStmt>(tls))) {
1128 extra_else_offset = 0;
1129 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(tsl.getLocWithOffset(1)),
1130 mergeStmt.str(), true, true);
1132 if (tHasChild && !isa<CompoundStmt>(ts)) {
1133 rewrite.InsertText(rewrite.getSourceMgr().getFileLoc(tls->getLocStart()), "{", false, true);
1134 SourceLocation tend = Lexer::findLocationAfterToken(tls->getLocEnd(), tok::semi, rewrite.getSourceMgr(), rewrite.getLangOpts(), false);
1135 rewrite.InsertText(tend, "}", true, true);
1137 if (tHasChild && isa<CompoundStmt>(ts)) extra_else_offset++;
1140 SourceLocation esl = es->getLocEnd().getLocWithOffset(-1);
1141 bool eHasChild = hasChild(es);
1143 if (eHasChild) els = getLastChild(es); else els = es;
1149 if (isa<CompoundStmt>(es))
1150 el = getFirstChild(es)->getLocStart();
1152 el = es->getLocStart();
1155 el = es->getLocStart().getLocWithOffset(1);
1156 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(el), eline.str(), false, true);
1158 if (eHasChild && !isa<CompoundStmt>(es)) {
1159 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(el), "{", false, true);
1160 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(es->getLocEnd().getLocWithOffset(1)), "}", true, true);
1163 if (!eHasChild || (!isa<ReturnStmt>(els) && !isa<BreakStmt>(els)))
1164 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(esl.getLocWithOffset(1)), mergeStmt.str(), true, true);
1167 std::stringstream eCompoundLine;
1168 eCompoundLine << " else { " << eline.str() << mergeStmt.str() << " }";
1169 SourceLocation tend = Lexer::findLocationAfterToken(ts->getLocEnd(), tok::semi, rewrite.getSourceMgr(), rewrite.getLangOpts(), false);
1170 if (!tend.isValid())
1171 tend = Lexer::getLocForEndOfToken(ts->getLocEnd(), 0, rewrite.getSourceMgr(), rewrite.getLangOpts());
1172 rewrite.InsertText(tend.getLocWithOffset(1), eCompoundLine.str(), false, true);
1177 bool hasChild(Stmt * s) {
1178 if (!isa<CompoundStmt>(s)) return true;
1179 return (!cast<CompoundStmt>(s)->body_empty());
1182 Stmt * getFirstChild(Stmt * s) {
1183 assert(isa<CompoundStmt>(s) && "haven't yet added code to rewrite then/elsestmt to CompoundStmt");
1184 assert(!cast<CompoundStmt>(s)->body_empty());
1185 return *(cast<CompoundStmt>(s)->body_begin());
1188 Stmt * getLastChild(Stmt * s) {
1190 if ((cs = dyn_cast<CompoundStmt>(s))) {
1191 assert (!cs->body_empty());
1192 return cs->body_back();
1198 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1199 std::map<const Expr *, std::string> &ExprToMCVar;
1202 class FunctionCallHandler : public MatchFinder::MatchCallback {
1204 FunctionCallHandler(Rewriter &rewrite,
1205 std::map<const NamedDecl *, std::string> &DeclToMCVar,
1206 std::set<const FunctionDecl *> &ThreadMains)
1207 : rewrite(rewrite), DeclToMCVar(DeclToMCVar), ThreadMains(ThreadMains) {}
1209 virtual void run(const MatchFinder::MatchResult &Result) {
1210 CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
1211 Decl * d = ce->getCalleeDecl();
1212 NamedDecl * nd = dyn_cast<NamedDecl>(d);
1213 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
1214 ASTContext *Context = Result.Context;
1216 if (nd->getName() == "thrd_create") {
1217 Expr * callee0 = ce->getArg(1)->IgnoreCasts();
1218 UnaryOperator * callee1;
1219 if ((callee1 = dyn_cast<UnaryOperator>(callee0))) {
1220 if (callee1->getOpcode() == UnaryOperatorKind::UO_AddrOf)
1221 callee0 = callee1->getSubExpr();
1223 DeclRefExpr * callee = dyn_cast<DeclRefExpr>(callee0);
1224 if (!callee) return;
1225 FunctionDecl * fd = dyn_cast<FunctionDecl>(callee->getDecl());
1226 ThreadMains.insert(fd);
1233 if (s && !ce->getCallReturnType(*Context)->isVoidType()) {
1234 // TODO check that the type is mc-visible also?
1235 const DeclStmt * ds;
1236 const VarDecl * lhs = NULL;
1237 std::string mc_rv = encodeRV(rvCount++);
1239 std::stringstream brline;
1240 brline << "MCID " << mc_rv << ";\n";
1241 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(s->getLocStart()),
1242 brline.str(), false, true);
1244 std::stringstream nol;
1245 if (ce->getNumArgs() > 0) nol << ", ";
1246 nol << "&" << mc_rv;
1247 rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(ce->getRParenLoc()),
1250 if (s && (ds = dyn_cast<DeclStmt>(s))) {
1251 if (!ds->isSingleDecl()) {
1252 for (auto & d : ds->decls()) {
1253 VarDecl * vd = dyn_cast<VarDecl>(d);
1254 if (!d || vd->hasInit())
1255 assert(0 && "unsupported form of decl");
1260 lhs = retrieveSingleDecl(ds);
1263 DeclToMCVar[lhs] = mc_rv;
1266 for (const auto & a : ce->arguments()) {
1267 std::stringstream nol;
1269 std::string aa = "MCID_NODEP";
1271 Expr * e = a->IgnoreCasts();
1272 DeclRefExpr * dr = dyn_cast<DeclRefExpr>(e);
1274 NamedDecl * d = dr->getDecl();
1275 if (DeclToMCVar.find(d) != DeclToMCVar.end())
1276 aa = DeclToMCVar[d];
1281 if (a->getLocEnd().isValid())
1282 rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(a->getLocStart()),
1289 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1290 std::set<const FunctionDecl *> &ThreadMains;
1293 class ReturnHandler : public MatchFinder::MatchCallback {
1295 ReturnHandler(Rewriter &rewrite,
1296 std::map<const NamedDecl *, std::string> &DeclToMCVar,
1297 std::set<const FunctionDecl *> &ThreadMains)
1298 : rewrite(rewrite), DeclToMCVar(DeclToMCVar), ThreadMains(ThreadMains) {}
1300 virtual void run(const MatchFinder::MatchResult &Result) {
1301 const FunctionDecl * fd = Result.Nodes.getNodeAs<FunctionDecl>("containingFunction");
1302 ReturnStmt * rs = const_cast<ReturnStmt *>(Result.Nodes.getNodeAs<ReturnStmt>("returnStmt"));
1303 Expr * rv = const_cast<Expr *>(rs->getRetValue());
1306 if (ThreadMains.find(fd) != ThreadMains.end()) return;
1307 // not sure why this is explicitly needed, but crashes without it
1308 if (!fd->getIdentifier() || fd->getName() == "user_main") return;
1310 FindLocalsVisitor flv;
1311 flv.TraverseStmt(rv);
1312 std::string mrv = "MCID_NODEP";
1314 if (flv.RetrieveVars().size() > 0) {
1315 const NamedDecl * returnVar = flv.RetrieveVars()[0];
1316 if (DeclToMCVar.find(returnVar) != DeclToMCVar.end()) {
1317 mrv = DeclToMCVar[returnVar];
1320 std::stringstream nol;
1321 nol << "*retval = " << mrv << ";\n";
1322 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(rs->getLocStart()),
1323 nol.str(), false, true);
1328 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1329 std::set<const FunctionDecl *> &ThreadMains;
1332 class VarDeclHandler : public MatchFinder::MatchCallback {
1334 VarDeclHandler(Rewriter &rewrite,
1335 std::map<const NamedDecl *, std::string> &DeclToMCVar,
1336 std::set<const VarDecl *> &DeclsNeedingMC)
1337 : rewrite(rewrite), DeclToMCVar(DeclToMCVar), DeclsNeedingMC(DeclsNeedingMC) {}
1339 virtual void run(const MatchFinder::MatchResult &Result) {
1340 VarDecl * d = const_cast<VarDecl *>(Result.Nodes.getNodeAs<VarDecl>("d"));
1341 std::stringstream nol;
1343 if (DeclsNeedingMC.find(d) == DeclsNeedingMC.end()) return;
1346 if (DeclToMCVar.find(d) != DeclToMCVar.end())
1347 dn = DeclToMCVar[d];
1349 dn = encode(d->getName().str());
1351 nol << "MCID " << dn << "; ";
1353 if (d->getLocStart().isValid())
1354 rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(d->getLocStart()),
1360 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1361 std::set<const VarDecl *> &DeclsNeedingMC;
1364 class FunctionDeclHandler : public MatchFinder::MatchCallback {
1366 FunctionDeclHandler(Rewriter &rewrite,
1367 std::set<const FunctionDecl *> &ThreadMains)
1368 : rewrite(rewrite), ThreadMains(ThreadMains) {}
1370 virtual void run(const MatchFinder::MatchResult &Result) {
1371 FunctionDecl * fd = const_cast<FunctionDecl *>(Result.Nodes.getNodeAs<FunctionDecl>("fd"));
1373 if (!fd->getIdentifier()) return;
1375 if (fd->getName() == "user_main") { ThreadMains.insert(fd); return; }
1377 if (ThreadMains.find(fd) != ThreadMains.end()) return;
1379 SourceLocation LastParam = fd->getNameInfo().getLocStart().getLocWithOffset(fd->getName().size()).getLocWithOffset(1);
1380 for (auto & p : fd->params()) {
1381 std::stringstream nol;
1382 nol << "MCID " << encode(p->getName()) << ", ";
1383 if (p->getLocStart().isValid())
1384 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(p->getLocStart()),
1386 if (p->getLocEnd().isValid())
1387 LastParam = p->getLocEnd().getLocWithOffset(p->getName().size());
1390 if (!fd->getReturnType()->isVoidType()) {
1391 std::stringstream nol;
1392 if (fd->param_size() > 0) nol << ", ";
1393 nol << "MCID * retval";
1394 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(LastParam),
1401 std::set<const FunctionDecl *> &ThreadMains;
1404 class BailHandler : public MatchFinder::MatchCallback {
1407 virtual void run(const MatchFinder::MatchResult &Result) {
1408 assert(0 && "we don't handle goto statements");
1412 class MyASTConsumer : public ASTConsumer {
1414 MyASTConsumer(Rewriter &R) : R(R),
1418 HandlerMalloc(MallocExprs),
1419 HandlerLoad(R, DeclsRead, DeclsNeedingMC, DeclToMCVar, ExprToMCVar, StmtsHandled, Redirector, DeferredUpdates),
1420 HandlerStore(R, DeclsRead, DeclsNeedingMC, DeferredUpdates),
1421 HandlerRMW(R, DeclsRead, DeclsInCond, DeclToMCVar, ExprToMCVar, StmtsHandled, Redirector, DeferredUpdates),
1423 HandlerBranchConditionRefactoring(R, DeclsInCond, DeclToMCVar, ExprToMCVar, Redirector, DeferredUpdates),
1424 HandlerAssign(R, DeclsRead, DeclsInCond, DeclsNeedingMC, DeclToMCVar, StmtsHandled, MallocExprs, DeferredUpdates),
1425 HandlerAnnotateBranch(R, DeclToMCVar, ExprToMCVar),
1426 HandlerFunctionDecl(R, ThreadMains),
1427 HandlerFunctionCall(R, DeclToMCVar, ThreadMains),
1428 HandlerReturn(R, DeclToMCVar, ThreadMains),
1429 HandlerVarDecl(R, DeclToMCVar, DeclsNeedingMC),
1431 MatcherFunctionCall.addMatcher(callExpr(anyOf(hasParent(compoundStmt()),
1432 hasAncestor(varDecl(hasParent(stmt().bind("containingStmt")))),
1433 hasAncestor(binaryOperator(hasOperatorName("=")).bind("containingStmt")))).bind("callExpr"),
1434 &HandlerFunctionCall);
1435 MatcherLoadStore.addMatcher
1436 (callExpr(callee(functionDecl(anyOf(hasName("malloc"), hasName("calloc"))))).bind("callExpr"),
1439 MatcherLoadStore.addMatcher
1440 (callExpr(callee(functionDecl(anyOf(hasName("load_32"), hasName("load_64")))),
1441 anyOf(hasAncestor(varDecl(hasParent(stmt().bind("containingStmt"))).bind("decl")),
1442 hasAncestor(binaryOperator(hasOperatorName("=")).bind("containingStmt")),
1443 hasParent(stmt().bind("containingStmt"))))
1447 MatcherLoadStore.addMatcher(callExpr(callee(functionDecl(anyOf(hasName("store_32"), hasName("store_64"))))).bind("callExpr"),
1450 MatcherLoadStore.addMatcher
1451 (callExpr(callee(functionDecl(anyOf(hasName("rmw_32"), hasName("rmw_64")))),
1452 anyOf(hasAncestor(varDecl(hasParent(stmt().bind("containingStmt"))).bind("decl")),
1453 hasAncestor(binaryOperator(hasOperatorName("="),
1454 hasLHS(declRefExpr().bind("lhs"))).bind("containingStmt")),
1459 MatcherLoadStore.addMatcher(ifStmt(hasCondition
1460 (anyOf(binaryOperator().bind("bc"),
1461 hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("load_32"), hasName("load_64"))))).bind("callExpr")),
1462 hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("store_32"), hasName("store_64"))))).bind("callExpr")),
1463 hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("rmw_32"), hasName("rmw_64"))))).bind("callExpr")),
1464 anything()))).bind("if"),
1465 &HandlerBranchConditionRefactoring);
1467 MatcherLoadStore.addMatcher(forStmt().bind("s"),
1469 MatcherLoadStore.addMatcher(whileStmt().bind("s"),
1471 MatcherLoadStore.addMatcher(doStmt().bind("s"),
1474 MatcherFunction.addMatcher(binaryOperator(anyOf(hasAncestor(declStmt().bind("containingStmt")),
1475 hasParent(compoundStmt())),
1476 hasOperatorName("=")).bind("op"),
1478 MatcherFunction.addMatcher(declStmt().bind("containingStmt"), &HandlerAssign);
1480 MatcherFunction.addMatcher(ifStmt().bind("if"),
1481 &HandlerAnnotateBranch);
1483 MatcherFunctionDecl.addMatcher(functionDecl().bind("fd"),
1484 &HandlerFunctionDecl);
1485 MatcherFunctionDecl.addMatcher(varDecl().bind("d"), &HandlerVarDecl);
1486 MatcherFunctionDecl.addMatcher(returnStmt(hasAncestor(functionDecl().bind("containingFunction"))).bind("returnStmt"),
1489 MatcherSanity.addMatcher(gotoStmt(), &HandlerBail);
1492 // Override the method that gets called for each parsed top-level
1494 void HandleTranslationUnit(ASTContext &Context) override {
1495 LangOpts = Context.getLangOpts();
1497 MatcherFunctionCall.matchAST(Context);
1498 MatcherLoadStore.matchAST(Context);
1499 MatcherFunction.matchAST(Context);
1500 MatcherFunctionDecl.matchAST(Context);
1501 MatcherSanity.matchAST(Context);
1503 for (auto & u : DeferredUpdates) {
1504 R.InsertText(R.getSourceMgr().getExpansionLoc(u->loc), u->update, true, true);
1507 DeferredUpdates.clear();
1511 /* DeclsRead contains all local variables 'x' which:
1512 * 1) appear in 'x = load_32(...);
1513 * 2) appear in 'y = store_32(x); */
1514 std::set<const NamedDecl *> DeclsRead;
1515 /* DeclsInCond contains all local variables 'x' used in a branch condition or rmw parameter */
1516 std::set<const NamedDecl *> DeclsInCond;
1517 std::map<const NamedDecl *, std::string> DeclToMCVar;
1518 std::map<const Expr *, std::string> ExprToMCVar;
1519 std::set<const VarDecl *> DeclsNeedingMC;
1520 std::set<const FunctionDecl *> ThreadMains;
1521 std::set<const Stmt *> StmtsHandled;
1522 std::set<const Expr *> MallocExprs;
1523 std::map<const Expr *, SourceLocation> Redirector;
1524 std::vector<Update *> DeferredUpdates;
1528 MallocHandler HandlerMalloc;
1529 LoadHandler HandlerLoad;
1530 StoreHandler HandlerStore;
1531 RMWHandler HandlerRMW;
1532 LoopHandler HandlerLoop;
1533 BranchConditionRefactoringHandler HandlerBranchConditionRefactoring;
1534 BranchAnnotationHandler HandlerAnnotateBranch;
1535 AssignHandler HandlerAssign;
1536 FunctionDeclHandler HandlerFunctionDecl;
1537 FunctionCallHandler HandlerFunctionCall;
1538 ReturnHandler HandlerReturn;
1539 VarDeclHandler HandlerVarDecl;
1540 BailHandler HandlerBail;
1541 MatchFinder MatcherLoadStore, MatcherFunction, MatcherFunctionDecl, MatcherFunctionCall, MatcherSanity;
1544 // For each source file provided to the tool, a new FrontendAction is created.
1545 class MyFrontendAction : public ASTFrontendAction {
1547 MyFrontendAction() {}
1548 void EndSourceFileAction() override {
1549 SourceManager &SM = TheRewriter.getSourceMgr();
1550 llvm::errs() << "** EndSourceFileAction for: "
1551 << SM.getFileEntryForID(SM.getMainFileID())->getName() << "\n";
1553 // Now emit the rewritten buffer.
1554 TheRewriter.getEditBuffer(SM.getMainFileID()).write(llvm::outs());
1557 std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI,
1558 StringRef file) override {
1559 llvm::errs() << "** Creating AST consumer for: " << file << "\n";
1560 TheRewriter.setSourceMgr(CI.getSourceManager(), CI.getLangOpts());
1561 return llvm::make_unique<MyASTConsumer>(TheRewriter);
1565 Rewriter TheRewriter;
1568 int main(int argc, const char **argv) {
1569 CommonOptionsParser op(argc, argv, AddMC2AnnotationsCategory);
1570 ClangTool Tool(op.getCompilations(), op.getSourcePathList());
1572 return Tool.run(newFrontendActionFactory<MyFrontendAction>().get());