add default MC-names for not-found variables in equality test
[satcheck.git] / clang / src / add_mc2_annotations.cpp
index a2539e545c3a98d1e9884429a49e9c00edd2cf44..740a710ca8f965355f74831d536d49b3cc873097 100644 (file)
@@ -774,7 +774,7 @@ public:
     virtual void run(const MatchFinder::MatchResult &Result) {
         BinaryOperator * op = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("op"));
         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
-        FindLocalsVisitor flv;
+        FindLocalsVisitor locals, locals_rhs;
 
         const VarDecl * lhs = NULL;
         const Expr * rhs = NULL;
@@ -810,10 +810,11 @@ public:
         }
         std::set<std::string> mcState;
 
+        bool lhsUsedInCond;
+        bool rhsRead = false;
+
         bool lhsTooComplicated = false;
         if (op) {
-            flv.TraverseStmt(op);
-
             DeclRefExpr * vd;
             if ((vd = dyn_cast<DeclRefExpr>(op->getLHS())))
                 lhs = dyn_cast<VarDecl>(vd->getDecl());
@@ -826,21 +827,37 @@ public:
             if (rhs) 
                 rhs = rhs->IgnoreCasts();
         }
-        else if (lhs) {
-            // rhs must be MC-active state, i.e. in declsread
-            // lhs must be subsequently used in (1) store/load or (2) branch condition or (3) other functions and (3a) uses values from other functions or (3b) uses values from loads, stores, or phi functions
-            flv.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
+
+        // rhs must be MC-active state, i.e. in declsread
+        // lhs must be subsequently used in (1) store/load or (2) branch condition or (3) other functions and (3a) uses values from other functions or (3b) uses values from loads, stores, or phi functions
+
+        if (rhs) {
+            locals_rhs.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
+            for (auto & nd : locals_rhs.RetrieveVars()) {
+                if (DeclsRead.find(nd) != DeclsRead.end())
+                    rhsRead = true;
+            }
         }
 
-        if (DeclsInCond.find(lhs) != DeclsInCond.end()) {
-            for (auto & d : flv.RetrieveVars()) {
+        locals.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
+
+        lhsUsedInCond = DeclsInCond.find(lhs) != DeclsInCond.end();
+        if (lhsUsedInCond) {
+            for (auto & d : locals.RetrieveVars()) {
+                if (DeclToMCVar.count(d) > 0)
+                    mcState.insert(DeclToMCVar[d]);
+                else if (DeclsRead.find(d) != DeclsRead.end())
+                    mcState.insert(encode(d->getName().str()));
+            }
+        }
+        if (rhsRead) {
+            for (auto & d : locals_rhs.RetrieveVars()) {
                 if (DeclToMCVar.count(d) > 0)
                     mcState.insert(DeclToMCVar[d]);
                 else if (DeclsRead.find(d) != DeclsRead.end())
                     mcState.insert(encode(d->getName().str()));
             }
         }
-
         if (mcState.size() > 0 || MallocExprs.find(rhs) != MallocExprs.end()) {
             if (lhsTooComplicated)
                 assert(0 && "couldn't find LHS of = operator");
@@ -871,9 +888,9 @@ public:
             }
             nol << "); ";
             SourceLocation place;
-            if (op)
-                place = op->getLocEnd().getLocWithOffset(1);
-            else
+            if (op) {
+                place = Lexer::getLocForEndOfToken(op->getLocEnd(), 0, rewrite.getSourceMgr(), rewrite.getLangOpts()).getLocWithOffset(1);
+            else
                 place = s->getLocEnd();
             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(place.getLocWithOffset(1)),
                                nol.str(), true, true);
@@ -936,8 +953,15 @@ public:
                     DeclRefExpr * l = dyn_cast<DeclRefExpr>(lhs), *r = dyn_cast<DeclRefExpr>(rhs);
                     is_equality = true;
                     prel << "\nMCID " << condVarEncoded.str() << ";\n";
-                    std::string ld = DeclToMCVar.find(l->getDecl())->second,
+                    std::string ld, rd;
+                    if (DeclToMCVar.find(l->getDecl()) != DeclToMCVar.end())
+                        ld = DeclToMCVar.find(l->getDecl())->second;
+                    else
+                        ld = encode(l->getDecl()->getName());
+                    if (DeclToMCVar.find(r->getDecl()) != DeclToMCVar.end())
                         rd = DeclToMCVar.find(r->getDecl())->second;
+                    else
+                        rd = encode(r->getDecl()->getName());
 
                     prel << "\nint " << condVar << " = MC2_equals(" <<
                         ld << ", (uint64_t)" << l->getNameInfo().getName().getAsString() << ", " <<
@@ -1487,7 +1511,9 @@ private:
     /* DeclsRead contains all local variables 'x' which:
     * 1) appear in 'x = load_32(...);
     * 2) appear in 'y = store_32(x); */
-    std::set<const NamedDecl *> DeclsRead, DeclsInCond;
+    std::set<const NamedDecl *> DeclsRead;
+    /* DeclsInCond contains all local variables 'x' used in a branch condition or rmw parameter */
+    std::set<const NamedDecl *> DeclsInCond;
     std::map<const NamedDecl *, std::string> DeclToMCVar;
     std::map<const Expr *, std::string> ExprToMCVar;
     std::set<const VarDecl *> DeclsNeedingMC;