add MC2_function call for assignments where RHS computed from loads; tweak tests
[satcheck.git] / clang / src / add_mc2_annotations.cpp
1 // -*-  indent-tabs-mode:nil; c-basic-offset:4; -*-
2 //------------------------------------------------------------------------------
3 // Add MC2 annotations to C code.
4 // Copyright 2015 Patrick Lam <prof.lam@gmail.com>
5 //
6 // Permission is hereby granted, free of charge, to any person
7 // obtaining a copy of this software and associated documentation
8 // files (the "Software"), to deal with the Software without
9 // restriction, including without limitation the rights to use, copy,
10 // modify, merge, publish, distribute, sublicense, and/or sell copies
11 // of the Software, and to permit persons to whom the Software is
12 // furnished to do so, subject to the following conditions:
13 //
14 // Redistributions of source code must retain the above copyright
15 // notice, this list of conditions and the following disclaimers.
16 //
17 // Redistributions in binary form must reproduce the above copyright
18 // notice, this list of conditions and the following disclaimers in
19 // the documentation and/or other materials provided with the
20 // distribution.
21 //
22 // Neither the names of the University of Waterloo, nor the names of
23 // its contributors may be used to endorse or promote products derived
24 // from this Software without specific prior written permission.
25 //
26 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
27 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
28 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
29 // NONINFRINGEMENT. IN NO EVENT SHALL THE CONTRIBUTORS OR COPYRIGHT
30 // HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
31 // WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
32 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
33 // DEALINGS WITH THE SOFTWARE.
34 //
35 // Patrick Lam (prof.lam@gmail.com)
36 //
37 // Base code:
38 // Eli Bendersky (eliben@gmail.com)
39 //
40 //------------------------------------------------------------------------------
41 #include <sstream>
42 #include <string>
43 #include <map>
44 #include <stdbool.h>
45
46 #include "clang/AST/AST.h"
47 #include "clang/AST/ASTContext.h"
48 #include "clang/AST/ASTConsumer.h"
49 #include "clang/AST/RecursiveASTVisitor.h"
50 #include "clang/ASTMatchers/ASTMatchers.h"
51 #include "clang/ASTMatchers/ASTMatchFinder.h"
52 #include "clang/Frontend/ASTConsumers.h"
53 #include "clang/Frontend/FrontendActions.h"
54 #include "clang/Frontend/CompilerInstance.h"
55 #include "clang/Lex/Lexer.h"
56 #include "clang/Tooling/CommonOptionsParser.h"
57 #include "clang/Tooling/Tooling.h"
58 #include "clang/Rewrite/Core/Rewriter.h"
59 #include "llvm/Support/raw_ostream.h"
60 #include "llvm/ADT/STLExtras.h"
61
62 using namespace clang;
63 using namespace clang::ast_matchers;
64 using namespace clang::driver;
65 using namespace clang::tooling;
66 using namespace llvm;
67
68 static LangOptions LangOpts;
69 static llvm::cl::OptionCategory AddMC2AnnotationsCategory("Add MC2 Annotations");
70
71 static std::string encode(std::string varName) {
72     std::stringstream nn;
73     nn << "_m" << varName;
74     return nn.str();
75 }
76
77 static int fnCount;
78 static std::string encodeFn(int num) {
79     std::stringstream nn;
80     nn << "_fn" << num;
81     return nn.str();
82 };
83
84 static int ptrCount;
85 static std::string encodePtr(int num) {
86     std::stringstream nn;
87     nn << "_p" << num;
88     return nn.str();
89 };
90
91 static int rmwCount;
92 static std::string encodeRMW(int num) {
93     std::stringstream nn;
94     nn << "_rmw" << num;
95     return nn.str();
96 };
97
98 static int branchCount;
99 static std::string encodeBranch(int num) {
100     std::stringstream nn;
101     nn << "_br" << num;
102     return nn.str();
103 };
104
105 static int condCount;
106 static std::string encodeCond(int num) {
107     std::stringstream nn;
108     nn << "_cond" << num;
109     return nn.str();
110 };
111
112 static int rvCount;
113 static std::string encodeRV(int num) {
114     std::stringstream nn;
115     nn << "_rv" << num;
116     return nn.str();
117 };
118
119 static int funcCount;
120
121 struct ProvisionalName {
122     int index, length;
123     const DeclRefExpr * pname;
124     bool enabled;
125
126     ProvisionalName(int index, const DeclRefExpr * pname) : index(index), pname(pname), length(encode(pname->getNameInfo().getName().getAsString()).length()), enabled(true) {}
127     ProvisionalName(int index, const DeclRefExpr * pname, int length) : index(index), pname(pname), length(length), enabled(true) {}
128 };
129
130 struct Update {
131     SourceLocation loc;
132     std::string update;
133     std::vector<ProvisionalName *> * pnames;
134
135     Update(SourceLocation loc, std::string update, std::vector<ProvisionalName *> * pnames) : 
136         loc(loc), update(update), pnames(pnames) {}
137
138     ~Update() { 
139         for (auto pname : *pnames) delete pname;
140         delete pnames; 
141     }
142 };
143
144 void updateProvisionalName(std::vector<Update *> &DeferredUpdates, const ValueDecl * now_known, std::string mcVar) {
145     for (Update * u : DeferredUpdates) {
146         for (int i = 0; i < u->pnames->size(); i++) {
147             ProvisionalName * v = (*(u->pnames))[i];
148             if (!v->enabled) continue;
149             if (now_known == v->pname->getDecl()) {
150                 v->enabled = false;
151                 std::string oldName = encode(v->pname->getNameInfo().getName().getAsString());
152
153                 u->update.replace(v->index, v->length, mcVar);
154                 for (int j = i+1; j < u->pnames->size(); j++) {
155                     ProvisionalName * vv = (*(u->pnames))[j];
156                     if (vv->index > v->index)
157                         vv->index -= v->length - mcVar.length();
158                 }
159             }
160         }
161     }
162 }
163
164 static const VarDecl * retrieveSingleDecl(const DeclStmt * s) {
165     // XXX iterate through all decls defined in s, not just the first one
166     assert(s->isSingleDecl() && isa<VarDecl>(s->getSingleDecl()) && "unsupported form of decl");
167     if (s->isSingleDecl() && isa<VarDecl>(s->getSingleDecl())) {
168         return cast<VarDecl>(s->getSingleDecl());
169     } else return NULL;
170 }
171
172 class FindCallArgVisitor : public RecursiveASTVisitor<FindCallArgVisitor> {
173 public:
174     FindCallArgVisitor() : DE(NULL), UnaryOp(NULL) {}
175
176     bool VisitStmt(Stmt * s) {
177         if (!UnaryOp) {
178             if (UnaryOperator * uo = dyn_cast<UnaryOperator>(s)) {
179                 if (uo->getOpcode() == UnaryOperatorKind::UO_AddrOf ||
180                     uo->getOpcode() == UnaryOperatorKind::UO_Deref)
181                     UnaryOp = uo;
182             }
183         }
184
185         if (!DE && (DE = dyn_cast<DeclRefExpr>(s)))
186             ;
187         return true;
188     }
189
190     void Clear() {
191         UnaryOp = NULL; DE = NULL;
192     }
193
194     const UnaryOperator * RetrieveUnaryOp() {
195         if (UnaryOp) {
196             bool found = false;
197             const Stmt * s = UnaryOp;
198             while (s != NULL) {
199                 if (s == DE) {
200                     found = true; break;
201                 }
202
203                 if (const UnaryOperator * op = dyn_cast<UnaryOperator>(s))
204                     s = op->getSubExpr();
205                 else if (const CastExpr * op = dyn_cast<CastExpr>(s))
206                     s = op->getSubExpr();
207                 else if (const MemberExpr * op = dyn_cast<MemberExpr>(s))
208                     s = op->getBase();
209                 else
210                     s = NULL;
211             }
212             if (found)
213                 return UnaryOp;
214         }
215         return NULL;
216     }
217
218     const DeclRefExpr * RetrieveDeclRefExpr() {
219         return DE;
220     }
221
222 private:
223     const UnaryOperator * UnaryOp;
224     const DeclRefExpr * DE;
225 };
226
227 class FindLocalsVisitor : public RecursiveASTVisitor<FindLocalsVisitor> {
228 public:
229     FindLocalsVisitor() : Vars() {}
230
231     bool VisitDeclRefExpr(DeclRefExpr * de) {
232         Vars.push_back(de->getDecl());
233         return true;
234     }
235
236     void Clear() {
237         Vars.clear();
238     }
239
240     const TinyPtrVector<const NamedDecl *> RetrieveVars() {
241         return Vars;
242     }
243
244 private:
245     TinyPtrVector<const NamedDecl *> Vars;
246 };
247
248
249 class MallocHandler : public MatchFinder::MatchCallback {
250 public:
251     MallocHandler(std::set<const Expr *> & MallocExprs) :
252         MallocExprs(MallocExprs) {}
253
254     virtual void run(const MatchFinder::MatchResult &Result) {
255         const CallExpr * ce = Result.Nodes.getNodeAs<CallExpr>("callExpr");
256
257         MallocExprs.insert(ce);
258     }
259
260     private:
261     std::set<const Expr *> &MallocExprs;
262 };
263
264 static void generateMC2Function(Rewriter & Rewrite,
265                                 const Expr * e, 
266                                 SourceLocation loc,
267                                 std::string tmpname, 
268                                 std::string tmpFn, 
269                                 const DeclRefExpr * lhs, 
270                                 std::string lhsName,
271                                 std::vector<ProvisionalName *> * vars1,
272                                 std::vector<Update *> & DeferredUpdates) {
273     // prettyprint the LHS (&newnode->value)
274     // e.g. int * _tmp0 = &newnode->value;
275     std::string SStr;
276     llvm::raw_string_ostream S(SStr);
277     e->printPretty(S, nullptr, Rewrite.getLangOpts());
278     const std::string &Str = S.str();
279
280     std::stringstream prel;
281     prel << "\nvoid * " << tmpname << " = " << Str << ";\n";
282
283     // MCID _p0 = MC2_function(1, MC2_PTR_LENGTH, _tmp0, _fn0);
284     prel << "MCID " << tmpFn << " = MC2_function_id(" << ++funcCount << ", 1, MC2_PTR_LENGTH, " << tmpname << ", ";
285     if (lhs) {
286         // XXX generate casts when they'd eliminate warnings
287         ProvisionalName * v = new ProvisionalName(prel.tellp(), lhs);
288         vars1->push_back(v);
289     }
290     prel << encode(lhsName) << "); ";
291
292     Update * u = new Update(loc, prel.str(), vars1);
293     DeferredUpdates.push_back(u);
294 }
295
296 class LoadHandler : public MatchFinder::MatchCallback {
297 public:
298     LoadHandler(Rewriter &Rewrite,
299                 std::set<const NamedDecl *> & DeclsRead,
300                 std::set<const VarDecl *> & DeclsNeedingMC, 
301                 std::map<const NamedDecl *, std::string> &DeclToMCVar,
302                 std::map<const Expr *, std::string> &ExprToMCVar,
303                 std::set<const Stmt *> & StmtsHandled,
304                 std::map<const Expr *, SourceLocation> &Redirector,
305                 std::vector<Update *> & DeferredUpdates) : 
306         Rewrite(Rewrite), DeclsRead(DeclsRead), DeclsNeedingMC(DeclsNeedingMC), DeclToMCVar(DeclToMCVar),
307         ExprToMCVar(ExprToMCVar),
308         StmtsHandled(StmtsHandled), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
309
310     virtual void run(const MatchFinder::MatchResult &Result) {
311         std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
312         CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
313         const VarDecl * d = Result.Nodes.getNodeAs<VarDecl>("decl");
314         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
315         const Expr * lhs = NULL;
316         if (s && isa<BinaryOperator>(s)) lhs = cast<BinaryOperator>(s)->getLHS();
317         if (!s) s = ce;
318
319         const DeclRefExpr * rhs = NULL;
320         MemberExpr * ml = NULL;
321         bool isAddrOfR = false, isAddrMemberR = false;
322
323         StmtsHandled.insert(s);
324         
325         std::string n, n_decl;
326         if (lhs) {
327             FindCallArgVisitor fcaVisitor;
328             fcaVisitor.Clear();
329             fcaVisitor.TraverseStmt(ce->getArg(0));
330             rhs = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
331             const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
332             isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
333             isAddrMemberR = ruop && isa<MemberExpr>(ruop->getSubExpr());
334             if (isAddrMemberR)
335                 ml = dyn_cast<MemberExpr>(ruop->getSubExpr());
336
337             FindLocalsVisitor flv;
338             flv.Clear();
339             flv.TraverseStmt(const_cast<Stmt*>(cast<Stmt>(lhs)));
340             for (auto & d : flv.RetrieveVars()) {
341                 const VarDecl * dd = cast<VarDecl>(d);
342                 n = dd->getName();
343                 // XXX todo rhs for non-decl stmts
344                 if (!isa<ParmVarDecl>(dd))
345                     DeclsNeedingMC.insert(dd);
346                 DeclsRead.insert(d);
347                 DeclToMCVar[dd] = encode(n);
348             }
349         } else {
350             FindCallArgVisitor fcaVisitor;
351             fcaVisitor.Clear();
352             fcaVisitor.TraverseStmt(ce->getArg(0));
353             rhs = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
354             const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
355             isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
356             isAddrMemberR = ruop && isa<MemberExpr>(ruop->getSubExpr());
357             if (isAddrMemberR)
358                 ml = dyn_cast<MemberExpr>(ruop->getSubExpr());
359
360             if (d) {
361                 n = d->getName();
362                 DeclsNeedingMC.insert(d);
363                 DeclsRead.insert(d);
364                 DeclToMCVar[d] = encode(n);
365             } else {
366                 n = ExprToMCVar[ce];
367                 fcaVisitor.Clear();
368                 fcaVisitor.TraverseStmt(ce);
369                 const DeclRefExpr * dd = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
370                 updateProvisionalName(DeferredUpdates, dd->getDecl(), encode(n));
371                 DeclToMCVar[dd->getDecl()] = encode(n);
372                 n_decl = "MCID ";
373             }
374         }
375         
376         std::stringstream nol;
377
378         if (lhs && isa<DeclRefExpr>(lhs)) {
379             const DeclRefExpr * ll = cast<DeclRefExpr>(lhs);
380             ProvisionalName * v = new ProvisionalName(nol.tellp(), ll);
381             vars->push_back(v);
382         }
383
384         if (rhs) {
385             if (isAddrMemberR) {
386                 if (!n.empty()) 
387                     nol << n_decl << encode(n) << "=";
388                 nol << "MC2_nextOpLoadOffset(";
389
390                 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
391                 vars->push_back(v);
392                 nol << encode(rhs->getNameInfo().getName().getAsString());
393
394                 nol << ", MC2_OFFSET(";
395                 nol << ml->getBase()->getType().getAsString();
396                 nol << ", ";
397                 nol << ml->getMemberDecl()->getName().str();
398                 nol << ")";
399             } else if (!isAddrOfR) {
400                 if (!n.empty()) 
401                     nol << n_decl << encode(n) << "=";
402                 nol << "MC2_nextOpLoad(";
403                 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
404                 vars->push_back(v);
405                 nol << encode(rhs->getNameInfo().getName().getAsString());
406             } else {
407                 if (!n.empty()) 
408                     nol << n_decl << encode(n) << "=";
409                 nol << "MC2_nextOpLoad(";
410                 nol << "MCID_NODEP";
411             }
412         } else {
413             if (!n.empty()) 
414                 nol << n_decl << encode(n) << "=";
415             nol << "MC2_nextOpLoad(";
416             nol << "MCID_NODEP";
417         }
418
419         if (lhs)
420             nol << "), ";
421         else
422             nol << "); ";
423         SourceLocation ss = s->getLocStart();
424         // eek gross hack:
425         // if the load appears as its own stmt and is the 1st stmt, containingStmt may be the containing CompoundStmt;
426         // move over 1 so that we get the right location.
427         if (isa<CompoundStmt>(s)) ss = ss.getLocWithOffset(1);
428         const Expr * e = dyn_cast<Expr>(s);
429         if (e && Redirector.count(e) > 0)
430             ss = Redirector[e];
431         Update * u = new Update(ss, nol.str(), vars);
432         DeferredUpdates.insert(DeferredUpdates.begin(), u);
433     }
434
435     private:
436     Rewriter &Rewrite;
437     std::set<const NamedDecl *> & DeclsRead;
438     std::set<const VarDecl *> & DeclsNeedingMC;
439     std::map<const Expr *, std::string> &ExprToMCVar;
440     std::map<const NamedDecl *, std::string> &DeclToMCVar;
441     std::set<const Stmt *> &StmtsHandled;
442     std::vector<Update *> &DeferredUpdates;
443     std::map<const Expr *, SourceLocation> &Redirector;
444 };
445
446 class StoreHandler : public MatchFinder::MatchCallback {
447 public:
448     StoreHandler(Rewriter &Rewrite,
449                  std::set<const NamedDecl *> & DeclsRead,
450                  std::set<const VarDecl *> &DeclsNeedingMC,
451                  std::vector<Update *> & DeferredUpdates) : 
452         Rewrite(Rewrite), DeclsRead(DeclsRead), DeclsNeedingMC(DeclsNeedingMC), DeferredUpdates(DeferredUpdates) {}
453
454     virtual void run(const MatchFinder::MatchResult &Result) {
455         std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
456         CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
457
458         fcaVisitor.Clear();
459         fcaVisitor.TraverseStmt(ce->getArg(0));
460         const DeclRefExpr * lhs = fcaVisitor.RetrieveDeclRefExpr();
461         const UnaryOperator * luop = fcaVisitor.RetrieveUnaryOp();
462     
463         std::stringstream nol;
464
465         bool isAddrMemberL;
466         bool isAddrOfL;
467
468         if (luop && luop->getOpcode() == UnaryOperatorKind::UO_AddrOf) {
469             isAddrMemberL = isa<MemberExpr>(luop->getSubExpr());
470             isAddrOfL = !isa<MemberExpr>(luop->getSubExpr());
471         }
472
473         if (lhs) {
474             if (isAddrOfL) {
475                 nol << "MC2_nextOpStore(";
476                 nol << "MCID_NODEP";
477             } else {
478                 if (isAddrMemberL) {
479                     MemberExpr * ml = cast<MemberExpr>(luop->getSubExpr());
480
481                     nol << "MC2_nextOpStoreOffset(";
482
483                     ProvisionalName * v = new ProvisionalName(nol.tellp(), lhs);
484                     vars->push_back(v);
485
486                     nol << encode(lhs->getNameInfo().getName().getAsString());
487                     if (!isa<ParmVarDecl>(lhs->getDecl()))
488                         DeclsNeedingMC.insert(cast<VarDecl>(lhs->getDecl()));
489
490                     nol << ", MC2_OFFSET(";
491                     nol << ml->getBase()->getType().getAsString();
492                     nol << ", ";
493                     nol << ml->getMemberDecl()->getName().str();
494                     nol << ")";
495                 } else {
496                     nol << "MC2_nextOpStore(";
497                     ProvisionalName * v = new ProvisionalName(nol.tellp(), lhs);
498                     vars->push_back(v);
499
500                     nol << encode(lhs->getNameInfo().getName().getAsString());
501                 }
502             }
503         }
504         else {
505             nol << "MC2_nextOpStore(";
506             nol << "MCID_NODEP";
507         }
508         
509         nol << ", ";
510
511         fcaVisitor.Clear();
512         fcaVisitor.TraverseStmt(ce->getArg(1));
513         const DeclRefExpr * rhs = fcaVisitor.RetrieveDeclRefExpr();
514         const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
515
516         bool isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
517         bool isDerefR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_Deref;
518
519         if (rhs && !isAddrOfR) {
520             assert (!isDerefR && "Must use atomic load for dereferences!");
521             ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
522             vars->push_back(v);
523
524             nol << encode(rhs->getNameInfo().getName().getAsString());
525             DeclsRead.insert(rhs->getDecl());
526         }
527         else
528             nol << "MCID_NODEP";
529         
530         nol << ");\n";
531         Update * u = new Update(ce->getLocStart(), nol.str(), vars);
532         DeferredUpdates.push_back(u);
533     }
534
535     private:
536     Rewriter &Rewrite;
537     FindCallArgVisitor fcaVisitor;
538     std::set<const NamedDecl *> & DeclsRead;
539     std::set<const VarDecl *> & DeclsNeedingMC;
540     std::vector<Update *> &DeferredUpdates;
541 };
542
543 class RMWHandler : public MatchFinder::MatchCallback {
544 public:
545     RMWHandler(Rewriter &rewrite,
546                  std::set<const NamedDecl *> & DeclsRead,
547                  std::set<const NamedDecl *> & DeclsInCond,
548                  std::map<const NamedDecl *, std::string> &DeclToMCVar,
549                  std::map<const Expr *, std::string> &ExprToMCVar,
550                  std::set<const Stmt *> & StmtsHandled,
551                  std::map<const Expr *, SourceLocation> &Redirector,
552                  std::vector<Update *> & DeferredUpdates) : 
553         rewrite(rewrite), DeclsRead(DeclsRead), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar), 
554         ExprToMCVar(ExprToMCVar),
555         StmtsHandled(StmtsHandled), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
556
557     virtual void run(const MatchFinder::MatchResult &Result) {
558         CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
559         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
560         std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
561
562         std::stringstream nol;
563         
564         std::string rmwMCVar;
565         rmwMCVar = encodeRMW(rmwCount++);
566
567         const VarDecl * rmw_lhs;
568         if (s) {
569             StmtsHandled.insert(s);
570             assert (isa<DeclStmt>(s) || isa<BinaryOperator>(s) && "unknown RMW format: not declrefexpr, not binaryoperator");
571             const DeclStmt * ds;
572             if ((ds = dyn_cast<DeclStmt>(s))) {
573                 rmw_lhs = retrieveSingleDecl(ds);
574             } else {
575                 const Expr * e = cast<BinaryOperator>(s)->getLHS();
576                 assert (isa<DeclRefExpr>(e));
577                 rmw_lhs = cast<VarDecl>(cast<DeclRefExpr>(e)->getDecl());
578             }
579             DeclToMCVar[rmw_lhs] = rmwMCVar;
580         }
581
582         // retrieve effective LHS of the RMW
583         fcaVisitor.Clear();
584         fcaVisitor.TraverseStmt(ce->getArg(1));
585         const DeclRefExpr * elhs = fcaVisitor.RetrieveDeclRefExpr();
586         const UnaryOperator * eluop = fcaVisitor.RetrieveUnaryOp();
587         bool isAddrMemberL = false;
588
589         if (eluop && eluop->getOpcode() == UnaryOperatorKind::UO_AddrOf) {
590             isAddrMemberL = isa<MemberExpr>(eluop->getSubExpr());
591         }
592
593         nol << "MCID " << rmwMCVar;
594         if (isAddrMemberL) {
595             MemberExpr * ml = cast<MemberExpr>(eluop->getSubExpr());
596
597             nol << " = MC2_nextRMWOffset(";
598
599             ProvisionalName * v = new ProvisionalName(nol.tellp(), elhs);
600             vars->push_back(v);
601
602             nol << encode(elhs->getNameInfo().getName().getAsString());
603
604             nol << ", MC2_OFFSET(";
605             nol << ml->getBase()->getType().getAsString();
606             nol << ", ";
607             nol << ml->getMemberDecl()->getName().str();
608             nol << ")";
609         } else {
610             nol << " = MC2_nextRMW(";
611             bool isAddrOfL = eluop && eluop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
612
613             if (elhs) {
614                 if (isAddrOfL)
615                     nol << "MCID_NODEP";
616                 else {
617                     ProvisionalName * v = new ProvisionalName(nol.tellp(), elhs);
618                     vars->push_back(v);
619
620                     std::string elhsName = encode(elhs->getNameInfo().getName().getAsString());
621                     nol << elhsName;
622                 }
623             }
624             else
625                 nol << "MCID_NODEP";
626         }
627         nol << ", ";
628
629         // handle both RHS ops
630         int outputted = 0;
631         for (int arg = 2; arg < 4; arg++) {
632             fcaVisitor.Clear();
633             fcaVisitor.TraverseStmt(ce->getArg(arg));
634             const DeclRefExpr * a = fcaVisitor.RetrieveDeclRefExpr();
635             const UnaryOperator * op = fcaVisitor.RetrieveUnaryOp();
636             
637             bool isAddrOfR = op && op->getOpcode() == UnaryOperatorKind::UO_AddrOf;
638             bool isDerefR = op && op->getOpcode() == UnaryOperatorKind::UO_Deref;
639
640             if (a && !isAddrOfR) {
641                 assert (!isDerefR && "Must use atomic load for dereferences!");
642
643                 DeclsInCond.insert(a->getDecl());
644
645                 if (outputted > 0) nol << ", ";
646                 outputted++;
647
648                 bool alreadyMCVar = false;
649                 if (DeclToMCVar.find(a->getDecl()) != DeclToMCVar.end()) {
650                     alreadyMCVar = true;
651                     nol << DeclToMCVar[a->getDecl()];
652                 }
653                 else {
654                     std::string an = "MCID_NODEP";
655                     ProvisionalName * v = new ProvisionalName(nol.tellp(), a, an.length());
656                     nol << an;
657                     vars->push_back(v);
658                 }
659
660                 DeclsRead.insert(a->getDecl());
661             }
662             else {
663                 if (outputted > 0) nol << ", ";
664                 outputted++;
665
666                 nol << "MCID_NODEP";
667             }
668         }
669         nol << ");\n";
670
671         SourceLocation place = s ? s->getLocStart() : ce->getLocStart();
672         const Expr * e = s ? dyn_cast<Expr>(s) : ce;
673         if (e && Redirector.count(e) > 0)
674             place = Redirector[e];
675         Update * u = new Update(place, nol.str(), vars);
676         DeferredUpdates.insert(DeferredUpdates.begin(), u);
677     }
678     
679     private:
680     Rewriter &rewrite;
681     FindCallArgVisitor fcaVisitor;
682     std::set<const NamedDecl *> &DeclsRead;
683     std::set<const NamedDecl *> &DeclsInCond;
684     std::map<const NamedDecl *, std::string> &DeclToMCVar;
685     std::map<const Expr *, std::string> &ExprToMCVar;
686     std::set<const Stmt *> &StmtsHandled;
687     std::vector<Update *> &DeferredUpdates;
688     std::map<const Expr *, SourceLocation> &Redirector;
689 };
690
691 class FindReturnsBreaksVisitor : public RecursiveASTVisitor<FindReturnsBreaksVisitor> {
692 public:
693     FindReturnsBreaksVisitor() : Returns(), Breaks() {}
694
695     bool VisitStmt(Stmt * s) {
696         if (isa<ReturnStmt>(s))
697             Returns.push_back(cast<ReturnStmt>(s));
698
699         if (isa<BreakStmt>(s))
700             Breaks.push_back(cast<BreakStmt>(s));
701         return true;
702     }
703
704     void Clear() {
705         Returns.clear(); Breaks.clear();
706     }
707
708     const std::vector<const ReturnStmt *> RetrieveReturns() {
709         return Returns;
710     }
711
712     const std::vector<const BreakStmt *> RetrieveBreaks() {
713         return Breaks;
714     }
715
716 private:
717     std::vector<const ReturnStmt *> Returns;
718     std::vector<const BreakStmt *> Breaks;
719 };
720
721 class LoopHandler : public MatchFinder::MatchCallback {
722 public:
723     LoopHandler(Rewriter &rewrite) : rewrite(rewrite) {}
724
725     virtual void run(const MatchFinder::MatchResult &Result) {
726         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("s");
727
728         rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(s->getLocStart()),
729                            "MC2_enterLoop();\n", true, true);
730
731         // annotate all returns with MC2_exitLoop()
732         // annotate all breaks that aren't further nested with MC2_exitLoop().
733         FindReturnsBreaksVisitor frbv;
734         if (isa<ForStmt>(s))
735             frbv.TraverseStmt(const_cast<Stmt *>(cast<ForStmt>(s)->getBody()));
736         if (isa<WhileStmt>(s))
737             frbv.TraverseStmt(const_cast<Stmt *>(cast<WhileStmt>(s)->getBody()));
738         if (isa<DoStmt>(s))
739             frbv.TraverseStmt(const_cast<Stmt *>(cast<DoStmt>(s)->getBody()));
740
741         for (auto & r : frbv.RetrieveReturns()) {
742             rewrite.InsertText(r->getLocStart(), "MC2_exitLoop();\n", true, true);
743         }
744         
745         // need to find all breaks and returns embedded inside the loop
746
747         rewrite.InsertTextAfterToken(rewrite.getSourceMgr().getExpansionLoc(s->getLocEnd().getLocWithOffset(1)),
748                                      "\nMC2_exitLoop();\n");
749     }
750
751 private:
752     Rewriter &rewrite;
753 };
754
755 /* Inserts MC2_function for any variables which are subsequently used by the model checker, as long as they depend on MC-visible [currently: read] state. */
756 class AssignHandler : public MatchFinder::MatchCallback {
757 public:
758     AssignHandler(Rewriter &rewrite, std::set<const NamedDecl *> &DeclsRead,
759                   std::set<const NamedDecl *> &DeclsInCond,
760                   std::set<const VarDecl *> &DeclsNeedingMC,
761                   std::map<const NamedDecl *, std::string> &DeclToMCVar,
762                   std::set<const Stmt *> &StmtsHandled,
763                   std::set<const Expr *> &MallocExprs,
764                   std::vector<Update *> &DeferredUpdates) :
765         rewrite(rewrite),
766         DeclsRead(DeclsRead),
767         DeclsInCond(DeclsInCond),
768         DeclsNeedingMC(DeclsNeedingMC),
769         DeclToMCVar(DeclToMCVar),
770         StmtsHandled(StmtsHandled),
771         MallocExprs(MallocExprs),
772         DeferredUpdates(DeferredUpdates) {}
773
774     virtual void run(const MatchFinder::MatchResult &Result) {
775         BinaryOperator * op = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("op"));
776         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
777         FindLocalsVisitor locals, locals_rhs;
778
779         const VarDecl * lhs = NULL;
780         const Expr * rhs = NULL;
781         const DeclStmt * ds;
782
783         if (s && (ds = dyn_cast<DeclStmt>(s))) {
784             // long term goal: refactor the run() method to deal with one assignment at a time
785             // for now, if there is only declarations and no rhs's, we'll ignore this stmt
786             if (!ds->isSingleDecl()) {
787                 for (auto & d : ds->decls()) {
788                     VarDecl * vd = dyn_cast<VarDecl>(d);
789                     if (!d || vd->hasInit())
790                         assert(0 && "unsupported form of decl");
791                 }
792                 return;
793             }
794
795             lhs = retrieveSingleDecl(ds);
796         }
797
798         if (StmtsHandled.find(ds) != StmtsHandled.end() || StmtsHandled.find(op) != StmtsHandled.end())
799             return;
800
801         if (lhs) {
802             if (lhs->hasInit()) {
803                 rhs = lhs->getInit();
804                 if (rhs) {
805                     rhs = rhs->IgnoreCasts();
806                 }
807             }
808             else
809                 return;
810         }
811         std::set<std::string> mcState;
812
813         bool lhsUsedInCond;
814         bool rhsRead = false;
815
816         bool lhsTooComplicated = false;
817         if (op) {
818             DeclRefExpr * vd;
819             if ((vd = dyn_cast<DeclRefExpr>(op->getLHS())))
820                 lhs = dyn_cast<VarDecl>(vd->getDecl());
821             else {
822                 // kick the can along...
823                 lhsTooComplicated = true;
824             }
825
826             rhs = op->getRHS();
827             if (rhs) 
828                 rhs = rhs->IgnoreCasts();
829         }
830
831         // rhs must be MC-active state, i.e. in declsread
832         // lhs must be subsequently used in (1) store/load or (2) branch condition or (3) other functions and (3a) uses values from other functions or (3b) uses values from loads, stores, or phi functions
833
834         if (rhs) {
835             locals_rhs.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
836             for (auto & nd : locals_rhs.RetrieveVars()) {
837                 if (DeclsRead.find(nd) != DeclsRead.end())
838                     rhsRead = true;
839             }
840         }
841
842         locals.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
843
844         lhsUsedInCond = DeclsInCond.find(lhs) != DeclsInCond.end();
845         if (lhsUsedInCond) {
846             for (auto & d : locals.RetrieveVars()) {
847                 if (DeclToMCVar.count(d) > 0)
848                     mcState.insert(DeclToMCVar[d]);
849                 else if (DeclsRead.find(d) != DeclsRead.end())
850                     mcState.insert(encode(d->getName().str()));
851             }
852         }
853         if (rhsRead) {
854             for (auto & d : locals_rhs.RetrieveVars()) {
855                 if (DeclToMCVar.count(d) > 0)
856                     mcState.insert(DeclToMCVar[d]);
857                 else if (DeclsRead.find(d) != DeclsRead.end())
858                     mcState.insert(encode(d->getName().str()));
859             }
860         }
861         if (mcState.size() > 0 || MallocExprs.find(rhs) != MallocExprs.end()) {
862             if (lhsTooComplicated)
863                 assert(0 && "couldn't find LHS of = operator");
864
865             std::stringstream nol;
866             std::string _lhsStr, lhsStr;
867             std::string mcVar = encodeFn(fnCount++);
868             if (lhs) {
869                 lhsStr = lhs->getName().str();
870                 _lhsStr = encode(lhsStr);
871                 DeclToMCVar[lhs] = mcVar;
872                 DeclsNeedingMC.insert(cast<VarDecl>(lhs));
873             }
874             int function_id = 0;
875             if (!(MallocExprs.find(rhs) != MallocExprs.end()))
876                 function_id = ++funcCount;
877             nol << "\n" << mcVar << " = MC2_function_id(" << function_id << ", " << mcState.size();
878             if (lhs)
879                 nol << ", sizeof (" << lhsStr << "), (uint64_t)" << lhsStr;
880             else 
881                 nol << ", MC2_PTR_LENGTH";
882             for (auto & d : mcState) {
883                 nol <<  ", ";
884                 if (_lhsStr == d)
885                     nol << mcVar;
886                 else
887                     nol << d;
888             }
889             nol << "); ";
890             SourceLocation place;
891             if (op) {
892                 place = Lexer::getLocForEndOfToken(op->getLocEnd(), 0, rewrite.getSourceMgr(), rewrite.getLangOpts()).getLocWithOffset(1);
893             } else
894                 place = s->getLocEnd();
895             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(place.getLocWithOffset(1)),
896                                nol.str(), true, true);
897
898             updateProvisionalName(DeferredUpdates, lhs, mcVar);
899         }
900     }
901
902     private:
903     Rewriter &rewrite;
904     std::set<const NamedDecl *> &DeclsRead, &DeclsInCond;
905     std::set<const VarDecl *> &DeclsNeedingMC;
906     std::map<const NamedDecl *, std::string> &DeclToMCVar;
907     std::set<const Stmt *> &StmtsHandled;
908     std::set<const Expr *> &MallocExprs;
909     std::vector<Update *> &DeferredUpdates;
910 };
911
912 // record vars used in conditions
913 class BranchConditionRefactoringHandler : public MatchFinder::MatchCallback {
914 public:
915     BranchConditionRefactoringHandler(Rewriter &rewrite,
916                                       std::set<const NamedDecl *> & DeclsInCond,
917                                       std::map<const NamedDecl *, std::string> &DeclToMCVar,
918                                       std::map<const Expr *, std::string> &ExprToMCVar,
919                                       std::map<const Expr *, SourceLocation> &Redirector,
920                                       std::vector<Update *> &DeferredUpdates) :
921         rewrite(rewrite), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar),
922         ExprToMCVar(ExprToMCVar), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
923
924     virtual void run(const MatchFinder::MatchResult &Result) {
925         IfStmt * is = const_cast<IfStmt *>(Result.Nodes.getNodeAs<IfStmt>("if"));
926         Expr * cond = is->getCond();
927
928         // refactor out complicated conditions
929         FindCallArgVisitor flv;
930         flv.TraverseStmt(cond);
931         std::string mcVar;
932
933         BinaryOperator * bc = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("bc"));
934         if (bc) {
935             std::string condVar = encodeCond(condCount++);
936             std::stringstream condVarEncoded;
937             condVarEncoded << condVar << "_m";
938
939             // prettyprint the binary op
940             // e.g. int _cond0 = x == y;
941             std::string SStr;
942             llvm::raw_string_ostream S(SStr);
943             bc->printPretty(S, nullptr, rewrite.getLangOpts());
944             const std::string &Str = S.str();
945             
946             std::stringstream prel;
947
948             bool is_equality = false;
949             // handle equality tests
950             if (bc->getOpcode() == BO_EQ) {
951                 Expr * lhs = bc->getLHS()->IgnoreCasts(), * rhs = bc->getRHS()->IgnoreCasts();
952                 if (isa<DeclRefExpr>(lhs) && isa<DeclRefExpr>(rhs)) {
953                     DeclRefExpr * l = dyn_cast<DeclRefExpr>(lhs), *r = dyn_cast<DeclRefExpr>(rhs);
954                     is_equality = true;
955                     prel << "\nMCID " << condVarEncoded.str() << ";\n";
956                     std::string ld = DeclToMCVar.find(l->getDecl())->second,
957                         rd = DeclToMCVar.find(r->getDecl())->second;
958
959                     prel << "\nint " << condVar << " = MC2_equals(" <<
960                         ld << ", (uint64_t)" << l->getNameInfo().getName().getAsString() << ", " <<
961                         rd << ", (uint64_t)" << r->getNameInfo().getName().getAsString() << ", " <<
962                         "&" << condVarEncoded.str() << ");\n";
963                 }
964             }
965
966             if (!is_equality) {
967                 prel << "\nint " << condVar << " = " << Str << ";";
968                 prel << "\nMCID " << condVarEncoded.str() << " = MC2_function_id(" << ++funcCount << ", 1, sizeof(" << condVar << "), " << condVar;
969                 const DeclRefExpr * d = flv.RetrieveDeclRefExpr();
970                 if (DeclToMCVar.find(d->getDecl()) != DeclToMCVar.end()) {
971                     prel << ", " << DeclToMCVar[d->getDecl()];
972                 }
973                 prel << ");\n";
974             }
975
976             ExprToMCVar[cond] = condVarEncoded.str();
977             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(is->getLocStart()),
978                                prel.str(), false, true);
979
980             // rewrite the binary op with the newly-inserted var
981             Expr * RO = bc->getRHS(); // used for location only
982
983             int cl = Lexer::MeasureTokenLength(RO->getLocStart(), rewrite.getSourceMgr(), rewrite.getLangOpts());
984             SourceRange SR(cond->getLocStart(), rewrite.getSourceMgr().getExpansionLoc(RO->getLocStart()).getLocWithOffset(cl-1));
985             rewrite.ReplaceText(SR, condVar);
986         } else {
987             std::string condVar = encodeCond(condCount++);
988             std::stringstream condVarEncoded;
989             condVarEncoded << condVar << "_m";
990
991             std::string SStr;
992             llvm::raw_string_ostream S(SStr);
993             cond->printPretty(S, nullptr, rewrite.getLangOpts());
994             const std::string &Str = S.str();
995
996             std::stringstream prel;
997             prel << "\nint " << condVar << " = " << Str << ";";
998             prel << "\nMCID " << condVarEncoded.str() << " = MC2_function_id(" << ++funcCount << ", 1, sizeof(" << condVar << "), " << condVar;
999             std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
1000             const DeclRefExpr * d = flv.RetrieveDeclRefExpr();
1001             if (isa<VarDecl>(d->getDecl()) && DeclToMCVar.find(d->getDecl()) != DeclToMCVar.end()) {
1002                 prel << ", " << DeclToMCVar[d->getDecl()];
1003             } else {
1004                 prel << ", ";
1005                 ProvisionalName * v = new ProvisionalName(prel.tellp(), d, 0);
1006                 vars->push_back(v);
1007             }
1008             prel << ");\n";
1009
1010             ExprToMCVar[cond] = condVarEncoded.str();
1011             // gross hack; should look for any callexprs in cond
1012             // but right now, if it's a unaryop, just manually traverse
1013             if (isa<UnaryOperator>(cond)) {
1014                 Expr * e = dyn_cast<UnaryOperator>(cond)->getSubExpr();
1015                 ExprToMCVar[e] = condVarEncoded.str();
1016             }
1017             Update * u = new Update(is->getLocStart(), prel.str(), vars);
1018             DeferredUpdates.push_back(u);
1019
1020             // rewrite the call op with the newly-inserted var
1021             SourceRange SR(cond->getLocStart(), cond->getLocEnd());
1022             Redirector[cond] = is->getLocStart();
1023             rewrite.ReplaceText(SR, condVar);
1024         }
1025
1026         std::deque<const Decl *> q;
1027         const NamedDecl * d = flv.RetrieveDeclRefExpr()->getDecl();
1028         q.push_back(d);
1029         while (!q.empty()) {
1030             const Decl * d = q.back();
1031             q.pop_back();
1032             if (isa<NamedDecl>(d))
1033                 DeclsInCond.insert(cast<NamedDecl>(d));
1034
1035             const VarDecl * vd;
1036             if ((vd = dyn_cast<VarDecl>(d))) {
1037                 if (vd->hasInit()) {
1038                     const Expr * e = vd->getInit();
1039                     flv.Clear();
1040                     flv.TraverseStmt(const_cast<Expr *>(e));
1041                     const NamedDecl * d = flv.RetrieveDeclRefExpr()->getDecl();
1042                     q.push_back(d);
1043                 }
1044             }
1045         }
1046     }
1047
1048 private:
1049     Rewriter &rewrite;
1050     std::set<const NamedDecl *> & DeclsInCond;
1051     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1052     std::map<const Expr *, std::string> &ExprToMCVar;
1053     std::map<const Expr *, SourceLocation> &Redirector;
1054     std::vector<Update *> &DeferredUpdates;
1055 };
1056
1057 class BranchAnnotationHandler : public MatchFinder::MatchCallback {
1058 public:
1059     BranchAnnotationHandler(Rewriter &rewrite,
1060                             std::map<const NamedDecl *, std::string> & DeclToMCVar,
1061                             std::map<const Expr *, std::string> & ExprToMCVar)
1062         : rewrite(rewrite),
1063           DeclToMCVar(DeclToMCVar),
1064           ExprToMCVar(ExprToMCVar){}
1065     virtual void run(const MatchFinder::MatchResult &Result) {
1066         IfStmt * is = const_cast<IfStmt *>(Result.Nodes.getNodeAs<IfStmt>("if"));
1067
1068         // if the branch condition is interesting:
1069         // (but right now, not too interesting)
1070         Expr * cond = is->getCond()->IgnoreCasts();
1071
1072         FindLocalsVisitor flv;
1073         flv.TraverseStmt(cond);
1074         if (flv.RetrieveVars().size() == 0) return;
1075
1076         const NamedDecl * condVar = flv.RetrieveVars()[0];
1077
1078         std::string mCondVar;
1079         if (ExprToMCVar.count(cond) > 0)
1080             mCondVar = ExprToMCVar[cond];
1081         else if (DeclToMCVar.count(condVar) > 0) 
1082             mCondVar = DeclToMCVar[condVar];
1083         else
1084             mCondVar = encode(condVar->getName());
1085         std::string brVar = encodeBranch(branchCount++);
1086
1087         std::stringstream brline;
1088         brline << "MCID " << brVar << ";\n";
1089         rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(is->getLocStart()),
1090                            brline.str(), false, true);
1091
1092         Stmt * ts = is->getThen(), * es = is->getElse();
1093         bool tHasChild = hasChild(ts);
1094         SourceLocation tfl;
1095         if (tHasChild) {
1096             if (isa<CompoundStmt>(ts))
1097                 tfl = getFirstChild(ts)->getLocStart();
1098             else
1099                 tfl = ts->getLocStart();
1100         } else
1101             tfl = ts->getLocStart().getLocWithOffset(1);
1102         SourceLocation tsl = ts->getLocEnd().getLocWithOffset(-1);
1103
1104         std::stringstream tlineStart, mergeStmt, eline;
1105
1106         UnaryOperator * uop = dyn_cast<UnaryOperator>(cond);
1107         tlineStart << brVar << " = MC2_branchUsesID(" << mCondVar << ", " << "1" << ", 2, true);\n";
1108         eline << brVar << " = MC2_branchUsesID(" << mCondVar << ", " << "0" << ", 2, true);";
1109
1110         mergeStmt << "\tMC2_merge(" << brVar << ");\n";
1111
1112         rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(tfl), tlineStart.str(), false, true);
1113
1114         Stmt * tls = NULL;
1115         int extra_else_offset = 0;
1116
1117         if (tHasChild) { tls = getLastChild(ts); }
1118         if (tls) extra_else_offset = 2; else extra_else_offset = 1;
1119
1120         if (!tHasChild || (!isa<ReturnStmt>(tls) && !isa<BreakStmt>(tls))) {
1121             extra_else_offset = 0;
1122             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(tsl.getLocWithOffset(1)),
1123                                mergeStmt.str(), true, true);
1124         }
1125         if (tHasChild && !isa<CompoundStmt>(ts)) {
1126             rewrite.InsertText(rewrite.getSourceMgr().getFileLoc(tls->getLocStart()), "{", false, true);
1127             SourceLocation tend = Lexer::findLocationAfterToken(tls->getLocEnd(), tok::semi, rewrite.getSourceMgr(), rewrite.getLangOpts(), false);
1128             rewrite.InsertText(tend, "}", true, true);
1129         }
1130         if (tHasChild && isa<CompoundStmt>(ts)) extra_else_offset++;
1131
1132         if (es) {
1133             SourceLocation esl = es->getLocEnd().getLocWithOffset(-1);
1134             bool eHasChild = hasChild(es); 
1135             Stmt * els = NULL;
1136             if (eHasChild) els = getLastChild(es); else els = es;
1137             
1138             eline << "\n";
1139
1140             SourceLocation el;
1141             if (eHasChild) {
1142                 if (isa<CompoundStmt>(es))
1143                     el = getFirstChild(es)->getLocStart();
1144                 else {
1145                     el = es->getLocStart();
1146                 }
1147             } else
1148                 el = es->getLocStart().getLocWithOffset(1);
1149             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(el), eline.str(), false, true);
1150
1151             if (eHasChild && !isa<CompoundStmt>(es)) {
1152                 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(el), "{", false, true);
1153                 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(es->getLocEnd().getLocWithOffset(1)), "}", true, true);
1154             }
1155
1156             if (!eHasChild || (!isa<ReturnStmt>(els) && !isa<BreakStmt>(els)))
1157                 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(esl.getLocWithOffset(1)), mergeStmt.str(), true, true);
1158         }
1159         else {
1160             std::stringstream eCompoundLine;
1161             eCompoundLine << " else { " << eline.str() << mergeStmt.str() << " }";
1162             SourceLocation tend = Lexer::findLocationAfterToken(ts->getLocEnd(), tok::semi, rewrite.getSourceMgr(), rewrite.getLangOpts(), false);
1163             if (!tend.isValid())
1164                 tend = Lexer::getLocForEndOfToken(ts->getLocEnd(), 0, rewrite.getSourceMgr(), rewrite.getLangOpts());
1165             rewrite.InsertText(tend.getLocWithOffset(1), eCompoundLine.str(), false, true);
1166         }
1167     }
1168 private:
1169
1170     bool hasChild(Stmt * s) {
1171         if (!isa<CompoundStmt>(s)) return true;
1172         return (!cast<CompoundStmt>(s)->body_empty());
1173     }
1174
1175     Stmt * getFirstChild(Stmt * s) {
1176         assert(isa<CompoundStmt>(s) && "haven't yet added code to rewrite then/elsestmt to CompoundStmt");
1177         assert(!cast<CompoundStmt>(s)->body_empty());
1178         return *(cast<CompoundStmt>(s)->body_begin());
1179     }
1180
1181     Stmt * getLastChild(Stmt * s) {
1182         CompoundStmt * cs;
1183         if ((cs = dyn_cast<CompoundStmt>(s))) {
1184             assert (!cs->body_empty());
1185             return cs->body_back();
1186         }
1187         return s;
1188     }
1189
1190     Rewriter &rewrite;
1191     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1192     std::map<const Expr *, std::string> &ExprToMCVar;
1193 };
1194
1195 class FunctionCallHandler : public MatchFinder::MatchCallback {
1196 public:
1197     FunctionCallHandler(Rewriter &rewrite,
1198                         std::map<const NamedDecl *, std::string> &DeclToMCVar,
1199                         std::set<const FunctionDecl *> &ThreadMains)
1200         : rewrite(rewrite), DeclToMCVar(DeclToMCVar), ThreadMains(ThreadMains) {}
1201
1202     virtual void run(const MatchFinder::MatchResult &Result) {
1203         CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
1204         Decl * d = ce->getCalleeDecl();
1205         NamedDecl * nd = dyn_cast<NamedDecl>(d);
1206         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
1207         ASTContext *Context = Result.Context;
1208
1209         if (nd->getName() == "thrd_create") {
1210             Expr * callee0 = ce->getArg(1)->IgnoreCasts();
1211             UnaryOperator * callee1;
1212             if ((callee1 = dyn_cast<UnaryOperator>(callee0))) {
1213                 if (callee1->getOpcode() == UnaryOperatorKind::UO_AddrOf)
1214                     callee0 = callee1->getSubExpr();
1215             }
1216             DeclRefExpr * callee = dyn_cast<DeclRefExpr>(callee0);
1217             if (!callee) return;
1218             FunctionDecl * fd = dyn_cast<FunctionDecl>(callee->getDecl());
1219             ThreadMains.insert(fd);
1220             return;
1221         }
1222
1223         if (!d->hasBody())
1224             return;
1225
1226         if (s && !ce->getCallReturnType(*Context)->isVoidType()) {
1227             // TODO check that the type is mc-visible also?
1228             const DeclStmt * ds;
1229             const VarDecl * lhs = NULL;
1230             std::string mc_rv = encodeRV(rvCount++);
1231
1232             std::stringstream brline;
1233             brline << "MCID " << mc_rv << ";\n";
1234             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(s->getLocStart()),
1235                                brline.str(), false, true);
1236
1237             std::stringstream nol;
1238             if (ce->getNumArgs() > 0) nol << ", ";
1239             nol << "&" << mc_rv;
1240             rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(ce->getRParenLoc()),
1241                                      nol.str());
1242
1243             if (s && (ds = dyn_cast<DeclStmt>(s))) {
1244                 if (!ds->isSingleDecl()) {
1245                     for (auto & d : ds->decls()) {
1246                         VarDecl * vd = dyn_cast<VarDecl>(d);
1247                         if (!d || vd->hasInit())
1248                             assert(0 && "unsupported form of decl");
1249                     }
1250                     return;
1251                 }
1252
1253                 lhs = retrieveSingleDecl(ds);
1254             }
1255
1256             DeclToMCVar[lhs] = mc_rv;
1257         }
1258
1259         for (const auto & a : ce->arguments()) {
1260             std::stringstream nol;
1261
1262             std::string aa = "MCID_NODEP";
1263
1264             Expr * e = a->IgnoreCasts();
1265             DeclRefExpr * dr = dyn_cast<DeclRefExpr>(e);
1266             if (dr) { 
1267                 NamedDecl * d = dr->getDecl();
1268                 if (DeclToMCVar.find(d) != DeclToMCVar.end())
1269                     aa = DeclToMCVar[d];
1270             }
1271
1272             nol << aa << ", ";
1273             
1274             if (a->getLocEnd().isValid())
1275                 rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(a->getLocStart()),
1276                                          nol.str());
1277         }
1278     }
1279
1280 private:
1281     Rewriter &rewrite;
1282     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1283     std::set<const FunctionDecl *> &ThreadMains;
1284 };
1285
1286 class ReturnHandler : public MatchFinder::MatchCallback {
1287 public:
1288     ReturnHandler(Rewriter &rewrite,
1289                   std::map<const NamedDecl *, std::string> &DeclToMCVar,
1290                   std::set<const FunctionDecl *> &ThreadMains)
1291         : rewrite(rewrite), DeclToMCVar(DeclToMCVar), ThreadMains(ThreadMains) {}
1292
1293     virtual void run(const MatchFinder::MatchResult &Result) {
1294         const FunctionDecl * fd = Result.Nodes.getNodeAs<FunctionDecl>("containingFunction");
1295         ReturnStmt * rs = const_cast<ReturnStmt *>(Result.Nodes.getNodeAs<ReturnStmt>("returnStmt"));
1296         Expr * rv = const_cast<Expr *>(rs->getRetValue());
1297
1298         if (!rv) return;        
1299         if (ThreadMains.find(fd) != ThreadMains.end()) return;
1300         // not sure why this is explicitly needed, but crashes without it
1301         if (!fd->getIdentifier() || fd->getName() == "user_main") return;
1302
1303         FindLocalsVisitor flv;
1304         flv.TraverseStmt(rv);
1305         std::string mrv = "MCID_NODEP";
1306
1307         if (flv.RetrieveVars().size() > 0) {
1308             const NamedDecl * returnVar = flv.RetrieveVars()[0];
1309             if (DeclToMCVar.find(returnVar) != DeclToMCVar.end()) {
1310                 mrv = DeclToMCVar[returnVar];
1311             }
1312         }
1313         std::stringstream nol;
1314         nol << "*retval = " << mrv << ";\n";
1315         rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(rs->getLocStart()),
1316                            nol.str(), false, true);
1317     }
1318
1319 private:
1320     Rewriter &rewrite;
1321     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1322     std::set<const FunctionDecl *> &ThreadMains;
1323 };
1324
1325 class VarDeclHandler : public MatchFinder::MatchCallback {
1326 public:
1327     VarDeclHandler(Rewriter &rewrite,
1328                    std::map<const NamedDecl *, std::string> &DeclToMCVar,
1329                    std::set<const VarDecl *> &DeclsNeedingMC)
1330         : rewrite(rewrite), DeclToMCVar(DeclToMCVar), DeclsNeedingMC(DeclsNeedingMC) {}
1331
1332     virtual void run(const MatchFinder::MatchResult &Result) {
1333         VarDecl * d = const_cast<VarDecl *>(Result.Nodes.getNodeAs<VarDecl>("d"));
1334         std::stringstream nol;
1335
1336         if (DeclsNeedingMC.find(d) == DeclsNeedingMC.end()) return;
1337
1338         std::string dn;
1339         if (DeclToMCVar.find(d) != DeclToMCVar.end())
1340             dn = DeclToMCVar[d];
1341         else
1342             dn = encode(d->getName().str());
1343
1344         nol << "MCID " << dn << "; ";
1345
1346         if (d->getLocStart().isValid())
1347             rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(d->getLocStart()),
1348                                      nol.str());
1349     }
1350
1351 private:
1352     Rewriter &rewrite;
1353     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1354     std::set<const VarDecl *> &DeclsNeedingMC;
1355 };
1356
1357 class FunctionDeclHandler : public MatchFinder::MatchCallback {
1358 public:
1359     FunctionDeclHandler(Rewriter &rewrite,
1360                         std::set<const FunctionDecl *> &ThreadMains)
1361         : rewrite(rewrite), ThreadMains(ThreadMains) {}
1362
1363     virtual void run(const MatchFinder::MatchResult &Result) {
1364         FunctionDecl * fd = const_cast<FunctionDecl *>(Result.Nodes.getNodeAs<FunctionDecl>("fd"));
1365
1366         if (!fd->getIdentifier()) return;
1367
1368         if (fd->getName() == "user_main") { ThreadMains.insert(fd); return; }
1369
1370         if (ThreadMains.find(fd) != ThreadMains.end()) return;
1371
1372         SourceLocation LastParam = fd->getNameInfo().getLocStart().getLocWithOffset(fd->getName().size()).getLocWithOffset(1);
1373         for (auto & p : fd->params()) {
1374             std::stringstream nol;
1375             nol << "MCID " << encode(p->getName()) << ", ";
1376             if (p->getLocStart().isValid())
1377                 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(p->getLocStart()),
1378                                    nol.str(), false);
1379             if (p->getLocEnd().isValid())
1380                 LastParam = p->getLocEnd().getLocWithOffset(p->getName().size());
1381         }
1382
1383         if (!fd->getReturnType()->isVoidType()) {
1384             std::stringstream nol;
1385             if (fd->param_size() > 0) nol << ", ";
1386             nol << "MCID * retval";
1387             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(LastParam),
1388                                nol.str(), false);
1389         }
1390     }
1391
1392 private:
1393     Rewriter &rewrite;
1394     std::set<const FunctionDecl *> &ThreadMains;
1395 };
1396
1397 class BailHandler : public MatchFinder::MatchCallback {
1398 public:
1399     BailHandler() {}
1400     virtual void run(const MatchFinder::MatchResult &Result) {
1401         assert(0 && "we don't handle goto statements");
1402     }
1403 };
1404
1405 class MyASTConsumer : public ASTConsumer {
1406 public:
1407     MyASTConsumer(Rewriter &R) : R(R),
1408                                  DeclsRead(),
1409                                  DeclsInCond(),
1410                                  DeclToMCVar(),
1411                                  HandlerMalloc(MallocExprs),
1412                                  HandlerLoad(R, DeclsRead, DeclsNeedingMC, DeclToMCVar, ExprToMCVar, StmtsHandled, Redirector, DeferredUpdates),
1413                                  HandlerStore(R, DeclsRead, DeclsNeedingMC, DeferredUpdates),
1414                                  HandlerRMW(R, DeclsRead, DeclsInCond, DeclToMCVar, ExprToMCVar, StmtsHandled, Redirector, DeferredUpdates),
1415                                  HandlerLoop(R),
1416                                  HandlerBranchConditionRefactoring(R, DeclsInCond, DeclToMCVar, ExprToMCVar, Redirector, DeferredUpdates),
1417                                  HandlerAssign(R, DeclsRead, DeclsInCond, DeclsNeedingMC, DeclToMCVar, StmtsHandled, MallocExprs, DeferredUpdates),
1418                                  HandlerAnnotateBranch(R, DeclToMCVar, ExprToMCVar),
1419                                  HandlerFunctionDecl(R, ThreadMains),
1420                                  HandlerFunctionCall(R, DeclToMCVar, ThreadMains),
1421                                  HandlerReturn(R, DeclToMCVar, ThreadMains),
1422                                  HandlerVarDecl(R, DeclToMCVar, DeclsNeedingMC),
1423                                  HandlerBail() {
1424         MatcherFunctionCall.addMatcher(callExpr(anyOf(hasParent(compoundStmt()),
1425                                                       hasAncestor(varDecl(hasParent(stmt().bind("containingStmt")))),
1426                                                       hasAncestor(binaryOperator(hasOperatorName("=")).bind("containingStmt")))).bind("callExpr"),
1427                                        &HandlerFunctionCall);
1428         MatcherLoadStore.addMatcher
1429             (callExpr(callee(functionDecl(anyOf(hasName("malloc"), hasName("calloc"))))).bind("callExpr"),
1430              &HandlerMalloc);
1431
1432         MatcherLoadStore.addMatcher
1433             (callExpr(callee(functionDecl(anyOf(hasName("load_32"), hasName("load_64")))),
1434                       anyOf(hasAncestor(varDecl(hasParent(stmt().bind("containingStmt"))).bind("decl")),
1435                             hasAncestor(binaryOperator(hasOperatorName("=")).bind("containingStmt")),
1436                             hasParent(stmt().bind("containingStmt"))))
1437              .bind("callExpr"),
1438              &HandlerLoad);
1439
1440         MatcherLoadStore.addMatcher(callExpr(callee(functionDecl(anyOf(hasName("store_32"), hasName("store_64"))))).bind("callExpr"),
1441                                     &HandlerStore);
1442
1443         MatcherLoadStore.addMatcher
1444             (callExpr(callee(functionDecl(anyOf(hasName("rmw_32"), hasName("rmw_64")))),
1445                       anyOf(hasAncestor(varDecl(hasParent(stmt().bind("containingStmt"))).bind("decl")),
1446                             hasAncestor(binaryOperator(hasOperatorName("="),
1447                                                        hasLHS(declRefExpr().bind("lhs"))).bind("containingStmt")),
1448                             anything()))
1449              .bind("callExpr"),
1450              &HandlerRMW);
1451
1452         MatcherLoadStore.addMatcher(ifStmt(hasCondition
1453                                            (anyOf(binaryOperator().bind("bc"),
1454                                                   hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("load_32"), hasName("load_64"))))).bind("callExpr")),
1455                                                   hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("store_32"), hasName("store_64"))))).bind("callExpr")),
1456                                                   hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("rmw_32"), hasName("rmw_64"))))).bind("callExpr")),
1457                                                   anything()))).bind("if"),
1458                                     &HandlerBranchConditionRefactoring);
1459
1460         MatcherLoadStore.addMatcher(forStmt().bind("s"),
1461                                     &HandlerLoop);
1462         MatcherLoadStore.addMatcher(whileStmt().bind("s"),
1463                                     &HandlerLoop);
1464         MatcherLoadStore.addMatcher(doStmt().bind("s"),
1465                                     &HandlerLoop);
1466
1467         MatcherFunction.addMatcher(binaryOperator(anyOf(hasAncestor(declStmt().bind("containingStmt")),
1468                                                         hasParent(compoundStmt())),
1469                                                         hasOperatorName("=")).bind("op"),
1470                                    &HandlerAssign);
1471         MatcherFunction.addMatcher(declStmt().bind("containingStmt"), &HandlerAssign);
1472
1473         MatcherFunction.addMatcher(ifStmt().bind("if"),
1474                                    &HandlerAnnotateBranch);
1475
1476         MatcherFunctionDecl.addMatcher(functionDecl().bind("fd"),
1477                                        &HandlerFunctionDecl);
1478         MatcherFunctionDecl.addMatcher(varDecl().bind("d"), &HandlerVarDecl);
1479         MatcherFunctionDecl.addMatcher(returnStmt(hasAncestor(functionDecl().bind("containingFunction"))).bind("returnStmt"),
1480                                    &HandlerReturn);
1481
1482         MatcherSanity.addMatcher(gotoStmt(), &HandlerBail);
1483     }
1484
1485     // Override the method that gets called for each parsed top-level
1486     // declaration.
1487     void HandleTranslationUnit(ASTContext &Context) override {
1488         LangOpts = Context.getLangOpts();
1489
1490         MatcherFunctionCall.matchAST(Context);
1491         MatcherLoadStore.matchAST(Context);
1492         MatcherFunction.matchAST(Context);
1493         MatcherFunctionDecl.matchAST(Context);
1494         MatcherSanity.matchAST(Context);
1495
1496         for (auto & u : DeferredUpdates) {
1497             R.InsertText(R.getSourceMgr().getExpansionLoc(u->loc), u->update, true, true);
1498             delete u;
1499         }
1500         DeferredUpdates.clear();
1501     }
1502
1503 private:
1504     /* DeclsRead contains all local variables 'x' which:
1505     * 1) appear in 'x = load_32(...);
1506     * 2) appear in 'y = store_32(x); */
1507     std::set<const NamedDecl *> DeclsRead;
1508     /* DeclsInCond contains all local variables 'x' used in a branch condition or rmw parameter */
1509     std::set<const NamedDecl *> DeclsInCond;
1510     std::map<const NamedDecl *, std::string> DeclToMCVar;
1511     std::map<const Expr *, std::string> ExprToMCVar;
1512     std::set<const VarDecl *> DeclsNeedingMC;
1513     std::set<const FunctionDecl *> ThreadMains;
1514     std::set<const Stmt *> StmtsHandled;
1515     std::set<const Expr *> MallocExprs;
1516     std::map<const Expr *, SourceLocation> Redirector;
1517     std::vector<Update *> DeferredUpdates;
1518
1519     Rewriter &R;
1520
1521     MallocHandler HandlerMalloc;
1522     LoadHandler HandlerLoad;
1523     StoreHandler HandlerStore;
1524     RMWHandler HandlerRMW;
1525     LoopHandler HandlerLoop;
1526     BranchConditionRefactoringHandler HandlerBranchConditionRefactoring;
1527     BranchAnnotationHandler HandlerAnnotateBranch;
1528     AssignHandler HandlerAssign;
1529     FunctionDeclHandler HandlerFunctionDecl;
1530     FunctionCallHandler HandlerFunctionCall;
1531     ReturnHandler HandlerReturn;
1532     VarDeclHandler HandlerVarDecl;
1533     BailHandler HandlerBail;
1534     MatchFinder MatcherLoadStore, MatcherFunction, MatcherFunctionDecl, MatcherFunctionCall, MatcherSanity;
1535 };
1536
1537 // For each source file provided to the tool, a new FrontendAction is created.
1538 class MyFrontendAction : public ASTFrontendAction {
1539 public:
1540     MyFrontendAction() {}
1541     void EndSourceFileAction() override {
1542         SourceManager &SM = TheRewriter.getSourceMgr();
1543         llvm::errs() << "** EndSourceFileAction for: "
1544                      << SM.getFileEntryForID(SM.getMainFileID())->getName() << "\n";
1545
1546         // Now emit the rewritten buffer.
1547         TheRewriter.getEditBuffer(SM.getMainFileID()).write(llvm::outs());
1548     }
1549
1550     std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI,
1551                                                    StringRef file) override {
1552         llvm::errs() << "** Creating AST consumer for: " << file << "\n";
1553         TheRewriter.setSourceMgr(CI.getSourceManager(), CI.getLangOpts());
1554         return llvm::make_unique<MyASTConsumer>(TheRewriter);
1555     }
1556
1557 private:
1558     Rewriter TheRewriter;
1559 };
1560
1561 int main(int argc, const char **argv) {
1562     CommonOptionsParser op(argc, argv, AddMC2AnnotationsCategory);
1563     ClangTool Tool(op.getCompilations(), op.getSourcePathList());
1564     
1565     return Tool.run(newFrontendActionFactory<MyFrontendAction>().get());
1566 }