[PGO] Fix a bug in InstProfWriter addRecord
authorXinliang David Li <davidxl@google.com>
Fri, 8 Jan 2016 03:49:59 +0000 (03:49 +0000)
committerXinliang David Li <davidxl@google.com>
Fri, 8 Jan 2016 03:49:59 +0000 (03:49 +0000)
For a new record with weight != 1, only edge profiling
counters are scaled, VP data is not properly scaled.

This patch refactors the code and fixes the problem.
Also added sort by count interface (for follow up patch).

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@257143 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/ProfileData/InstrProf.h
lib/ProfileData/InstrProf.cpp
lib/ProfileData/InstrProfWriter.cpp
unittests/ProfileData/InstrProfTest.cpp

index 501aafdbfdc348a2cd1acff0acc1fb7f5b2c4b99..299763429006566cf2c651c4ae121146fa915a2c 100644 (file)
@@ -355,11 +355,14 @@ struct InstrProfValueSiteRecord {
           return left.Value < right.Value;
         });
   }
+  /// Sort ValueData Descending by Count
+  inline void sortByCount();
 
   /// Merge data from another InstrProfValueSiteRecord
   /// Optionally scale merged counts by \p Weight.
-  instrprof_error mergeValueData(InstrProfValueSiteRecord &Input,
-                                 uint64_t Weight = 1);
+  instrprof_error merge(InstrProfValueSiteRecord &Input, uint64_t Weight = 1);
+  /// Scale up value profile data counts.
+  instrprof_error scale(uint64_t Weight);
 };
 
 /// Profiling information for a single function.
@@ -402,6 +405,19 @@ struct InstrProfRecord {
   /// Optionally scale merged counts by \p Weight.
   instrprof_error merge(InstrProfRecord &Other, uint64_t Weight = 1);
 
+  /// Scale up profile counts (including value profile data) by
+  /// \p Weight.
+  instrprof_error scale(uint64_t Weight);
+
+  /// Sort value profile data (per site) by count.
+  void sortValueData() {
+    for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind) {
+      std::vector<InstrProfValueSiteRecord> &SiteRecords =
+          getValueSitesForKind(Kind);
+      for (auto &SR : SiteRecords)
+        SR.sortByCount();
+    }
+  }
   /// Clear value data entries
   void clearValueData() {
     for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind)
@@ -436,6 +452,8 @@ private:
   // Scale merged value counts by \p Weight.
   instrprof_error mergeValueProfData(uint32_t ValueKind, InstrProfRecord &Src,
                                      uint64_t Weight);
+  // Scale up value profile data count.
+  instrprof_error scaleValueProfData(uint32_t ValueKind, uint64_t Weight);
 };
 
 uint32_t InstrProfRecord::getNumValueKinds() const {
@@ -503,11 +521,22 @@ inline support::endianness getHostEndianness() {
 #define INSTR_PROF_VALUE_PROF_DATA
 #include "llvm/ProfileData/InstrProfData.inc"
 
- /*
- * Initialize the record for runtime value profile data.
- * Return 0 if the initialization is successful, otherwise
- * return 1.
- */
+void InstrProfValueSiteRecord::sortByCount() {
+  ValueData.sort(
+      [](const InstrProfValueData &left, const InstrProfValueData &right) {
+        return left.Count > right.Count;
+      });
+  // Now truncate
+  size_t max_s = INSTR_PROF_MAX_NUM_VAL_PER_SITE;
+  if (ValueData.size() > max_s)
+    ValueData.resize(max_s);
+}
+
+/*
+* Initialize the record for runtime value profile data.
+* Return 0 if the initialization is successful, otherwise
+* return 1.
+*/
 int initializeValueProfRuntimeRecord(ValueProfRuntimeRecord *RuntimeRecord,
                                      const uint16_t *NumValueSites,
                                      ValueProfNode **Nodes);
index 027f0f78c54616af0fa480642d091dfc07b82fae..94c701de093c0031ad1832e3f1968d3f2c7af424 100644 (file)
@@ -257,9 +257,8 @@ int readPGOFuncNameStrings(StringRef NameStrings, InstrProfSymtab &Symtab) {
   return 0;
 }
 
-instrprof_error
-InstrProfValueSiteRecord::mergeValueData(InstrProfValueSiteRecord &Input,
-                                         uint64_t Weight) {
+instrprof_error InstrProfValueSiteRecord::merge(InstrProfValueSiteRecord &Input,
+                                                uint64_t Weight) {
   this->sortByTargetValues();
   Input.sortByTargetValues();
   auto I = ValueData.begin();
@@ -288,6 +287,17 @@ InstrProfValueSiteRecord::mergeValueData(InstrProfValueSiteRecord &Input,
   return Result;
 }
 
+instrprof_error InstrProfValueSiteRecord::scale(uint64_t Weight) {
+  instrprof_error Result = instrprof_error::success;
+  for (auto I = ValueData.begin(), IE = ValueData.end(); I != IE; ++I) {
+    bool Overflowed;
+    I->Count = SaturatingMultiply(I->Count, Weight, &Overflowed);
+    if (Overflowed)
+      Result = instrprof_error::counter_overflow;
+  }
+  return Result;
+}
+
 // Merge Value Profile data from Src record to this record for ValueKind.
 // Scale merged value counts by \p Weight.
 instrprof_error InstrProfRecord::mergeValueProfData(uint32_t ValueKind,
@@ -303,8 +313,7 @@ instrprof_error InstrProfRecord::mergeValueProfData(uint32_t ValueKind,
       Src.getValueSitesForKind(ValueKind);
   instrprof_error Result = instrprof_error::success;
   for (uint32_t I = 0; I < ThisNumValueSites; I++)
-    MergeResult(Result,
-                ThisSiteRecords[I].mergeValueData(OtherSiteRecords[I], Weight));
+    MergeResult(Result, ThisSiteRecords[I].merge(OtherSiteRecords[I], Weight));
   return Result;
 }
 
@@ -336,6 +345,32 @@ instrprof_error InstrProfRecord::merge(InstrProfRecord &Other,
   return Result;
 }
 
+instrprof_error InstrProfRecord::scaleValueProfData(uint32_t ValueKind,
+                                                    uint64_t Weight) {
+  uint32_t ThisNumValueSites = getNumValueSites(ValueKind);
+  std::vector<InstrProfValueSiteRecord> &ThisSiteRecords =
+      getValueSitesForKind(ValueKind);
+  instrprof_error Result = instrprof_error::success;
+  for (uint32_t I = 0; I < ThisNumValueSites; I++)
+    MergeResult(Result, ThisSiteRecords[I].scale(Weight));
+  return Result;
+}
+
+instrprof_error InstrProfRecord::scale(uint64_t Weight) {
+  instrprof_error Result = instrprof_error::success;
+  for (auto &Count : this->Counts) {
+    bool Overflowed;
+    Count = SaturatingMultiply(Count, Weight, &Overflowed);
+    if (Overflowed && Result == instrprof_error::success) {
+      Result = instrprof_error::counter_overflow;
+    }
+  }
+  for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind)
+    MergeResult(Result, scaleValueProfData(Kind, Weight));
+
+  return Result;
+}
+
 // Map indirect call target name hash to name string.
 uint64_t InstrProfRecord::remapValue(uint64_t Value, uint32_t ValueKind,
                                      ValueMapType *ValueMap) {
index 9bb03e1e77a38deda428e062385d465410028011..07667e1221e828154cfc621c0ecba5af42d586ab 100644 (file)
@@ -104,22 +104,14 @@ std::error_code InstrProfWriter::addRecord(InstrProfRecord &&I,
       ProfileDataMap.insert(std::make_pair(I.Hash, InstrProfRecord()));
   InstrProfRecord &Dest = Where->second;
 
-  instrprof_error Result;
+  instrprof_error Result = instrprof_error::success;
   if (NewFunc) {
     // We've never seen a function with this name and hash, add it.
     Dest = std::move(I);
     // Fix up the name to avoid dangling reference.
     Dest.Name = FunctionData.find(Dest.Name)->getKey();
-    Result = instrprof_error::success;
-    if (Weight > 1) {
-      for (auto &Count : Dest.Counts) {
-        bool Overflowed;
-        Count = SaturatingMultiply(Count, Weight, &Overflowed);
-        if (Overflowed && Result == instrprof_error::success) {
-          Result = instrprof_error::counter_overflow;
-        }
-      }
-    }
+    if (Weight > 1)
+      Result = Dest.scale(Weight);
   } else {
     // We're updating a function we've seen before.
     Result = Dest.merge(I, Weight);
index 1ccc3ca5d69591c8f5e4d5b1a4c433196e975830..a4fadd3d65d93256005720e84c662525401109ca 100644 (file)
@@ -164,6 +164,62 @@ TEST_F(InstrProfTest, get_icall_data_read_write) {
             [](const InstrProfValueData &VD1, const InstrProfValueData &VD2) {
               return VD1.Count > VD2.Count;
             });
+
+  ASSERT_EQ(3U, VD[0].Count);
+  ASSERT_EQ(2U, VD[1].Count);
+  ASSERT_EQ(1U, VD[2].Count);
+
+  ASSERT_EQ(StringRef((const char *)VD[0].Value, 7), StringRef("callee3"));
+  ASSERT_EQ(StringRef((const char *)VD[1].Value, 7), StringRef("callee2"));
+  ASSERT_EQ(StringRef((const char *)VD[2].Value, 7), StringRef("callee1"));
+}
+
+TEST_F(InstrProfTest, get_icall_data_read_write_with_weight) {
+  InstrProfRecord Record1("caller", 0x1234, {1, 2});
+  InstrProfRecord Record2("callee1", 0x1235, {3, 4});
+  InstrProfRecord Record3("callee2", 0x1235, {3, 4});
+  InstrProfRecord Record4("callee3", 0x1235, {3, 4});
+
+  // 4 value sites.
+  Record1.reserveSites(IPVK_IndirectCallTarget, 4);
+  InstrProfValueData VD0[] = {{(uint64_t) "callee1", 1},
+                              {(uint64_t) "callee2", 2},
+                              {(uint64_t) "callee3", 3}};
+  Record1.addValueData(IPVK_IndirectCallTarget, 0, VD0, 3, nullptr);
+  // No value profile data at the second site.
+  Record1.addValueData(IPVK_IndirectCallTarget, 1, nullptr, 0, nullptr);
+  InstrProfValueData VD2[] = {{(uint64_t) "callee1", 1},
+                              {(uint64_t) "callee2", 2}};
+  Record1.addValueData(IPVK_IndirectCallTarget, 2, VD2, 2, nullptr);
+  InstrProfValueData VD3[] = {{(uint64_t) "callee1", 1}};
+  Record1.addValueData(IPVK_IndirectCallTarget, 3, VD3, 1, nullptr);
+
+  Writer.addRecord(std::move(Record1), 10);
+  Writer.addRecord(std::move(Record2));
+  Writer.addRecord(std::move(Record3));
+  Writer.addRecord(std::move(Record4));
+  auto Profile = Writer.writeBuffer();
+  readProfile(std::move(Profile));
+
+  ErrorOr<InstrProfRecord> R = Reader->getInstrProfRecord("caller", 0x1234);
+  ASSERT_TRUE(NoError(R.getError()));
+  ASSERT_EQ(4U, R.get().getNumValueSites(IPVK_IndirectCallTarget));
+  ASSERT_EQ(3U, R.get().getNumValueDataForSite(IPVK_IndirectCallTarget, 0));
+  ASSERT_EQ(0U, R.get().getNumValueDataForSite(IPVK_IndirectCallTarget, 1));
+  ASSERT_EQ(2U, R.get().getNumValueDataForSite(IPVK_IndirectCallTarget, 2));
+  ASSERT_EQ(1U, R.get().getNumValueDataForSite(IPVK_IndirectCallTarget, 3));
+
+  std::unique_ptr<InstrProfValueData[]> VD =
+      R.get().getValueForSite(IPVK_IndirectCallTarget, 0);
+  // Now sort the target acording to frequency.
+  std::sort(&VD[0], &VD[3],
+            [](const InstrProfValueData &VD1, const InstrProfValueData &VD2) {
+              return VD1.Count > VD2.Count;
+            });
+  ASSERT_EQ(30U, VD[0].Count);
+  ASSERT_EQ(20U, VD[1].Count);
+  ASSERT_EQ(10U, VD[2].Count);
+
   ASSERT_EQ(StringRef((const char *)VD[0].Value, 7), StringRef("callee3"));
   ASSERT_EQ(StringRef((const char *)VD[1].Value, 7), StringRef("callee2"));
   ASSERT_EQ(StringRef((const char *)VD[2].Value, 7), StringRef("callee1"));