add default MC-names for not-found variables in equality test
[satcheck.git] / clang / src / add_mc2_annotations.cpp
index 1b7c92604e31547ebb3e3ce77746fda8b4df9062..740a710ca8f965355f74831d536d49b3cc873097 100644 (file)
@@ -1,4 +1,4 @@
-// -*-  indent-tabs-mode:nil;  -*-
+// -*-  indent-tabs-mode:nil; c-basic-offset:4; -*-
 //------------------------------------------------------------------------------
 // Add MC2 annotations to C code.
 // Copyright 2015 Patrick Lam <prof.lam@gmail.com>
@@ -725,7 +725,8 @@ public:
     virtual void run(const MatchFinder::MatchResult &Result) {
         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("s");
 
-        rewrite.InsertText(s->getLocStart(), "MC2_enterLoop();\n", true, true);
+        rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(s->getLocStart()),
+                           "MC2_enterLoop();\n", true, true);
 
         // annotate all returns with MC2_exitLoop()
         // annotate all breaks that aren't further nested with MC2_exitLoop().
@@ -743,7 +744,8 @@ public:
         
         // need to find all breaks and returns embedded inside the loop
 
-        rewrite.InsertTextAfterToken(s->getLocEnd().getLocWithOffset(1), "\nMC2_exitLoop();\n");
+        rewrite.InsertTextAfterToken(rewrite.getSourceMgr().getExpansionLoc(s->getLocEnd().getLocWithOffset(1)),
+                                     "\nMC2_exitLoop();\n");
     }
 
 private:
@@ -772,7 +774,7 @@ public:
     virtual void run(const MatchFinder::MatchResult &Result) {
         BinaryOperator * op = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("op"));
         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
-        FindLocalsVisitor flv;
+        FindLocalsVisitor locals, locals_rhs;
 
         const VarDecl * lhs = NULL;
         const Expr * rhs = NULL;
@@ -808,10 +810,11 @@ public:
         }
         std::set<std::string> mcState;
 
+        bool lhsUsedInCond;
+        bool rhsRead = false;
+
         bool lhsTooComplicated = false;
         if (op) {
-            flv.TraverseStmt(op);
-
             DeclRefExpr * vd;
             if ((vd = dyn_cast<DeclRefExpr>(op->getLHS())))
                 lhs = dyn_cast<VarDecl>(vd->getDecl());
@@ -824,21 +827,37 @@ public:
             if (rhs) 
                 rhs = rhs->IgnoreCasts();
         }
-        else if (lhs) {
-            // rhs must be MC-active state, i.e. in declsread
-            // lhs must be subsequently used in (1) store/load or (2) branch condition or (3) other functions and (3a) uses values from other functions or (3b) uses values from loads, stores, or phi functions
-            flv.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
+
+        // rhs must be MC-active state, i.e. in declsread
+        // lhs must be subsequently used in (1) store/load or (2) branch condition or (3) other functions and (3a) uses values from other functions or (3b) uses values from loads, stores, or phi functions
+
+        if (rhs) {
+            locals_rhs.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
+            for (auto & nd : locals_rhs.RetrieveVars()) {
+                if (DeclsRead.find(nd) != DeclsRead.end())
+                    rhsRead = true;
+            }
         }
 
-        if (DeclsInCond.find(lhs) != DeclsInCond.end()) {
-            for (auto & d : flv.RetrieveVars()) {
+        locals.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
+
+        lhsUsedInCond = DeclsInCond.find(lhs) != DeclsInCond.end();
+        if (lhsUsedInCond) {
+            for (auto & d : locals.RetrieveVars()) {
+                if (DeclToMCVar.count(d) > 0)
+                    mcState.insert(DeclToMCVar[d]);
+                else if (DeclsRead.find(d) != DeclsRead.end())
+                    mcState.insert(encode(d->getName().str()));
+            }
+        }
+        if (rhsRead) {
+            for (auto & d : locals_rhs.RetrieveVars()) {
                 if (DeclToMCVar.count(d) > 0)
                     mcState.insert(DeclToMCVar[d]);
                 else if (DeclsRead.find(d) != DeclsRead.end())
                     mcState.insert(encode(d->getName().str()));
             }
         }
-
         if (mcState.size() > 0 || MallocExprs.find(rhs) != MallocExprs.end()) {
             if (lhsTooComplicated)
                 assert(0 && "couldn't find LHS of = operator");
@@ -869,11 +888,12 @@ public:
             }
             nol << "); ";
             SourceLocation place;
-            if (op)
-                place = op->getLocEnd().getLocWithOffset(1);
-            else
+            if (op) {
+                place = Lexer::getLocForEndOfToken(op->getLocEnd(), 0, rewrite.getSourceMgr(), rewrite.getLangOpts()).getLocWithOffset(1);
+            else
                 place = s->getLocEnd();
-            rewrite.InsertText(place.getLocWithOffset(1), nol.str(), true, true);
+            rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(place.getLocWithOffset(1)),
+                               nol.str(), true, true);
 
             updateProvisionalName(DeferredUpdates, lhs, mcVar);
         }
@@ -892,13 +912,13 @@ public:
 // record vars used in conditions
 class BranchConditionRefactoringHandler : public MatchFinder::MatchCallback {
 public:
-    BranchConditionRefactoringHandler(Rewriter &Rewrite,
+    BranchConditionRefactoringHandler(Rewriter &rewrite,
                                       std::set<const NamedDecl *> & DeclsInCond,
                                       std::map<const NamedDecl *, std::string> &DeclToMCVar,
                                       std::map<const Expr *, std::string> &ExprToMCVar,
                                       std::map<const Expr *, SourceLocation> &Redirector,
                                       std::vector<Update *> &DeferredUpdates) :
-        Rewrite(Rewrite), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar), 
+        rewrite(rewrite), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar),
         ExprToMCVar(ExprToMCVar), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
 
     virtual void run(const MatchFinder::MatchResult &Result) {
@@ -920,7 +940,7 @@ public:
             // e.g. int _cond0 = x == y;
             std::string SStr;
             llvm::raw_string_ostream S(SStr);
-            bc->printPretty(S, nullptr, Rewrite.getLangOpts());
+            bc->printPretty(S, nullptr, rewrite.getLangOpts());
             const std::string &Str = S.str();
             
             std::stringstream prel;
@@ -933,8 +953,15 @@ public:
                     DeclRefExpr * l = dyn_cast<DeclRefExpr>(lhs), *r = dyn_cast<DeclRefExpr>(rhs);
                     is_equality = true;
                     prel << "\nMCID " << condVarEncoded.str() << ";\n";
-                    std::string ld = DeclToMCVar.find(l->getDecl())->second,
+                    std::string ld, rd;
+                    if (DeclToMCVar.find(l->getDecl()) != DeclToMCVar.end())
+                        ld = DeclToMCVar.find(l->getDecl())->second;
+                    else
+                        ld = encode(l->getDecl()->getName());
+                    if (DeclToMCVar.find(r->getDecl()) != DeclToMCVar.end())
                         rd = DeclToMCVar.find(r->getDecl())->second;
+                    else
+                        rd = encode(r->getDecl()->getName());
 
                     prel << "\nint " << condVar << " = MC2_equals(" <<
                         ld << ", (uint64_t)" << l->getNameInfo().getName().getAsString() << ", " <<
@@ -954,14 +981,15 @@ public:
             }
 
             ExprToMCVar[cond] = condVarEncoded.str();
-            Rewrite.InsertText(is->getLocStart(), prel.str(), false, true);
+            rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(is->getLocStart()),
+                               prel.str(), false, true);
 
             // rewrite the binary op with the newly-inserted var
             Expr * RO = bc->getRHS(); // used for location only
 
-            int cl = Lexer::MeasureTokenLength(RO->getLocStart(), Rewrite.getSourceMgr(), Rewrite.getLangOpts());
-            SourceRange SR(cond->getLocStart(), Rewrite.getSourceMgr().getExpansionLoc(RO->getLocStart()).getLocWithOffset(cl-1));
-            Rewrite.ReplaceText(SR, condVar);
+            int cl = Lexer::MeasureTokenLength(RO->getLocStart(), rewrite.getSourceMgr(), rewrite.getLangOpts());
+            SourceRange SR(cond->getLocStart(), rewrite.getSourceMgr().getExpansionLoc(RO->getLocStart()).getLocWithOffset(cl-1));
+            rewrite.ReplaceText(SR, condVar);
         } else {
             std::string condVar = encodeCond(condCount++);
             std::stringstream condVarEncoded;
@@ -969,7 +997,7 @@ public:
 
             std::string SStr;
             llvm::raw_string_ostream S(SStr);
-            cond->printPretty(S, nullptr, Rewrite.getLangOpts());
+            cond->printPretty(S, nullptr, rewrite.getLangOpts());
             const std::string &Str = S.str();
 
             std::stringstream prel;
@@ -999,7 +1027,7 @@ public:
             // rewrite the call op with the newly-inserted var
             SourceRange SR(cond->getLocStart(), cond->getLocEnd());
             Redirector[cond] = is->getLocStart();
-            Rewrite.ReplaceText(SR, condVar);
+            rewrite.ReplaceText(SR, condVar);
         }
 
         std::deque<const Decl *> q;
@@ -1025,7 +1053,7 @@ public:
     }
 
 private:
-    Rewriter &Rewrite;
+    Rewriter &rewrite;
     std::set<const NamedDecl *> & DeclsInCond;
     std::map<const NamedDecl *, std::string> &DeclToMCVar;
     std::map<const Expr *, std::string> &ExprToMCVar;
@@ -1038,7 +1066,7 @@ public:
     BranchAnnotationHandler(Rewriter &rewrite,
                             std::map<const NamedDecl *, std::string> & DeclToMCVar,
                             std::map<const Expr *, std::string> & ExprToMCVar)
-        : Rewrite(rewrite),
+        : rewrite(rewrite),
           DeclToMCVar(DeclToMCVar),
           ExprToMCVar(ExprToMCVar){}
     virtual void run(const MatchFinder::MatchResult &Result) {
@@ -1065,7 +1093,8 @@ public:
 
         std::stringstream brline;
         brline << "MCID " << brVar << ";\n";
-        Rewrite.InsertText(is->getLocStart(), brline.str(), false, true);
+        rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(is->getLocStart()),
+                           brline.str(), false, true);
 
         Stmt * ts = is->getThen(), * es = is->getElse();
         bool tHasChild = hasChild(ts);
@@ -1087,7 +1116,7 @@ public:
 
         mergeStmt << "\tMC2_merge(" << brVar << ");\n";
 
-        Rewrite.InsertText(tfl, tlineStart.str(), false, true);
+        rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(tfl), tlineStart.str(), false, true);
 
         Stmt * tls = NULL;
         int extra_else_offset = 0;
@@ -1097,13 +1126,13 @@ public:
 
         if (!tHasChild || (!isa<ReturnStmt>(tls) && !isa<BreakStmt>(tls))) {
             extra_else_offset = 0;
-            Rewrite.InsertText(tsl.getLocWithOffset(1), mergeStmt.str(), true, true);
+            rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(tsl.getLocWithOffset(1)),
+                               mergeStmt.str(), true, true);
         }
         if (tHasChild && !isa<CompoundStmt>(ts)) {
-            Rewrite.InsertText(tls->getLocStart(), "{", false, true);
-            SourceLocation tend = Lexer::getLocForEndOfToken(tls->getLocStart(), 0, Rewrite.getSourceMgr(), Rewrite.getLangOpts());
-            Rewrite.InsertText(tend.getLocWithOffset(2), "}", true, true);
-            extra_else_offset++;
+            rewrite.InsertText(rewrite.getSourceMgr().getFileLoc(tls->getLocStart()), "{", false, true);
+            SourceLocation tend = Lexer::findLocationAfterToken(tls->getLocEnd(), tok::semi, rewrite.getSourceMgr(), rewrite.getLangOpts(), false);
+            rewrite.InsertText(tend, "}", true, true);
         }
         if (tHasChild && isa<CompoundStmt>(ts)) extra_else_offset++;
 
@@ -1124,22 +1153,23 @@ public:
                 }
             } else
                 el = es->getLocStart().getLocWithOffset(1);
-            Rewrite.InsertText(el, eline.str(), false, true);
+            rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(el), eline.str(), false, true);
 
             if (eHasChild && !isa<CompoundStmt>(es)) {
-                Rewrite.InsertText(el, "{", false, true);
-                Rewrite.InsertText(es->getLocEnd().getLocWithOffset(1), "}", true, true);
+                rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(el), "{", false, true);
+                rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(es->getLocEnd().getLocWithOffset(1)), "}", true, true);
             }
 
             if (!eHasChild || (!isa<ReturnStmt>(els) && !isa<BreakStmt>(els)))
-                Rewrite.InsertText(esl.getLocWithOffset(1), mergeStmt.str(), true, true);
+                rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(esl.getLocWithOffset(1)), mergeStmt.str(), true, true);
         }
         else {
             std::stringstream eCompoundLine;
             eCompoundLine << " else { " << eline.str() << mergeStmt.str() << " }";
-            SourceLocation tend = Lexer::getLocForEndOfToken(ts->getLocEnd(), 0, Rewrite.getSourceMgr(), Rewrite.getLangOpts());
-            Rewrite.InsertText(tend.getLocWithOffset(extra_else_offset),
-              eCompoundLine.str(), false, true);
+            SourceLocation tend = Lexer::findLocationAfterToken(ts->getLocEnd(), tok::semi, rewrite.getSourceMgr(), rewrite.getLangOpts(), false);
+            if (!tend.isValid())
+                tend = Lexer::getLocForEndOfToken(ts->getLocEnd(), 0, rewrite.getSourceMgr(), rewrite.getLangOpts());
+            rewrite.InsertText(tend.getLocWithOffset(1), eCompoundLine.str(), false, true);
         }
     }
 private:
@@ -1164,7 +1194,7 @@ private:
         return s;
     }
 
-    Rewriter &Rewrite;
+    Rewriter &rewrite;
     std::map<const NamedDecl *, std::string> &DeclToMCVar;
     std::map<const Expr *, std::string> &ExprToMCVar;
 };
@@ -1208,12 +1238,14 @@ public:
 
             std::stringstream brline;
             brline << "MCID " << mc_rv << ";\n";
-            rewrite.InsertText(s->getLocStart(), brline.str(), false, true);
+            rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(s->getLocStart()),
+                               brline.str(), false, true);
 
             std::stringstream nol;
             if (ce->getNumArgs() > 0) nol << ", ";
             nol << "&" << mc_rv;
-            rewrite.InsertTextBefore(ce->getRParenLoc(), nol.str());
+            rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(ce->getRParenLoc()),
+                                     nol.str());
 
             if (s && (ds = dyn_cast<DeclStmt>(s))) {
                 if (!ds->isSingleDecl()) {
@@ -1247,7 +1279,8 @@ public:
             nol << aa << ", ";
             
             if (a->getLocEnd().isValid())
-                rewrite.InsertTextBefore(a->getLocStart(), nol.str());
+                rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(a->getLocStart()),
+                                         nol.str());
         }
     }
 
@@ -1286,7 +1319,8 @@ public:
         }
         std::stringstream nol;
         nol << "*retval = " << mrv << ";\n";
-        rewrite.InsertText(rs->getLocStart(), nol.str(), false, true);
+        rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(rs->getLocStart()),
+                           nol.str(), false, true);
     }
 
 private:
@@ -1317,7 +1351,8 @@ public:
         nol << "MCID " << dn << "; ";
 
         if (d->getLocStart().isValid())
-            rewrite.InsertTextBefore(d->getLocStart(), nol.str());
+            rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(d->getLocStart()),
+                                     nol.str());
     }
 
 private:
@@ -1346,7 +1381,8 @@ public:
             std::stringstream nol;
             nol << "MCID " << encode(p->getName()) << ", ";
             if (p->getLocStart().isValid())
-                rewrite.InsertText(p->getLocStart(), nol.str(), false);
+                rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(p->getLocStart()),
+                                   nol.str(), false);
             if (p->getLocEnd().isValid())
                 LastParam = p->getLocEnd().getLocWithOffset(p->getName().size());
         }
@@ -1355,7 +1391,8 @@ public:
             std::stringstream nol;
             if (fd->param_size() > 0) nol << ", ";
             nol << "MCID * retval";
-            rewrite.InsertText(LastParam, nol.str(), false);
+            rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(LastParam),
+                               nol.str(), false);
         }
     }
 
@@ -1464,7 +1501,7 @@ public:
         MatcherSanity.matchAST(Context);
 
         for (auto & u : DeferredUpdates) {
-            R.InsertText(u->loc, u->update, true, true);
+            R.InsertText(R.getSourceMgr().getExpansionLoc(u->loc), u->update, true, true);
             delete u;
         }
         DeferredUpdates.clear();
@@ -1474,7 +1511,9 @@ private:
     /* DeclsRead contains all local variables 'x' which:
     * 1) appear in 'x = load_32(...);
     * 2) appear in 'y = store_32(x); */
-    std::set<const NamedDecl *> DeclsRead, DeclsInCond;
+    std::set<const NamedDecl *> DeclsRead;
+    /* DeclsInCond contains all local variables 'x' used in a branch condition or rmw parameter */
+    std::set<const NamedDecl *> DeclsInCond;
     std::map<const NamedDecl *, std::string> DeclToMCVar;
     std::map<const Expr *, std::string> ExprToMCVar;
     std::set<const VarDecl *> DeclsNeedingMC;