X-Git-Url: http://demsky.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FAnalysis%2FVectorUtils.cpp;h=5fb517e8edb5578e08952b8a53023a7d8e7f3293;hb=813f44a29fd0fd140127023222d0633e23783bcc;hp=93720857662f988a8e0a046cd10cdfeb7afbe165;hpb=957622ea93878ccc2cc43f57b1ce1b45d052da3e;p=oota-llvm.git diff --git a/lib/Analysis/VectorUtils.cpp b/lib/Analysis/VectorUtils.cpp index 93720857662..5fb517e8edb 100644 --- a/lib/Analysis/VectorUtils.cpp +++ b/lib/Analysis/VectorUtils.cpp @@ -11,9 +11,12 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/Analysis/DemandedBits.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/PatternMatch.h" @@ -414,9 +417,11 @@ Value *llvm::findScalarElement(Value *V, unsigned EltNo) { /// the input value is (1) a splat constants vector or (2) a sequence /// of instructions that broadcast a single value into a vector. /// -llvm::Value *llvm::getSplatValue(Value *V) { - if (auto *CV = dyn_cast(V)) - return CV->getSplatValue(); +const llvm::Value *llvm::getSplatValue(const Value *V) { + + if (auto *C = dyn_cast(V)) + if (isa(V->getType())) + return C->getSplatValue(); auto *ShuffleInst = dyn_cast(V); if (!ShuffleInst) @@ -434,3 +439,130 @@ llvm::Value *llvm::getSplatValue(Value *V) { return InsertEltInst->getOperand(1); } + +MapVector +llvm::computeMinimumValueSizes(ArrayRef Blocks, DemandedBits &DB, + const TargetTransformInfo *TTI) { + + // DemandedBits will give us every value's live-out bits. But we want + // to ensure no extra casts would need to be inserted, so every DAG + // of connected values must have the same minimum bitwidth. + EquivalenceClasses ECs; + SmallVector Worklist; + SmallPtrSet Roots; + SmallPtrSet Visited; + DenseMap DBits; + SmallPtrSet InstructionSet; + MapVector MinBWs; + + // Determine the roots. We work bottom-up, from truncs or icmps. + bool SeenExtFromIllegalType = false; + for (auto *BB : Blocks) + for (auto &I : *BB) { + InstructionSet.insert(&I); + + if (TTI && (isa(&I) || isa(&I)) && + !TTI->isTypeLegal(I.getOperand(0)->getType())) + SeenExtFromIllegalType = true; + + // Only deal with non-vector integers up to 64-bits wide. + if ((isa(&I) || isa(&I)) && + !I.getType()->isVectorTy() && + I.getOperand(0)->getType()->getScalarSizeInBits() <= 64) { + // Don't make work for ourselves. If we know the loaded type is legal, + // don't add it to the worklist. + if (TTI && isa(&I) && TTI->isTypeLegal(I.getType())) + continue; + + Worklist.push_back(&I); + Roots.insert(&I); + } + } + // Early exit. + if (Worklist.empty() || (TTI && !SeenExtFromIllegalType)) + return MinBWs; + + // Now proceed breadth-first, unioning values together. + while (!Worklist.empty()) { + Value *Val = Worklist.pop_back_val(); + Value *Leader = ECs.getOrInsertLeaderValue(Val); + + if (Visited.count(Val)) + continue; + Visited.insert(Val); + + // Non-instructions terminate a chain successfully. + if (!isa(Val)) + continue; + Instruction *I = cast(Val); + + // If we encounter a type that is larger than 64 bits, we can't represent + // it so bail out. + if (DB.getDemandedBits(I).getBitWidth() > 64) + return MapVector(); + + uint64_t V = DB.getDemandedBits(I).getZExtValue(); + DBits[Leader] |= V; + + // Casts, loads and instructions outside of our range terminate a chain + // successfully. + if (isa(I) || isa(I) || isa(I) || + !InstructionSet.count(I)) + continue; + + // Unsafe casts terminate a chain unsuccessfully. We can't do anything + // useful with bitcasts, ptrtoints or inttoptrs and it'd be unsafe to + // transform anything that relies on them. + if (isa(I) || isa(I) || isa(I) || + !I->getType()->isIntegerTy()) { + DBits[Leader] |= ~0ULL; + continue; + } + + // We don't modify the types of PHIs. Reductions will already have been + // truncated if possible, and inductions' sizes will have been chosen by + // indvars. + if (isa(I)) + continue; + + if (DBits[Leader] == ~0ULL) + // All bits demanded, no point continuing. + continue; + + for (Value *O : cast(I)->operands()) { + ECs.unionSets(Leader, O); + Worklist.push_back(O); + } + } + + // Now we've discovered all values, walk them to see if there are + // any users we didn't see. If there are, we can't optimize that + // chain. + for (auto &I : DBits) + for (auto *U : I.first->users()) + if (U->getType()->isIntegerTy() && DBits.count(U) == 0) + DBits[ECs.getOrInsertLeaderValue(I.first)] |= ~0ULL; + + for (auto I = ECs.begin(), E = ECs.end(); I != E; ++I) { + uint64_t LeaderDemandedBits = 0; + for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) + LeaderDemandedBits |= DBits[*MI]; + + uint64_t MinBW = (sizeof(LeaderDemandedBits) * 8) - + llvm::countLeadingZeros(LeaderDemandedBits); + // Round up to a power of 2 + if (!isPowerOf2_64((uint64_t)MinBW)) + MinBW = NextPowerOf2(MinBW); + for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) { + if (!isa(*MI)) + continue; + Type *Ty = (*MI)->getType(); + if (Roots.count(*MI)) + Ty = cast(*MI)->getOperand(0)->getType(); + if (MinBW < Ty->getScalarSizeInBits()) + MinBWs[cast(*MI)] = MinBW; + } + } + + return MinBWs; +}