virtual void run(const MatchFinder::MatchResult &Result) {
BinaryOperator * op = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("op"));
const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
- FindLocalsVisitor flv;
+ FindLocalsVisitor locals, locals_rhs;
const VarDecl * lhs = NULL;
const Expr * rhs = NULL;
}
std::set<std::string> mcState;
+ bool lhsUsedInCond;
+ bool rhsRead = false;
+
bool lhsTooComplicated = false;
if (op) {
- flv.TraverseStmt(op);
-
DeclRefExpr * vd;
if ((vd = dyn_cast<DeclRefExpr>(op->getLHS())))
lhs = dyn_cast<VarDecl>(vd->getDecl());
if (rhs)
rhs = rhs->IgnoreCasts();
}
- else if (lhs) {
- // rhs must be MC-active state, i.e. in declsread
- // lhs must be subsequently used in (1) store/load or (2) branch condition or (3) other functions and (3a) uses values from other functions or (3b) uses values from loads, stores, or phi functions
- flv.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
+
+ // rhs must be MC-active state, i.e. in declsread
+ // lhs must be subsequently used in (1) store/load or (2) branch condition or (3) other functions and (3a) uses values from other functions or (3b) uses values from loads, stores, or phi functions
+
+ if (rhs) {
+ locals_rhs.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
+ for (auto & nd : locals_rhs.RetrieveVars()) {
+ if (DeclsRead.find(nd) != DeclsRead.end())
+ rhsRead = true;
+ }
}
- if (DeclsInCond.find(lhs) != DeclsInCond.end()) {
- for (auto & d : flv.RetrieveVars()) {
+ locals.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
+
+ lhsUsedInCond = DeclsInCond.find(lhs) != DeclsInCond.end();
+ if (lhsUsedInCond) {
+ for (auto & d : locals.RetrieveVars()) {
+ if (DeclToMCVar.count(d) > 0)
+ mcState.insert(DeclToMCVar[d]);
+ else if (DeclsRead.find(d) != DeclsRead.end())
+ mcState.insert(encode(d->getName().str()));
+ }
+ }
+ if (rhsRead) {
+ for (auto & d : locals_rhs.RetrieveVars()) {
if (DeclToMCVar.count(d) > 0)
mcState.insert(DeclToMCVar[d]);
else if (DeclsRead.find(d) != DeclsRead.end())
mcState.insert(encode(d->getName().str()));
}
}
-
if (mcState.size() > 0 || MallocExprs.find(rhs) != MallocExprs.end()) {
if (lhsTooComplicated)
assert(0 && "couldn't find LHS of = operator");
}
nol << "); ";
SourceLocation place;
- if (op)
- place = op->getLocEnd().getLocWithOffset(1);
- else
+ if (op) {
+ place = Lexer::getLocForEndOfToken(op->getLocEnd(), 0, rewrite.getSourceMgr(), rewrite.getLangOpts()).getLocWithOffset(1);
+ } else
place = s->getLocEnd();
rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(place.getLocWithOffset(1)),
nol.str(), true, true);
DeclRefExpr * l = dyn_cast<DeclRefExpr>(lhs), *r = dyn_cast<DeclRefExpr>(rhs);
is_equality = true;
prel << "\nMCID " << condVarEncoded.str() << ";\n";
- std::string ld = DeclToMCVar.find(l->getDecl())->second,
+ std::string ld, rd;
+ if (DeclToMCVar.find(l->getDecl()) != DeclToMCVar.end())
+ ld = DeclToMCVar.find(l->getDecl())->second;
+ else
+ ld = encode(l->getDecl()->getName());
+ if (DeclToMCVar.find(r->getDecl()) != DeclToMCVar.end())
rd = DeclToMCVar.find(r->getDecl())->second;
+ else
+ rd = encode(r->getDecl()->getName());
prel << "\nint " << condVar << " = MC2_equals(" <<
ld << ", (uint64_t)" << l->getNameInfo().getName().getAsString() << ", " <<
/* DeclsRead contains all local variables 'x' which:
* 1) appear in 'x = load_32(...);
* 2) appear in 'y = store_32(x); */
- std::set<const NamedDecl *> DeclsRead, DeclsInCond;
+ std::set<const NamedDecl *> DeclsRead;
+ /* DeclsInCond contains all local variables 'x' used in a branch condition or rmw parameter */
+ std::set<const NamedDecl *> DeclsInCond;
std::map<const NamedDecl *, std::string> DeclToMCVar;
std::map<const Expr *, std::string> ExprToMCVar;
std::set<const VarDecl *> DeclsNeedingMC;