[Support] Add saturating multiply-add support function
authorNathan Slingerland <slingn@gmail.com>
Tue, 12 Jan 2016 22:34:00 +0000 (22:34 +0000)
committerNathan Slingerland <slingn@gmail.com>
Tue, 12 Jan 2016 22:34:00 +0000 (22:34 +0000)
Summary: Add SaturatingMultiplyAdd convenience function template since A + (X * Y) comes up frequently when doing weighted arithmetic.

Reviewers: davidxl, silvas

Subscribers: llvm-commits

Differential Revision: http://reviews.llvm.org/D15385

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@257532 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/ProfileData/SampleProf.h
include/llvm/Support/MathExtras.h
lib/ProfileData/InstrProf.cpp
unittests/Support/MathExtrasTest.cpp

index 8df3fe80320930d38464e72cf4ed08e0fb6dc1d6..6c39cf9458dcb4f39e86303cef6b03b40f5fc4be 100644 (file)
@@ -140,16 +140,9 @@ public:
   /// around unsigned integers.
   sampleprof_error addSamples(uint64_t S, uint64_t Weight = 1) {
     bool Overflowed;
-    if (Weight > 1) {
-      S = SaturatingMultiply(S, Weight, &Overflowed);
-      if (Overflowed)
-        return sampleprof_error::counter_overflow;
-    }
-    NumSamples = SaturatingAdd(NumSamples, S, &Overflowed);
-    if (Overflowed)
-      return sampleprof_error::counter_overflow;
-
-    return sampleprof_error::success;
+    NumSamples = SaturatingMultiplyAdd(S, Weight, NumSamples, &Overflowed);
+    return Overflowed ? sampleprof_error::counter_overflow
+                      : sampleprof_error::success;
   }
 
   /// Add called function \p F with samples \p S.
@@ -161,16 +154,10 @@ public:
                                    uint64_t Weight = 1) {
     uint64_t &TargetSamples = CallTargets[F];
     bool Overflowed;
-    if (Weight > 1) {
-      S = SaturatingMultiply(S, Weight, &Overflowed);
-      if (Overflowed)
-        return sampleprof_error::counter_overflow;
-    }
-    TargetSamples = SaturatingAdd(TargetSamples, S, &Overflowed);
-    if (Overflowed)
-      return sampleprof_error::counter_overflow;
-
-    return sampleprof_error::success;
+    TargetSamples =
+        SaturatingMultiplyAdd(S, Weight, TargetSamples, &Overflowed);
+    return Overflowed ? sampleprof_error::counter_overflow
+                      : sampleprof_error::success;
   }
 
   /// Return true if this sample record contains function calls.
@@ -215,29 +202,17 @@ public:
   void dump() const;
   sampleprof_error addTotalSamples(uint64_t Num, uint64_t Weight = 1) {
     bool Overflowed;
-    if (Weight > 1) {
-      Num = SaturatingMultiply(Num, Weight, &Overflowed);
-      if (Overflowed)
-        return sampleprof_error::counter_overflow;
-    }
-    TotalSamples = SaturatingAdd(TotalSamples, Num, &Overflowed);
-    if (Overflowed)
-      return sampleprof_error::counter_overflow;
-
-    return sampleprof_error::success;
+    TotalSamples =
+        SaturatingMultiplyAdd(Num, Weight, TotalSamples, &Overflowed);
+    return Overflowed ? sampleprof_error::counter_overflow
+                      : sampleprof_error::success;
   }
   sampleprof_error addHeadSamples(uint64_t Num, uint64_t Weight = 1) {
     bool Overflowed;
-    if (Weight > 1) {
-      Num = SaturatingMultiply(Num, Weight, &Overflowed);
-      if (Overflowed)
-        return sampleprof_error::counter_overflow;
-    }
-    TotalHeadSamples = SaturatingAdd(TotalHeadSamples, Num, &Overflowed);
-    if (Overflowed)
-      return sampleprof_error::counter_overflow;
-
-    return sampleprof_error::success;
+    TotalHeadSamples =
+        SaturatingMultiplyAdd(Num, Weight, TotalHeadSamples, &Overflowed);
+    return Overflowed ? sampleprof_error::counter_overflow
+                      : sampleprof_error::success;
   }
   sampleprof_error addBodySamples(uint32_t LineOffset, uint32_t Discriminator,
                                   uint64_t Num, uint64_t Weight = 1) {
index 8111aeebe6ee228b989c45de86a3437361e6a1c3..408ae3c339a22e794157c2940297bd4e53badb27 100644 (file)
@@ -717,6 +717,25 @@ SaturatingMultiply(T X, T Y, bool *ResultOverflowed = nullptr) {
   return Z;
 }
 
+/// \brief Multiply two unsigned integers, X and Y, and add the unsigned
+/// integer, A to the product. Clamp the result to the maximum representable
+/// value of T on overflow. ResultOverflowed indicates if the result is larger
+/// than the maximum representable value of type T.
+/// Note that this is purely a convenience function as there is no distinction
+/// where overflow occurred in a 'fused' multiply-add for unsigned numbers.
+template <typename T>
+typename std::enable_if<std::is_unsigned<T>::value, T>::type
+SaturatingMultiplyAdd(T X, T Y, T A, bool *ResultOverflowed = nullptr) {
+  bool Dummy;
+  bool &Overflowed = ResultOverflowed ? *ResultOverflowed : Dummy;
+
+  T Product = SaturatingMultiply(X, Y, &Overflowed);
+  if (Overflowed)
+    return Product;
+
+  return SaturatingAdd(A, Product, &Overflowed);
+}
+
 extern const float huge_valf;
 } // End llvm namespace
 
index 94c701de093c0031ad1832e3f1968d3f2c7af424..d6777639abe7ab2adf873d97fb6976adc027e75f 100644 (file)
@@ -269,14 +269,8 @@ instrprof_error InstrProfValueSiteRecord::merge(InstrProfValueSiteRecord &Input,
     while (I != IE && I->Value < J->Value)
       ++I;
     if (I != IE && I->Value == J->Value) {
-      uint64_t JCount = J->Count;
       bool Overflowed;
-      if (Weight > 1) {
-        JCount = SaturatingMultiply(JCount, Weight, &Overflowed);
-        if (Overflowed)
-          Result = instrprof_error::counter_overflow;
-      }
-      I->Count = SaturatingAdd(I->Count, JCount, &Overflowed);
+      I->Count = SaturatingMultiplyAdd(J->Count, Weight, I->Count, &Overflowed);
       if (Overflowed)
         Result = instrprof_error::counter_overflow;
       ++I;
@@ -328,13 +322,8 @@ instrprof_error InstrProfRecord::merge(InstrProfRecord &Other,
 
   for (size_t I = 0, E = Other.Counts.size(); I < E; ++I) {
     bool Overflowed;
-    uint64_t OtherCount = Other.Counts[I];
-    if (Weight > 1) {
-      OtherCount = SaturatingMultiply(OtherCount, Weight, &Overflowed);
-      if (Overflowed)
-        Result = instrprof_error::counter_overflow;
-    }
-    Counts[I] = SaturatingAdd(Counts[I], OtherCount, &Overflowed);
+    Counts[I] =
+        SaturatingMultiplyAdd(Other.Counts[I], Weight, Counts[I], &Overflowed);
     if (Overflowed)
       Result = instrprof_error::counter_overflow;
   }
index 945d8322b259e07e43efe0cb4fc902e2ca684466..97309f8d31f5e8d335e7cfd1b9e5a232a547c59d 100644 (file)
@@ -304,4 +304,58 @@ TEST(MathExtras, SaturatingMultiply) {
   SaturatingMultiplyTestHelper<uint64_t>();
 }
 
+template<typename T>
+void SaturatingMultiplyAddTestHelper()
+{
+  const T Max = std::numeric_limits<T>::max();
+  bool ResultOverflowed;
+
+  // Test basic multiply-add.
+  EXPECT_EQ(T(16), SaturatingMultiplyAdd(T(2), T(3), T(10)));
+  EXPECT_EQ(T(16), SaturatingMultiplyAdd(T(2), T(3), T(10), &ResultOverflowed));
+  EXPECT_FALSE(ResultOverflowed);
+
+  // Test multiply overflows, add doesn't overflow
+  EXPECT_EQ(Max, SaturatingMultiplyAdd(Max, Max, T(0), &ResultOverflowed));
+  EXPECT_TRUE(ResultOverflowed);
+
+  // Test multiply doesn't overflow, add overflows
+  EXPECT_EQ(Max, SaturatingMultiplyAdd(T(1), T(1), Max, &ResultOverflowed));
+  EXPECT_TRUE(ResultOverflowed);
+
+  // Test multiply-add with Max as operand
+  EXPECT_EQ(Max, SaturatingMultiplyAdd(T(1), T(1), Max, &ResultOverflowed));
+  EXPECT_TRUE(ResultOverflowed);
+
+  EXPECT_EQ(Max, SaturatingMultiplyAdd(T(1), Max, T(1), &ResultOverflowed));
+  EXPECT_TRUE(ResultOverflowed);
+
+  EXPECT_EQ(Max, SaturatingMultiplyAdd(Max, Max, T(1), &ResultOverflowed));
+  EXPECT_TRUE(ResultOverflowed);
+
+  EXPECT_EQ(Max, SaturatingMultiplyAdd(Max, Max, Max, &ResultOverflowed));
+  EXPECT_TRUE(ResultOverflowed);
+
+  // Test multiply-add with 0 as operand
+  EXPECT_EQ(T(1), SaturatingMultiplyAdd(T(1), T(1), T(0), &ResultOverflowed));
+  EXPECT_FALSE(ResultOverflowed);
+
+  EXPECT_EQ(T(1), SaturatingMultiplyAdd(T(1), T(0), T(1), &ResultOverflowed));
+  EXPECT_FALSE(ResultOverflowed);
+
+  EXPECT_EQ(T(1), SaturatingMultiplyAdd(T(0), T(0), T(1), &ResultOverflowed));
+  EXPECT_FALSE(ResultOverflowed);
+
+  EXPECT_EQ(T(0), SaturatingMultiplyAdd(T(0), T(0), T(0), &ResultOverflowed));
+  EXPECT_FALSE(ResultOverflowed);
+
+}
+
+TEST(MathExtras, SaturatingMultiplyAdd) {
+  SaturatingMultiplyAddTestHelper<uint8_t>();
+  SaturatingMultiplyAddTestHelper<uint16_t>();
+  SaturatingMultiplyAddTestHelper<uint32_t>();
+  SaturatingMultiplyAddTestHelper<uint64_t>();
+}
+
 }