[llvm-profdata] Add support for weighted merge of profile data (2nd try)
[oota-llvm.git] / include / llvm / ProfileData / SampleProf.h
index f62f79064c4e654f11f9f7454625e8bdfcbd08ae..7607e24ec1c84714cc38c94b60af1c384ea52bf0 100644 (file)
@@ -123,18 +123,36 @@ public:
   SampleRecord() : NumSamples(0), CallTargets() {}
 
   /// Increment the number of samples for this record by \p S.
+  /// Optionally scale sample count \p S by \p Weight.
   ///
   /// Sample counts accumulate using saturating arithmetic, to avoid wrapping
   /// around unsigned integers.
-  void addSamples(uint64_t S) { NumSamples = SaturatingAdd(NumSamples, S); }
+  void addSamples(uint64_t S, uint64_t Weight = 1) {
+    // FIXME: Improve handling of counter overflow.
+    bool Overflowed;
+    if (Weight > 1) {
+      S = SaturatingMultiply(S, Weight, &Overflowed);
+      assert(!Overflowed && "Sample counter overflowed!");
+    }
+    NumSamples = SaturatingAdd(NumSamples, S, &Overflowed);
+    assert(!Overflowed && "Sample counter overflowed!");
+  }
 
   /// Add called function \p F with samples \p S.
+  /// Optionally scale sample count \p S by \p Weight.
   ///
   /// Sample counts accumulate using saturating arithmetic, to avoid wrapping
   /// around unsigned integers.
-  void addCalledTarget(StringRef F, uint64_t S) {
+  void addCalledTarget(StringRef F, uint64_t S, uint64_t Weight = 1) {
+    // FIXME: Improve handling of counter overflow.
     uint64_t &TargetSamples = CallTargets[F];
-    TargetSamples = SaturatingAdd(TargetSamples, S);
+    bool Overflowed;
+    if (Weight > 1) {
+      S = SaturatingMultiply(S, Weight, &Overflowed);
+      assert(!Overflowed && "Called target counter overflowed!");
+    }
+    TargetSamples = SaturatingAdd(TargetSamples, S, &Overflowed);
+    assert(!Overflowed && "Called target counter overflowed!");
   }
 
   /// Return true if this sample record contains function calls.
@@ -144,10 +162,11 @@ public:
   const CallTargetMap &getCallTargets() const { return CallTargets; }
 
   /// Merge the samples in \p Other into this record.
-  void merge(const SampleRecord &Other) {
-    addSamples(Other.getSamples());
+  /// Optionally scale sample counts by \p Weight.
+  void merge(const SampleRecord &Other, uint64_t Weight = 1) {
+    addSamples(Other.getSamples(), Weight);
     for (const auto &I : Other.getCallTargets())
-      addCalledTarget(I.first(), I.second);
+      addCalledTarget(I.first(), I.second, Weight);
   }
 
   void print(raw_ostream &OS, unsigned Indent) const;
@@ -174,16 +193,36 @@ public:
   FunctionSamples() : TotalSamples(0), TotalHeadSamples(0) {}
   void print(raw_ostream &OS = dbgs(), unsigned Indent = 0) const;
   void dump() const;
-  void addTotalSamples(uint64_t Num) { TotalSamples += Num; }
-  void addHeadSamples(uint64_t Num) { TotalHeadSamples += Num; }
-  void addBodySamples(uint32_t LineOffset, uint32_t Discriminator,
-                      uint64_t Num) {
-    BodySamples[LineLocation(LineOffset, Discriminator)].addSamples(Num);
+  void addTotalSamples(uint64_t Num, uint64_t Weight = 1) {
+    // FIXME: Improve handling of counter overflow.
+    bool Overflowed;
+    if (Weight > 1) {
+      Num = SaturatingMultiply(Num, Weight, &Overflowed);
+      assert(!Overflowed && "Total samples counter overflowed!");
+    }
+    TotalSamples = SaturatingAdd(TotalSamples, Num, &Overflowed);
+    assert(!Overflowed && "Total samples counter overflowed!");
+  }
+  void addHeadSamples(uint64_t Num, uint64_t Weight = 1) {
+    // FIXME: Improve handling of counter overflow.
+    bool Overflowed;
+    if (Weight > 1) {
+      Num = SaturatingMultiply(Num, Weight, &Overflowed);
+      assert(!Overflowed && "Total head samples counter overflowed!");
+    }
+    TotalHeadSamples = SaturatingAdd(TotalHeadSamples, Num, &Overflowed);
+    assert(!Overflowed && "Total head samples counter overflowed!");
+  }
+  void addBodySamples(uint32_t LineOffset, uint32_t Discriminator, uint64_t Num,
+                      uint64_t Weight = 1) {
+    BodySamples[LineLocation(LineOffset, Discriminator)].addSamples(Num,
+                                                                    Weight);
   }
   void addCalledTargetSamples(uint32_t LineOffset, uint32_t Discriminator,
-                              std::string FName, uint64_t Num) {
-    BodySamples[LineLocation(LineOffset, Discriminator)].addCalledTarget(FName,
-                                                                         Num);
+                              std::string FName, uint64_t Num,
+                              uint64_t Weight = 1) {
+    BodySamples[LineLocation(LineOffset, Discriminator)].addCalledTarget(
+        FName, Num, Weight);
   }
 
   /// Return the number of samples collected at the given location.
@@ -232,18 +271,19 @@ public:
   }
 
   /// Merge the samples in \p Other into this one.
-  void merge(const FunctionSamples &Other) {
-    addTotalSamples(Other.getTotalSamples());
-    addHeadSamples(Other.getHeadSamples());
+  /// Optionally scale samples by \p Weight.
+  void merge(const FunctionSamples &Other, uint64_t Weight = 1) {
+    addTotalSamples(Other.getTotalSamples(), Weight);
+    addHeadSamples(Other.getHeadSamples(), Weight);
     for (const auto &I : Other.getBodySamples()) {
       const LineLocation &Loc = I.first;
       const SampleRecord &Rec = I.second;
-      BodySamples[Loc].merge(Rec);
+      BodySamples[Loc].merge(Rec, Weight);
     }
     for (const auto &I : Other.getCallsiteSamples()) {
       const CallsiteLocation &Loc = I.first;
       const FunctionSamples &Rec = I.second;
-      functionSamplesAt(Loc).merge(Rec);
+      functionSamplesAt(Loc).merge(Rec, Weight);
     }
   }