if (Begin == End)
return;
- auto UnknownProbCount =
- std::count(Begin, End, BranchProbability::getUnknown());
- assert((UnknownProbCount == 0 ||
- UnknownProbCount == std::distance(Begin, End)) &&
- "Cannot normalize probabilities with known and unknown ones.");
- (void)UnknownProbCount;
-
- uint64_t Sum = std::accumulate(
- Begin, End, uint64_t(0),
- [](uint64_t S, const BranchProbability &BP) { return S + BP.N; });
-
+ unsigned UnknownProbCount = 0;
+ uint64_t Sum = std::accumulate(Begin, End, uint64_t(0),
+ [&](uint64_t S, const BranchProbability &BP) {
+ if (!BP.isUnknown())
+ return S + BP.N;
+ UnknownProbCount++;
+ return S;
+ });
+
+ if (UnknownProbCount > 0) {
+ BranchProbability ProbForUnknown = BranchProbability::getZero();
+ // If the sum of all known probabilities is less than one, evenly distribute
+ // the complement of sum to unknown probabilities. Otherwise, set unknown
+ // probabilities to zeros and continue to normalize known probabilities.
+ if (Sum < BranchProbability::getDenominator())
+ ProbForUnknown = BranchProbability::getRaw(
+ (BranchProbability::getDenominator() - Sum) / UnknownProbCount);
+
+ std::replace_if(Begin, End,
+ [](const BranchProbability &BP) { return BP.isUnknown(); },
+ ProbForUnknown);
+
+ if (Sum <= BranchProbability::getDenominator())
+ return;
+ }
+
if (Sum == 0) {
BranchProbability BP(1, std::distance(Begin, End));
std::fill(Begin, End, BP);
}
TEST(BranchProbabilityTest, NormalizeProbabilities) {
+ const auto UnknownProb = BranchProbability::getUnknown();
{
SmallVector<BranchProbability, 2> Probs{{0, 1}, {0, 1}};
BranchProbability::normalizeProbabilities(Probs.begin(), Probs.end());
EXPECT_EQ(BranchProbability::getDenominator() / 3 + 1,
Probs[2].getNumerator());
}
+ {
+ SmallVector<BranchProbability, 2> Probs{{0, 1}, UnknownProb};
+ BranchProbability::normalizeProbabilities(Probs.begin(), Probs.end());
+ EXPECT_EQ(0, Probs[0].getNumerator());
+ EXPECT_EQ(BranchProbability::getDenominator(), Probs[1].getNumerator());
+ }
+ {
+ SmallVector<BranchProbability, 2> Probs{{1, 1}, UnknownProb};
+ BranchProbability::normalizeProbabilities(Probs.begin(), Probs.end());
+ EXPECT_EQ(BranchProbability::getDenominator(), Probs[0].getNumerator());
+ EXPECT_EQ(0, Probs[1].getNumerator());
+ }
+ {
+ SmallVector<BranchProbability, 2> Probs{{1, 2}, UnknownProb};
+ BranchProbability::normalizeProbabilities(Probs.begin(), Probs.end());
+ EXPECT_EQ(BranchProbability::getDenominator() / 2, Probs[0].getNumerator());
+ EXPECT_EQ(BranchProbability::getDenominator() / 2, Probs[1].getNumerator());
+ }
+ {
+ SmallVector<BranchProbability, 4> Probs{
+ {1, 2}, {1, 2}, {1, 2}, UnknownProb};
+ BranchProbability::normalizeProbabilities(Probs.begin(), Probs.end());
+ EXPECT_EQ(BranchProbability::getDenominator() / 3 + 1,
+ Probs[0].getNumerator());
+ EXPECT_EQ(BranchProbability::getDenominator() / 3 + 1,
+ Probs[1].getNumerator());
+ EXPECT_EQ(BranchProbability::getDenominator() / 3 + 1,
+ Probs[2].getNumerator());
+ EXPECT_EQ(0, Probs[3].getNumerator());
+ }
}
}