[LibCallSimplifier] propagate FMF when shrinking unary calls
[oota-llvm.git] / lib / Transforms / Utils / SimplifyLibCalls.cpp
index e7eb39d6312fa595b9724412cdfd3434116761fc..2eba5fe31e0901573aa9def1e5db3e1eeee77bab 100644 (file)
@@ -57,8 +57,7 @@ static bool ignoreCallingConv(LibFunc::Func Func) {
          Func == LibFunc::llabs || Func == LibFunc::strlen;
 }
 
-/// isOnlyUsedInZeroEqualityComparison - Return true if it only matters that the
-/// value is equal or not-equal to zero.
+/// Return true if it only matters that the value is equal or not-equal to zero.
 static bool isOnlyUsedInZeroEqualityComparison(Value *V) {
   for (User *U : V->users()) {
     if (ICmpInst *IC = dyn_cast<ICmpInst>(U))
@@ -72,8 +71,7 @@ static bool isOnlyUsedInZeroEqualityComparison(Value *V) {
   return true;
 }
 
-/// isOnlyUsedInEqualityComparison - Return true if it is only used in equality
-/// comparisons with With.
+/// Return true if it is only used in equality comparisons with With.
 static bool isOnlyUsedInEqualityComparison(Value *V, Value *With) {
   for (User *U : V->users()) {
     if (ICmpInst *IC = dyn_cast<ICmpInst>(U))
@@ -249,12 +247,12 @@ Value *LibCallSimplifier::optimizeStrNCat(CallInst *CI, IRBuilder<> &B) {
       !FT->getParamType(2)->isIntegerTy())
     return nullptr;
 
-  // Extract some information from the instruction
+  // Extract some information from the instruction.
   Value *Dst = CI->getArgOperand(0);
   Value *Src = CI->getArgOperand(1);
   uint64_t Len;
 
-  // We don't do anything if length is not constant
+  // We don't do anything if length is not constant.
   if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getArgOperand(2)))
     Len = LengthArg->getZExtValue();
   else
@@ -272,12 +270,12 @@ Value *LibCallSimplifier::optimizeStrNCat(CallInst *CI, IRBuilder<> &B) {
   if (SrcLen == 0 || Len == 0)
     return Dst;
 
-  // We don't optimize this case
+  // We don't optimize this case.
   if (Len < SrcLen)
     return nullptr;
 
   // strncat(x, s, c) -> strcat(x, s)
-  // s is constant so the strcat can be optimized further
+  // s is constant so the strcat can be optimized further.
   return emitStrLenMemCpy(Src, Dst, SrcLen, B);
 }
 
@@ -310,7 +308,8 @@ Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilder<> &B) {
   StringRef Str;
   if (!getConstantStringInfo(SrcStr, Str)) {
     if (CharC->isZero()) // strchr(p, 0) -> p + strlen(p)
-      return B.CreateGEP(B.getInt8Ty(), SrcStr, EmitStrLen(SrcStr, B, DL, TLI), "strchr");
+      return B.CreateGEP(B.getInt8Ty(), SrcStr, EmitStrLen(SrcStr, B, DL, TLI),
+                         "strchr");
     return nullptr;
   }
 
@@ -490,8 +489,8 @@ Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilder<> &B) {
 
   Type *PT = Callee->getFunctionType()->getParamType(0);
   Value *LenV = ConstantInt::get(DL.getIntPtrType(PT), Len);
-  Value *DstEnd =
-      B.CreateGEP(B.getInt8Ty(), Dst, ConstantInt::get(DL.getIntPtrType(PT), Len - 1));
+  Value *DstEnd = B.CreateGEP(B.getInt8Ty(), Dst,
+                              ConstantInt::get(DL.getIntPtrType(PT), Len - 1));
 
   // We have enough information to now generate the memcpy call to do the
   // copy for us.  Make a memcpy to copy the nul byte with align = 1.
@@ -599,7 +598,8 @@ Value *LibCallSimplifier::optimizeStrPBrk(CallInst *CI, IRBuilder<> &B) {
     if (I == StringRef::npos) // No match.
       return Constant::getNullValue(CI->getType());
 
-    return B.CreateGEP(B.getInt8Ty(), CI->getArgOperand(0), B.getInt64(I), "strpbrk");
+    return B.CreateGEP(B.getInt8Ty(), CI->getArgOperand(0), B.getInt64(I),
+                       "strpbrk");
   }
 
   // strpbrk(s, "a") -> strchr(s, 'a')
@@ -878,8 +878,10 @@ Value *LibCallSimplifier::optimizeMemCmp(CallInst *CI, IRBuilder<> &B) {
       Type *RHSPtrTy =
           IntType->getPointerTo(RHS->getType()->getPointerAddressSpace());
 
-      Value *LHSV = B.CreateLoad(B.CreateBitCast(LHS, LHSPtrTy, "lhsc"), "lhsv");
-      Value *RHSV = B.CreateLoad(B.CreateBitCast(RHS, RHSPtrTy, "rhsc"), "rhsv");
+      Value *LHSV =
+          B.CreateLoad(B.CreateBitCast(LHS, LHSPtrTy, "lhsc"), "lhsv");
+      Value *RHSV =
+          B.CreateLoad(B.CreateBitCast(RHS, RHSPtrTy, "rhsc"), "rhsv");
 
       return B.CreateZExt(B.CreateICmpNE(LHSV, RHSV), CI->getType(), "memcmp");
     }
@@ -992,10 +994,14 @@ Value *LibCallSimplifier::optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B,
   Value *V = valueHasFloatPrecision(CI->getArgOperand(0));
   if (V == nullptr)
     return nullptr;
+  
+  // Propagate fast-math flags from the existing call to the new call.
+  IRBuilder<>::FastMathFlagGuard Guard(B);
+  B.SetFastMathFlags(CI->getFastMathFlags());
 
   // floor((double)floatval) -> (double)floorf(floatval)
   if (Callee->isIntrinsic()) {
-    Module *M = CI->getParent()->getParent()->getParent();
+    Module *M = CI->getModule();
     Intrinsic::ID IID = Callee->getIntrinsicID();
     Function *F = Intrinsic::getDeclaration(M, IID, B.getFloatTy());
     V = B.CreateCall(F, V);
@@ -1058,6 +1064,31 @@ Value *LibCallSimplifier::optimizeCos(CallInst *CI, IRBuilder<> &B) {
   return Ret;
 }
 
+static Value *getPow(Value *InnerChain[33], unsigned Exp, IRBuilder<> &B) {
+  // Multiplications calculated using Addition Chains.
+  // Refer: http://wwwhomes.uni-bielefeld.de/achim/addition_chain.html
+
+  assert(Exp != 0 && "Incorrect exponent 0 not handled");
+
+  if (InnerChain[Exp])
+    return InnerChain[Exp];
+
+  static const unsigned AddChain[33][2] = {
+      {0, 0}, // Unused.
+      {0, 0}, // Unused (base case = pow1).
+      {1, 1}, // Unused (pre-computed).
+      {1, 2},  {2, 2},   {2, 3},  {3, 3},   {2, 5},  {4, 4},
+      {1, 8},  {5, 5},   {1, 10}, {6, 6},   {4, 9},  {7, 7},
+      {3, 12}, {8, 8},   {8, 9},  {2, 16},  {1, 18}, {10, 10},
+      {6, 15}, {11, 11}, {3, 20}, {12, 12}, {8, 17}, {13, 13},
+      {3, 24}, {14, 14}, {4, 25}, {15, 15}, {3, 28}, {16, 16},
+  };
+
+  InnerChain[Exp] = B.CreateFMul(getPow(InnerChain, AddChain[Exp][0], B),
+                                 getPow(InnerChain, AddChain[Exp][1], B));
+  return InnerChain[Exp];
+}
+
 Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
   Function *Callee = CI->getCalledFunction();
   Value *Ret = nullptr;
@@ -1092,7 +1123,7 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
                                   Callee->getAttributes());
   }
 
-  bool unsafeFPMath = canUseUnsafeFPMath(CI->getParent()->getParent());
+  bool UnsafeFPMath = canUseUnsafeFPMath(CI->getParent()->getParent());
 
   // pow(exp(x), y) -> exp(x*y)
   // pow(exp2(x), y) -> exp2(x * y)
@@ -1101,7 +1132,7 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
   // underflow behavior quite dramatically.
   // Example: x = 1000, y = 0.001.
   // pow(exp(x), y) = pow(inf, 0.001) = inf, whereas exp(x*y) = exp(1).
-  if (unsafeFPMath) {
+  if (UnsafeFPMath) {
     if (auto *OpC = dyn_cast<CallInst>(Op1)) {
       IRBuilder<>::FastMathFlagGuard Guard(B);
       FastMathFlags FMF;
@@ -1132,7 +1163,7 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
                       LibFunc::fabsl)) {
 
     // In -ffast-math, pow(x, 0.5) -> sqrt(x).
-    if (unsafeFPMath)
+    if (UnsafeFPMath)
       return EmitUnaryFloatFnCall(Op1, TLI->getName(LibFunc::sqrt), B,
                                   Callee->getAttributes());
 
@@ -1156,6 +1187,32 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
     return B.CreateFMul(Op1, Op1, "pow2");
   if (Op2C->isExactlyValue(-1.0)) // pow(x, -1.0) -> 1.0/x
     return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), Op1, "powrecip");
+
+  // In -ffast-math, generate repeated fmul instead of generating pow(x, n).
+  if (UnsafeFPMath) {
+    APFloat V = abs(Op2C->getValueAPF());
+    // We limit to a max of 7 fmul(s). Thus max exponent is 32.
+    // This transformation applies to integer exponents only.
+    if (V.compare(APFloat(V.getSemantics(), 32.0)) == APFloat::cmpGreaterThan ||
+        !V.isInteger())
+      return nullptr;
+
+    // We will memoize intermediate products of the Addition Chain.
+    Value *InnerChain[33] = {nullptr};
+    InnerChain[1] = Op1;
+    InnerChain[2] = B.CreateFMul(Op1, Op1);
+
+    // We cannot readily convert a non-double type (like float) to a double.
+    // So we first convert V to something which could be converted to double.
+    bool ignored;
+    V.convert(APFloat::IEEEdouble, APFloat::rmTowardZero, &ignored);
+    Value *FMul = getPow(InnerChain, V.convertToDouble(), B);
+    // For negative exponents simply compute the reciprocal.
+    if (Op2C->isNegative())
+      FMul = B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), FMul);
+    return FMul;
+  }
+
   return nullptr;
 }
 
@@ -1317,12 +1374,20 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) {
 
   LibFunc::Func Func;
   Function *F = OpC->getCalledFunction();
-  StringRef FuncName = F->getName();
-  if ((TLI->getLibFunc(FuncName, Func) && TLI->has(Func) &&
-      Func == LibFunc::pow) || F->getIntrinsicID() == Intrinsic::pow)
+  if (F && ((TLI->getLibFunc(F->getName(), Func) && TLI->has(Func) &&
+      Func == LibFunc::pow) || F->getIntrinsicID() == Intrinsic::pow))
     return B.CreateFMul(OpC->getArgOperand(1),
       EmitUnaryFloatFnCall(OpC->getOperand(0), Callee->getName(), B,
                            Callee->getAttributes()), "mul");
+
+  // log(exp2(y)) -> y*log(2)
+  if (F && Name == "log" && TLI->getLibFunc(F->getName(), Func) &&
+      TLI->has(Func) && Func == LibFunc::exp2)
+    return B.CreateFMul(
+        OpC->getArgOperand(0),
+        EmitUnaryFloatFnCall(ConstantFP::get(CI->getType(), 2.0),
+                             Callee->getName(), B, Callee->getAttributes()),
+        "logmul");
   return Ret;
 }
 
@@ -2302,7 +2367,6 @@ void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) {
 // log, logf, logl:
 //   * log(exp(x))   -> x
 //   * log(exp(y))   -> y*log(e)
-//   * log(exp2(y))  -> y*log(2)
 //   * log(exp10(y)) -> y*log(10)
 //   * log(sqrt(x))  -> 0.5*log(x)
 //
@@ -2361,7 +2425,8 @@ bool FortifiedLibCallSimplifier::isFortifiedCallFoldable(CallInst *CI,
   return false;
 }
 
-Value *FortifiedLibCallSimplifier::optimizeMemCpyChk(CallInst *CI, IRBuilder<> &B) {
+Value *FortifiedLibCallSimplifier::optimizeMemCpyChk(CallInst *CI,
+                                                     IRBuilder<> &B) {
   Function *Callee = CI->getCalledFunction();
 
   if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memcpy_chk))
@@ -2375,7 +2440,8 @@ Value *FortifiedLibCallSimplifier::optimizeMemCpyChk(CallInst *CI, IRBuilder<> &
   return nullptr;
 }
 
-Value *FortifiedLibCallSimplifier::optimizeMemMoveChk(CallInst *CI, IRBuilder<> &B) {
+Value *FortifiedLibCallSimplifier::optimizeMemMoveChk(CallInst *CI,
+                                                      IRBuilder<> &B) {
   Function *Callee = CI->getCalledFunction();
 
   if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memmove_chk))
@@ -2389,7 +2455,8 @@ Value *FortifiedLibCallSimplifier::optimizeMemMoveChk(CallInst *CI, IRBuilder<>
   return nullptr;
 }
 
-Value *FortifiedLibCallSimplifier::optimizeMemSetChk(CallInst *CI, IRBuilder<> &B) {
+Value *FortifiedLibCallSimplifier::optimizeMemSetChk(CallInst *CI,
+                                                     IRBuilder<> &B) {
   Function *Callee = CI->getCalledFunction();
 
   if (!checkStringCopyLibFuncSignature(Callee, LibFunc::memset_chk))