From 5fa57d9ce38d76a6779730c6e30c8a3b970a3a9c Mon Sep 17 00:00:00 2001 From: Cong Hou Date: Tue, 13 Oct 2015 22:27:41 +0000 Subject: [PATCH] Update MachineBranchProbabilityInfo::normalizeEdgeWeights to make sure there is no zero weight in the output, and also add a missing test for JumpThreading. The test is for the patch in http://reviews.llvm.org/D10979 but was missing when committing that patch. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@250240 91177308-0d34-0410-b5e6-96231b3b80d8 --- .../CodeGen/MachineBranchProbabilityInfo.h | 52 +++++++++++-------- .../JumpThreading/update-edge-weight.ll | 43 +++++++++++++++ 2 files changed, 73 insertions(+), 22 deletions(-) create mode 100644 test/Transforms/JumpThreading/update-edge-weight.ll diff --git a/include/llvm/CodeGen/MachineBranchProbabilityInfo.h b/include/llvm/CodeGen/MachineBranchProbabilityInfo.h index 26f0d993738..21e2dbb5722 100644 --- a/include/llvm/CodeGen/MachineBranchProbabilityInfo.h +++ b/include/llvm/CodeGen/MachineBranchProbabilityInfo.h @@ -86,35 +86,43 @@ public: const MachineBasicBlock *Dst) const; // Normalize a list of weights by scaling them down so that the sum of them - // doesn't exceed UINT32_MAX. Return the scale. + // doesn't exceed UINT32_MAX. template - static uint32_t normalizeEdgeWeights(WeightListIter Begin, - WeightListIter End); + static void normalizeEdgeWeights(WeightListIter Begin, WeightListIter End); }; template -uint32_t -MachineBranchProbabilityInfo::normalizeEdgeWeights(WeightListIter Begin, - WeightListIter End) { +void MachineBranchProbabilityInfo::normalizeEdgeWeights(WeightListIter Begin, + WeightListIter End) { // First we compute the sum with 64-bits of precision. uint64_t Sum = std::accumulate(Begin, End, uint64_t(0)); - // If Sum is zero, set all weights to 1. - if (Sum == 0) - std::fill(Begin, End, uint64_t(1)); - - // If the computed sum fits in 32-bits, we're done. - if (Sum <= UINT32_MAX) - return 1; - - // Otherwise, compute the scale necessary to cause the weights to fit, and - // re-sum with that scale applied. - assert((Sum / UINT32_MAX) < UINT32_MAX && - "The sum of weights exceeds UINT32_MAX^2!"); - uint32_t Scale = (Sum / UINT32_MAX) + 1; - for (auto I = Begin; I != End; ++I) - *I /= Scale; - return Scale; + if (Sum > UINT32_MAX) { + // Compute the scale necessary to cause the weights to fit, and re-sum with + // that scale applied. + assert(Sum / UINT32_MAX < UINT32_MAX && + "The sum of weights exceeds UINT32_MAX^2!"); + uint32_t Scale = Sum / UINT32_MAX + 1; + for (auto I = Begin; I != End; ++I) + *I /= Scale; + Sum = std::accumulate(Begin, End, uint64_t(0)); + } + + // Eliminate zero weights. + auto ZeroWeightNum = std::count(Begin, End, 0u); + if (ZeroWeightNum > 0) { + // If all weights are zeros, replace them by 1. + if (Sum == 0) + std::fill(Begin, End, 1u); + else { + // Scale up non-zero weights and turn zero weights into ones. + uint64_t ScalingFactor = (UINT32_MAX - ZeroWeightNum) / Sum; + if (ScalingFactor > 1) + for (auto I = Begin; I != End; ++I) + *I *= ScalingFactor; + std::replace(Begin, End, 0u, 1u); + } + } } } diff --git a/test/Transforms/JumpThreading/update-edge-weight.ll b/test/Transforms/JumpThreading/update-edge-weight.ll new file mode 100644 index 00000000000..b5c5d01a3c6 --- /dev/null +++ b/test/Transforms/JumpThreading/update-edge-weight.ll @@ -0,0 +1,43 @@ +; RUN: opt -S -jump-threading %s | FileCheck %s + +; Test if edge weights are properly updated after jump threading. + +; CHECK: !2 = !{!"branch_weights", i32 22, i32 7} + +define void @foo(i32 %n) !prof !0 { +entry: + %cmp = icmp sgt i32 %n, 10 + br i1 %cmp, label %if.then.1, label %if.else.1, !prof !1 + +if.then.1: + tail call void @a() + br label %if.cond + +if.else.1: + tail call void @b() + br label %if.cond + +if.cond: + %cmp1 = icmp sgt i32 %n, 5 + br i1 %cmp1, label %if.then.2, label %if.else.2, !prof !2 + +if.then.2: + tail call void @c() + br label %if.end + +if.else.2: + tail call void @d() + br label %if.end + +if.end: + ret void +} + +declare void @a() +declare void @b() +declare void @c() +declare void @d() + +!0 = !{!"function_entry_count", i64 1} +!1 = !{!"branch_weights", i32 10, i32 5} +!2 = !{!"branch_weights", i32 10, i32 1} -- 2.34.1