Improve the determinism of MergeFunctions
authorJF Bastien <jfb@google.com>
Fri, 21 Aug 2015 23:27:24 +0000 (23:27 +0000)
committerJF Bastien <jfb@google.com>
Fri, 21 Aug 2015 23:27:24 +0000 (23:27 +0000)
Summary:

Merge functions previously relied on unsigned comparisons of pointer values to
order functions. This caused observable non-determinism in the compiler for
large bitcode programs. Basically, opt -mergefuncs program.bc | md5sum produces
different hashes when run repeatedly on the same machine. Differing output was
observed on three large bitcodes, but it was less frequent on the smallest file.
It is possible that this only manifests on the large inputs, hence remaining
undetected until now.

This patch fixes this by removing (almost, see below) all places where
comparisons between pointers are used to order functions. Most of these changes
are local, but the comparison of global values requires assigning an identifier
to each local in the order it is visited. This is very similar to the way the
comparison function identifies Value*'s defined within a function. Because the
order of visiting the functions and their subparts is deterministic, the
identifiers assigned to the globals will be as well, and the order of functions
will be deterministic.

With these changes, there is no more observed non-determinism. There is also
only minor slowdowns (negligible to 4%) compared to the baseline, which is
likely a result of the fact that global comparisons involve hash lookups and not
just pointer comparisons.

The one caveat so far is that programs containing BlockAddress constants can
still be non-deterministic. It is not clear what the right solution is here. In
particular, even if the global numbers are used to order by function, we still
need a way to order the BasicBlock*'s. Unfortunately, we cannot just bail out
and fail to order the functions or consider them equal, because we require a
total order over functions. Note that programs with BlockAddress constants are
relatively rare, so the impact of leaving this in is minor as long as this pass
is opt-in.

Author: jrkoenig

Reviewers: nlewycky, jfb, dschuff

Subscribers: jevinskie, llvm-commits, chapuni

Differential revision: http://reviews.llvm.org/D12168

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@245762 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/IPO/MergeFunctions.cpp
test/Transforms/MergeFunc/constant-entire-value.ll [new file with mode: 0644]

index 67d6b7fcb674b47e78d4df51f22ca0b0172eca7b..a31a08039796300ce075e17913c02e4c5e8b66b1 100644 (file)
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/ValueHandle.h"
+#include "llvm/IR/ValueMap.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -130,14 +131,50 @@ static cl::opt<unsigned> NumFunctionsForSanityCheck(
 
 namespace {
 
+/// GlobalNumberState assigns an integer to each global value in the program,
+/// which is used by the comparison routine to order references to globals. This
+/// state must be preserved throughout the pass, because Functions and other
+/// globals need to maintain their relative order. Globals are assigned a number
+/// when they are first visited. This order is deterministic, and so the
+/// assigned numbers are as well. When two functions are merged, neither number
+/// is updated. If the symbols are weak, this would be incorrect. If they are
+/// strong, then one will be replaced at all references to the other, and so
+/// direct callsites will now see one or the other symbol, and no update is
+/// necessary. Note that if we were guaranteed unique names, we could just
+/// compare those, but this would not work for stripped bitcodes or for those
+/// few symbols without a name.
+class GlobalNumberState {
+  struct Config : ValueMapConfig<GlobalValue*> {
+    enum { FollowRAUW = false };
+  };
+  // Each GlobalValue is mapped to an identifier. The Config ensures when RAUW
+  // occurs, the mapping does not change. Tracking changes is unnecessary, and
+  // also problematic for weak symbols (which may be overwritten).
+  typedef ValueMap<GlobalValue *, uint64_t, Config> ValueNumberMap;
+  ValueNumberMap GlobalNumbers;
+  // The next unused serial number to assign to a global.
+  uint64_t NextNumber;
+  public:
+    GlobalNumberState() : GlobalNumbers(), NextNumber(0) {}
+    uint64_t getNumber(GlobalValue* Global) {
+      ValueNumberMap::iterator MapIter;
+      bool Inserted;
+      std::tie(MapIter, Inserted) = GlobalNumbers.insert({Global, NextNumber});
+      if (Inserted)
+        NextNumber++;
+      return MapIter->second;
+    }
+};
+
 /// FunctionComparator - Compares two functions to determine whether or not
 /// they will generate machine code with the same behaviour. DataLayout is
 /// used if available. The comparator always fails conservatively (erring on the
 /// side of claiming that two functions are different).
 class FunctionComparator {
 public:
-  FunctionComparator(const Function *F1, const Function *F2)
-      : FnL(F1), FnR(F2) {}
+  FunctionComparator(const Function *F1, const Function *F2,
+                     GlobalNumberState* GN)
+      : FnL(F1), FnR(F2), GlobalNumbers(GN) {}
 
   /// Test whether the two functions have equivalent behaviour.
   int compare();
@@ -148,7 +185,7 @@ public:
 
 private:
   /// Test whether two basic blocks have equivalent behaviour.
-  int compare(const BasicBlock *BBL, const BasicBlock *BBR);
+  int cmpBasicBlocks(const BasicBlock *BBL, const BasicBlock *BBR);
 
   /// Constants comparison.
   /// Its analog to lexicographical comparison between hypothetical numbers
@@ -254,6 +291,10 @@ private:
   /// If these properties are equal - compare their contents.
   int cmpConstants(const Constant *L, const Constant *R);
 
+  /// Compares two global values by number. Uses the GlobalNumbersState to
+  /// identify the same gobals across function calls.
+  int cmpGlobalValues(GlobalValue *L, GlobalValue *R);
+
   /// Assign or look up previously assigned numbers for the two values, and
   /// return whether the numbers are equal. Numbers are assigned in the order
   /// visited.
@@ -333,8 +374,9 @@ private:
   ///
   /// 1. If types are of different kind (different type IDs).
   ///    Return result of type IDs comparison, treating them as numbers.
-  /// 2. If types are vectors or integers, compare Type* values as numbers.
-  /// 3. Types has same ID, so check whether they belongs to the next group:
+  /// 2. If types are integers, check that they have the same width. If they
+  /// are vectors, check that they have the same count and subtype.
+  /// 3. Types have the same ID, so check whether they are one of:
   /// * Void
   /// * Float
   /// * Double
@@ -343,8 +385,7 @@ private:
   /// * PPC_FP128
   /// * Label
   /// * Metadata
-  /// If so - return 0, yes - we can treat these types as equal only because
-  /// their IDs are same.
+  /// We can treat these types as equal whenever their IDs are same.
   /// 4. If Left and Right are pointers, return result of address space
   /// comparison (numbers comparison). We can treat pointer types of same
   /// address space as equal.
@@ -359,7 +400,8 @@ private:
 
   int cmpAPInts(const APInt &L, const APInt &R) const;
   int cmpAPFloats(const APFloat &L, const APFloat &R) const;
-  int cmpStrings(StringRef L, StringRef R) const;
+  int cmpInlineAsm(const InlineAsm *L, const InlineAsm *R) const;
+  int cmpMem(StringRef L, StringRef R) const;
   int cmpAttrs(const AttributeSet L, const AttributeSet R) const;
 
   // The two functions undergoing comparison.
@@ -399,33 +441,28 @@ private:
   /// could be operands from further BBs we didn't scan yet.
   /// So it's impossible to use dominance properties in general.
   DenseMap<const Value*, int> sn_mapL, sn_mapR;
+
+  // The global state we will use
+  GlobalNumberState* GlobalNumbers;
 };
 
 class FunctionNode {
   mutable AssertingVH<Function> F;
   FunctionComparator::FunctionHash Hash;
-
 public:
   // Note the hash is recalculated potentially multiple times, but it is cheap.
-  FunctionNode(Function *F) : F(F), Hash(FunctionComparator::functionHash(*F)){}
+  FunctionNode(Function *F)
+    : F(F), Hash(FunctionComparator::functionHash(*F))  {}
   Function *getFunc() const { return F; }
+  FunctionComparator::FunctionHash getHash() const { return Hash; }
 
   /// Replace the reference to the function F by the function G, assuming their
   /// implementations are equal.
   void replaceBy(Function *G) const {
-    assert(!(*this < FunctionNode(G)) && !(FunctionNode(G) < *this) &&
-           "The two functions must be equal");
-
     F = G;
   }
 
   void release() { F = 0; }
-  bool operator<(const FunctionNode &RHS) const {
-    // Order first by hashes, then full function comparison.
-    if (Hash != RHS.Hash)
-      return Hash < RHS.Hash;
-    return (FunctionComparator(F, RHS.getFunc()).compare()) == -1;
-  }
 };
 }
 
@@ -444,13 +481,17 @@ int FunctionComparator::cmpAPInts(const APInt &L, const APInt &R) const {
 }
 
 int FunctionComparator::cmpAPFloats(const APFloat &L, const APFloat &R) const {
-  if (int Res = cmpNumbers((uint64_t)&L.getSemantics(),
-                           (uint64_t)&R.getSemantics()))
+  // TODO: This correctly handles all existing fltSemantics, because they all
+  // have different precisions. This isn't very robust, however, if new types
+  // with different exponent ranges are introduced.
+  const fltSemantics &SL = L.getSemantics(), &SR = R.getSemantics();
+  if (int Res = cmpNumbers(APFloat::semanticsPrecision(SL),
+                           APFloat::semanticsPrecision(SR)))
     return Res;
   return cmpAPInts(L.bitcastToAPInt(), R.bitcastToAPInt());
 }
 
-int FunctionComparator::cmpStrings(StringRef L, StringRef R) const {
+int FunctionComparator::cmpMem(StringRef L, StringRef R) const {
   // Prevent heavy comparison, compare sizes first.
   if (int Res = cmpNumbers(L.size(), R.size()))
     return Res;
@@ -556,9 +597,25 @@ int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) {
   if (!L->isNullValue() && R->isNullValue())
     return -1;
 
+  auto GlobalValueL = const_cast<GlobalValue*>(dyn_cast<GlobalValue>(L));
+  auto GlobalValueR = const_cast<GlobalValue*>(dyn_cast<GlobalValue>(R));
+  if (GlobalValueL && GlobalValueR) {
+    return cmpGlobalValues(GlobalValueL, GlobalValueR);
+  }
+
   if (int Res = cmpNumbers(L->getValueID(), R->getValueID()))
     return Res;
 
+  if (const auto *SeqL = dyn_cast<ConstantDataSequential>(L)) {
+    const auto *SeqR = dyn_cast<ConstantDataSequential>(R);
+    // This handles ConstantDataArray and ConstantDataVector. Note that we
+    // compare the two raw data arrays, which might differ depending on the host
+    // endianness. This isn't a problem though, because the endiness of a module
+    // will affect the order of the constants, but this order is the same
+    // for a given input module and host platform.
+    return cmpMem(SeqL->getRawDataValues(), SeqR->getRawDataValues());
+  }
+
   switch (L->getValueID()) {
   case Value::UndefValueVal: return TypesRes;
   case Value::ConstantIntVal: {
@@ -627,12 +684,21 @@ int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) {
     }
     return 0;
   }
-  case Value::FunctionVal:
-  case Value::GlobalVariableVal:
-  case Value::GlobalAliasVal:
-  default: // Unknown constant, cast L and R pointers to numbers and compare.
+  case Value::BlockAddressVal: {
+    // FIXME: This still uses a pointer comparison. It isn't clear how to remove
+    // this. This only affects programs which take BlockAddresses and store them
+    // as constants, which is limited to interepreters, etc.
     return cmpNumbers((uint64_t)L, (uint64_t)R);
   }
+  default: // Unknown constant, abort.
+    DEBUG(dbgs() << "Looking at valueID " << L->getValueID() << "\n");
+    llvm_unreachable("Constant ValueID not recognized.");
+    return -1;
+  }
+}
+
+int FunctionComparator::cmpGlobalValues(GlobalValue *L, GlobalValue* R) {
+  return cmpNumbers(GlobalNumbers->getNumber(L), GlobalNumbers->getNumber(R));
 }
 
 /// cmpType - compares two types,
@@ -660,10 +726,15 @@ int FunctionComparator::cmpTypes(Type *TyL, Type *TyR) const {
     llvm_unreachable("Unknown type!");
     // Fall through in Release mode.
   case Type::IntegerTyID:
-  case Type::VectorTyID:
-    // TyL == TyR would have returned true earlier.
-    return cmpNumbers((uint64_t)TyL, (uint64_t)TyR);
-
+    return cmpNumbers(cast<IntegerType>(TyL)->getBitWidth(),
+                      cast<IntegerType>(TyR)->getBitWidth());
+  case Type::VectorTyID: {
+    VectorType *VTyL = cast<VectorType>(TyL), *VTyR = cast<VectorType>(TyR);
+    if (int Res = cmpNumbers(VTyL->getNumElements(), VTyR->getNumElements()))
+      return Res;
+    return cmpTypes(VTyL->getElementType(), VTyR->getElementType());
+  }
+  // TyL == TyR would have returned true earlier, because types are uniqued.
   case Type::VoidTyID:
   case Type::FloatTyID:
   case Type::DoubleTyID:
@@ -895,9 +966,8 @@ int FunctionComparator::cmpGEPs(const GEPOperator *GEPL,
   if (GEPL->accumulateConstantOffset(DL, OffsetL) &&
       GEPR->accumulateConstantOffset(DL, OffsetR))
     return cmpAPInts(OffsetL, OffsetR);
-
-  if (int Res = cmpNumbers((uint64_t)GEPL->getPointerOperand()->getType(),
-                           (uint64_t)GEPR->getPointerOperand()->getType()))
+  if (int Res = cmpTypes(GEPL->getPointerOperand()->getType(),
+                         GEPR->getPointerOperand()->getType()))
     return Res;
 
   if (int Res = cmpNumbers(GEPL->getNumOperands(), GEPR->getNumOperands()))
@@ -911,6 +981,28 @@ int FunctionComparator::cmpGEPs(const GEPOperator *GEPL,
   return 0;
 }
 
+int FunctionComparator::cmpInlineAsm(const InlineAsm *L,
+                                     const InlineAsm *R) const {
+  // InlineAsm's are uniqued. If they are the same pointer, obviously they are
+  // the same, otherwise compare the fields.
+  if (L == R)
+    return 0;
+  if (int Res = cmpTypes(L->getFunctionType(), R->getFunctionType()))
+    return Res;
+  if (int Res = cmpMem(L->getAsmString(), R->getAsmString()))
+    return Res;
+  if (int Res = cmpMem(L->getConstraintString(), R->getConstraintString()))
+    return Res;
+  if (int Res = cmpNumbers(L->hasSideEffects(), R->hasSideEffects()))
+    return Res;
+  if (int Res = cmpNumbers(L->isAlignStack(), R->isAlignStack()))
+    return Res;
+  if (int Res = cmpNumbers(L->getDialect(), R->getDialect()))
+    return Res;
+  llvm_unreachable("InlineAsm blocks were not uniqued.");
+  return 0;
+}
+
 /// Compare two values used by the two functions under pair-wise comparison. If
 /// this is the first time the values are seen, they're added to the mapping so
 /// that we will detect mismatches on next use.
@@ -945,7 +1037,7 @@ int FunctionComparator::cmpValues(const Value *L, const Value *R) {
   const InlineAsm *InlineAsmR = dyn_cast<InlineAsm>(R);
 
   if (InlineAsmL && InlineAsmR)
-    return cmpNumbers((uint64_t)L, (uint64_t)R);
+    return cmpInlineAsm(InlineAsmL, InlineAsmR);
   if (InlineAsmL)
     return 1;
   if (InlineAsmR)
@@ -957,7 +1049,8 @@ int FunctionComparator::cmpValues(const Value *L, const Value *R) {
   return cmpNumbers(LeftSN.first->second, RightSN.first->second);
 }
 // Test whether two basic blocks have equivalent behaviour.
-int FunctionComparator::compare(const BasicBlock *BBL, const BasicBlock *BBR) {
+int FunctionComparator::cmpBasicBlocks(const BasicBlock *BBL,
+                                       const BasicBlock *BBR) {
   BasicBlock::const_iterator InstL = BBL->begin(), InstLE = BBL->end();
   BasicBlock::const_iterator InstR = BBR->begin(), InstRE = BBR->end();
 
@@ -1020,7 +1113,7 @@ int FunctionComparator::compare() {
     return Res;
 
   if (FnL->hasGC()) {
-    if (int Res = cmpNumbers((uint64_t)FnL->getGC(), (uint64_t)FnR->getGC()))
+    if (int Res = cmpMem(FnL->getGC(), FnR->getGC()))
       return Res;
   }
 
@@ -1028,7 +1121,7 @@ int FunctionComparator::compare() {
     return Res;
 
   if (FnL->hasSection()) {
-    if (int Res = cmpStrings(FnL->getSection(), FnR->getSection()))
+    if (int Res = cmpMem(FnL->getSection(), FnR->getSection()))
       return Res;
   }
 
@@ -1074,7 +1167,7 @@ int FunctionComparator::compare() {
     if (int Res = cmpValues(BBL, BBR))
       return Res;
 
-    if (int Res = compare(BBL, BBR))
+    if (int Res = cmpBasicBlocks(BBL, BBR))
       return Res;
 
     const TerminatorInst *TermL = BBL->getTerminator();
@@ -1129,7 +1222,7 @@ FunctionComparator::FunctionHash FunctionComparator::functionHash(Function &F) {
   SmallVector<const BasicBlock *, 8> BBs;
   SmallSet<const BasicBlock *, 16> VisitedBBs;
 
-  // Walk the blocks in the same order as FunctionComparator::compare(),
+  // Walk the blocks in the same order as FunctionComparator::cmpBasicBlocks(),
   // accumulating the hash of the function "structure." (BB and opcode sequence)
   BBs.push_back(&F.getEntryBlock());
   VisitedBBs.insert(BBs[0]);
@@ -1163,14 +1256,31 @@ class MergeFunctions : public ModulePass {
 public:
   static char ID;
   MergeFunctions()
-    : ModulePass(ID), HasGlobalAliases(false) {
+    : ModulePass(ID), FnTree(FunctionNodeCmp(&GlobalNumbers)),
+      HasGlobalAliases(false) {
     initializeMergeFunctionsPass(*PassRegistry::getPassRegistry());
   }
 
   bool runOnModule(Module &M) override;
 
 private:
-  typedef std::set<FunctionNode> FnTreeType;
+  // The function comparison operator is provided here so that FunctionNodes do
+  // not need to become larger with another pointer.
+  class FunctionNodeCmp {
+    GlobalNumberState* GlobalNumbers;
+  public:
+    FunctionNodeCmp(GlobalNumberState* GN) : GlobalNumbers(GN) {}
+    bool operator()(const FunctionNode &LHS, const FunctionNode &RHS) const {
+      // Order first by hashes, then full function comparison.
+      if (LHS.getHash() != RHS.getHash())
+        return LHS.getHash() < RHS.getHash();
+      FunctionComparator FCmp(LHS.getFunc(), RHS.getFunc(), GlobalNumbers);
+      return FCmp.compare() == -1;
+    }
+  };
+  typedef std::set<FunctionNode, FunctionNodeCmp> FnTreeType;
+
+  GlobalNumberState GlobalNumbers;
 
   /// A work queue of functions that may have been modified and should be
   /// analyzed again.
@@ -1245,8 +1355,8 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) {
       for (std::vector<WeakVH>::iterator J = I; J != E && j < Max; ++J, ++j) {
         Function *F1 = cast<Function>(*I);
         Function *F2 = cast<Function>(*J);
-        int Res1 = FunctionComparator(F1, F2).compare();
-        int Res2 = FunctionComparator(F2, F1).compare();
+        int Res1 = FunctionComparator(F1, F2, &GlobalNumbers).compare();
+        int Res2 = FunctionComparator(F2, F1, &GlobalNumbers).compare();
 
         // If F1 <= F2, then F2 >= F1, otherwise report failure.
         if (Res1 != -Res2) {
@@ -1267,8 +1377,8 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) {
             continue;
 
           Function *F3 = cast<Function>(*K);
-          int Res3 = FunctionComparator(F1, F3).compare();
-          int Res4 = FunctionComparator(F2, F3).compare();
+          int Res3 = FunctionComparator(F1, F3, &GlobalNumbers).compare();
+          int Res4 = FunctionComparator(F2, F3, &GlobalNumbers).compare();
 
           bool Transitive = true;
 
@@ -1556,6 +1666,8 @@ void MergeFunctions::replaceFunctionInTree(FnTreeType::iterator &IterToF,
           (!F->mayBeOverridden() && !G->mayBeOverridden())) &&
          "Only change functions if both are strong or both are weak");
   (void)F;
+  assert(FunctionComparator(F, G, &GlobalNumbers).compare() == 0 &&
+         "The two functions must be equal");
 
   IterToF->replaceBy(G);
 }
diff --git a/test/Transforms/MergeFunc/constant-entire-value.ll b/test/Transforms/MergeFunc/constant-entire-value.ll
new file mode 100644 (file)
index 0000000..cb193d0
--- /dev/null
@@ -0,0 +1,42 @@
+; RUN: opt -S -mergefunc < %s | FileCheck %s
+
+; RUN: opt -S -mergefunc < %s | FileCheck -check-prefix=NOPLUS %s
+
+; This makes sure that zeros in constants don't cause problems with string based
+; memory comparisons
+define internal i32 @sum(i32 %x, i32 %y) {
+; CHECK-LABEL: @sum
+  %sum = add i32 %x, %y
+  %1 = extractvalue [3 x i32] [ i32 3, i32 0, i32 2 ], 2
+  %sum2 = add i32 %sum, %1
+  %sum3 = add i32 %sum2, %y
+  ret i32 %sum3
+}
+
+define internal i32 @add(i32 %x, i32 %y) {
+; CHECK-LABEL: @add
+  %sum = add i32 %x, %y
+  %1 = extractvalue [3 x i32] [ i32 3, i32 0, i32 1 ], 2
+  %sum2 = add i32 %sum, %1
+  %sum3 = add i32 %sum2, %y
+  ret i32 %sum3
+}
+
+define internal i32 @plus(i32 %x, i32 %y) {
+; NOPLUS-NOT: @plus
+  %sum = add i32 %x, %y
+  %1 = extractvalue [3 x i32] [ i32 3, i32 0, i32 5 ], 2
+  %sum2 = add i32 %sum, %1
+  %sum3 = add i32 %sum2, %y
+  ret i32 %sum3
+}
+
+define internal i32 @next(i32 %x, i32 %y) {
+; CHECK-LABEL: @next
+  %sum = add i32 %x, %y
+  %1 = extractvalue [3 x i32] [ i32 3, i32 0, i32 5 ], 2
+  %sum2 = add i32 %sum, %1
+  %sum3 = add i32 %sum2, %y
+  ret i32 %sum3
+}
+