fix replacement around macro expansion; add new benchmarks from Stavros Aronis
[satcheck.git] / clang / src / add_mc2_annotations.cpp
index 1b7c92604e31547ebb3e3ce77746fda8b4df9062..a2539e545c3a98d1e9884429a49e9c00edd2cf44 100644 (file)
@@ -1,4 +1,4 @@
-// -*-  indent-tabs-mode:nil;  -*-
+// -*-  indent-tabs-mode:nil; c-basic-offset:4; -*-
 //------------------------------------------------------------------------------
 // Add MC2 annotations to C code.
 // Copyright 2015 Patrick Lam <prof.lam@gmail.com>
@@ -725,7 +725,8 @@ public:
     virtual void run(const MatchFinder::MatchResult &Result) {
         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("s");
 
-        rewrite.InsertText(s->getLocStart(), "MC2_enterLoop();\n", true, true);
+        rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(s->getLocStart()),
+                           "MC2_enterLoop();\n", true, true);
 
         // annotate all returns with MC2_exitLoop()
         // annotate all breaks that aren't further nested with MC2_exitLoop().
@@ -743,7 +744,8 @@ public:
         
         // need to find all breaks and returns embedded inside the loop
 
-        rewrite.InsertTextAfterToken(s->getLocEnd().getLocWithOffset(1), "\nMC2_exitLoop();\n");
+        rewrite.InsertTextAfterToken(rewrite.getSourceMgr().getExpansionLoc(s->getLocEnd().getLocWithOffset(1)),
+                                     "\nMC2_exitLoop();\n");
     }
 
 private:
@@ -873,7 +875,8 @@ public:
                 place = op->getLocEnd().getLocWithOffset(1);
             else
                 place = s->getLocEnd();
-            rewrite.InsertText(place.getLocWithOffset(1), nol.str(), true, true);
+            rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(place.getLocWithOffset(1)),
+                               nol.str(), true, true);
 
             updateProvisionalName(DeferredUpdates, lhs, mcVar);
         }
@@ -892,13 +895,13 @@ public:
 // record vars used in conditions
 class BranchConditionRefactoringHandler : public MatchFinder::MatchCallback {
 public:
-    BranchConditionRefactoringHandler(Rewriter &Rewrite,
+    BranchConditionRefactoringHandler(Rewriter &rewrite,
                                       std::set<const NamedDecl *> & DeclsInCond,
                                       std::map<const NamedDecl *, std::string> &DeclToMCVar,
                                       std::map<const Expr *, std::string> &ExprToMCVar,
                                       std::map<const Expr *, SourceLocation> &Redirector,
                                       std::vector<Update *> &DeferredUpdates) :
-        Rewrite(Rewrite), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar), 
+        rewrite(rewrite), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar),
         ExprToMCVar(ExprToMCVar), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
 
     virtual void run(const MatchFinder::MatchResult &Result) {
@@ -920,7 +923,7 @@ public:
             // e.g. int _cond0 = x == y;
             std::string SStr;
             llvm::raw_string_ostream S(SStr);
-            bc->printPretty(S, nullptr, Rewrite.getLangOpts());
+            bc->printPretty(S, nullptr, rewrite.getLangOpts());
             const std::string &Str = S.str();
             
             std::stringstream prel;
@@ -954,14 +957,15 @@ public:
             }
 
             ExprToMCVar[cond] = condVarEncoded.str();
-            Rewrite.InsertText(is->getLocStart(), prel.str(), false, true);
+            rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(is->getLocStart()),
+                               prel.str(), false, true);
 
             // rewrite the binary op with the newly-inserted var
             Expr * RO = bc->getRHS(); // used for location only
 
-            int cl = Lexer::MeasureTokenLength(RO->getLocStart(), Rewrite.getSourceMgr(), Rewrite.getLangOpts());
-            SourceRange SR(cond->getLocStart(), Rewrite.getSourceMgr().getExpansionLoc(RO->getLocStart()).getLocWithOffset(cl-1));
-            Rewrite.ReplaceText(SR, condVar);
+            int cl = Lexer::MeasureTokenLength(RO->getLocStart(), rewrite.getSourceMgr(), rewrite.getLangOpts());
+            SourceRange SR(cond->getLocStart(), rewrite.getSourceMgr().getExpansionLoc(RO->getLocStart()).getLocWithOffset(cl-1));
+            rewrite.ReplaceText(SR, condVar);
         } else {
             std::string condVar = encodeCond(condCount++);
             std::stringstream condVarEncoded;
@@ -969,7 +973,7 @@ public:
 
             std::string SStr;
             llvm::raw_string_ostream S(SStr);
-            cond->printPretty(S, nullptr, Rewrite.getLangOpts());
+            cond->printPretty(S, nullptr, rewrite.getLangOpts());
             const std::string &Str = S.str();
 
             std::stringstream prel;
@@ -999,7 +1003,7 @@ public:
             // rewrite the call op with the newly-inserted var
             SourceRange SR(cond->getLocStart(), cond->getLocEnd());
             Redirector[cond] = is->getLocStart();
-            Rewrite.ReplaceText(SR, condVar);
+            rewrite.ReplaceText(SR, condVar);
         }
 
         std::deque<const Decl *> q;
@@ -1025,7 +1029,7 @@ public:
     }
 
 private:
-    Rewriter &Rewrite;
+    Rewriter &rewrite;
     std::set<const NamedDecl *> & DeclsInCond;
     std::map<const NamedDecl *, std::string> &DeclToMCVar;
     std::map<const Expr *, std::string> &ExprToMCVar;
@@ -1038,7 +1042,7 @@ public:
     BranchAnnotationHandler(Rewriter &rewrite,
                             std::map<const NamedDecl *, std::string> & DeclToMCVar,
                             std::map<const Expr *, std::string> & ExprToMCVar)
-        : Rewrite(rewrite),
+        : rewrite(rewrite),
           DeclToMCVar(DeclToMCVar),
           ExprToMCVar(ExprToMCVar){}
     virtual void run(const MatchFinder::MatchResult &Result) {
@@ -1065,7 +1069,8 @@ public:
 
         std::stringstream brline;
         brline << "MCID " << brVar << ";\n";
-        Rewrite.InsertText(is->getLocStart(), brline.str(), false, true);
+        rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(is->getLocStart()),
+                           brline.str(), false, true);
 
         Stmt * ts = is->getThen(), * es = is->getElse();
         bool tHasChild = hasChild(ts);
@@ -1087,7 +1092,7 @@ public:
 
         mergeStmt << "\tMC2_merge(" << brVar << ");\n";
 
-        Rewrite.InsertText(tfl, tlineStart.str(), false, true);
+        rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(tfl), tlineStart.str(), false, true);
 
         Stmt * tls = NULL;
         int extra_else_offset = 0;
@@ -1097,13 +1102,13 @@ public:
 
         if (!tHasChild || (!isa<ReturnStmt>(tls) && !isa<BreakStmt>(tls))) {
             extra_else_offset = 0;
-            Rewrite.InsertText(tsl.getLocWithOffset(1), mergeStmt.str(), true, true);
+            rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(tsl.getLocWithOffset(1)),
+                               mergeStmt.str(), true, true);
         }
         if (tHasChild && !isa<CompoundStmt>(ts)) {
-            Rewrite.InsertText(tls->getLocStart(), "{", false, true);
-            SourceLocation tend = Lexer::getLocForEndOfToken(tls->getLocStart(), 0, Rewrite.getSourceMgr(), Rewrite.getLangOpts());
-            Rewrite.InsertText(tend.getLocWithOffset(2), "}", true, true);
-            extra_else_offset++;
+            rewrite.InsertText(rewrite.getSourceMgr().getFileLoc(tls->getLocStart()), "{", false, true);
+            SourceLocation tend = Lexer::findLocationAfterToken(tls->getLocEnd(), tok::semi, rewrite.getSourceMgr(), rewrite.getLangOpts(), false);
+            rewrite.InsertText(tend, "}", true, true);
         }
         if (tHasChild && isa<CompoundStmt>(ts)) extra_else_offset++;
 
@@ -1124,22 +1129,23 @@ public:
                 }
             } else
                 el = es->getLocStart().getLocWithOffset(1);
-            Rewrite.InsertText(el, eline.str(), false, true);
+            rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(el), eline.str(), false, true);
 
             if (eHasChild && !isa<CompoundStmt>(es)) {
-                Rewrite.InsertText(el, "{", false, true);
-                Rewrite.InsertText(es->getLocEnd().getLocWithOffset(1), "}", true, true);
+                rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(el), "{", false, true);
+                rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(es->getLocEnd().getLocWithOffset(1)), "}", true, true);
             }
 
             if (!eHasChild || (!isa<ReturnStmt>(els) && !isa<BreakStmt>(els)))
-                Rewrite.InsertText(esl.getLocWithOffset(1), mergeStmt.str(), true, true);
+                rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(esl.getLocWithOffset(1)), mergeStmt.str(), true, true);
         }
         else {
             std::stringstream eCompoundLine;
             eCompoundLine << " else { " << eline.str() << mergeStmt.str() << " }";
-            SourceLocation tend = Lexer::getLocForEndOfToken(ts->getLocEnd(), 0, Rewrite.getSourceMgr(), Rewrite.getLangOpts());
-            Rewrite.InsertText(tend.getLocWithOffset(extra_else_offset),
-              eCompoundLine.str(), false, true);
+            SourceLocation tend = Lexer::findLocationAfterToken(ts->getLocEnd(), tok::semi, rewrite.getSourceMgr(), rewrite.getLangOpts(), false);
+            if (!tend.isValid())
+                tend = Lexer::getLocForEndOfToken(ts->getLocEnd(), 0, rewrite.getSourceMgr(), rewrite.getLangOpts());
+            rewrite.InsertText(tend.getLocWithOffset(1), eCompoundLine.str(), false, true);
         }
     }
 private:
@@ -1164,7 +1170,7 @@ private:
         return s;
     }
 
-    Rewriter &Rewrite;
+    Rewriter &rewrite;
     std::map<const NamedDecl *, std::string> &DeclToMCVar;
     std::map<const Expr *, std::string> &ExprToMCVar;
 };
@@ -1208,12 +1214,14 @@ public:
 
             std::stringstream brline;
             brline << "MCID " << mc_rv << ";\n";
-            rewrite.InsertText(s->getLocStart(), brline.str(), false, true);
+            rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(s->getLocStart()),
+                               brline.str(), false, true);
 
             std::stringstream nol;
             if (ce->getNumArgs() > 0) nol << ", ";
             nol << "&" << mc_rv;
-            rewrite.InsertTextBefore(ce->getRParenLoc(), nol.str());
+            rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(ce->getRParenLoc()),
+                                     nol.str());
 
             if (s && (ds = dyn_cast<DeclStmt>(s))) {
                 if (!ds->isSingleDecl()) {
@@ -1247,7 +1255,8 @@ public:
             nol << aa << ", ";
             
             if (a->getLocEnd().isValid())
-                rewrite.InsertTextBefore(a->getLocStart(), nol.str());
+                rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(a->getLocStart()),
+                                         nol.str());
         }
     }
 
@@ -1286,7 +1295,8 @@ public:
         }
         std::stringstream nol;
         nol << "*retval = " << mrv << ";\n";
-        rewrite.InsertText(rs->getLocStart(), nol.str(), false, true);
+        rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(rs->getLocStart()),
+                           nol.str(), false, true);
     }
 
 private:
@@ -1317,7 +1327,8 @@ public:
         nol << "MCID " << dn << "; ";
 
         if (d->getLocStart().isValid())
-            rewrite.InsertTextBefore(d->getLocStart(), nol.str());
+            rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(d->getLocStart()),
+                                     nol.str());
     }
 
 private:
@@ -1346,7 +1357,8 @@ public:
             std::stringstream nol;
             nol << "MCID " << encode(p->getName()) << ", ";
             if (p->getLocStart().isValid())
-                rewrite.InsertText(p->getLocStart(), nol.str(), false);
+                rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(p->getLocStart()),
+                                   nol.str(), false);
             if (p->getLocEnd().isValid())
                 LastParam = p->getLocEnd().getLocWithOffset(p->getName().size());
         }
@@ -1355,7 +1367,8 @@ public:
             std::stringstream nol;
             if (fd->param_size() > 0) nol << ", ";
             nol << "MCID * retval";
-            rewrite.InsertText(LastParam, nol.str(), false);
+            rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(LastParam),
+                               nol.str(), false);
         }
     }
 
@@ -1464,7 +1477,7 @@ public:
         MatcherSanity.matchAST(Context);
 
         for (auto & u : DeferredUpdates) {
-            R.InsertText(u->loc, u->update, true, true);
+            R.InsertText(R.getSourceMgr().getExpansionLoc(u->loc), u->update, true, true);
             delete u;
         }
         DeferredUpdates.clear();