X-Git-Url: http://demsky.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FAnalysis%2FVectorUtils.cpp;h=5fb517e8edb5578e08952b8a53023a7d8e7f3293;hb=813f44a29fd0fd140127023222d0633e23783bcc;hp=1ebff0f7c056ced889e780881f6865b49cb0e5e1;hpb=84bbcfe2009eafcde83fb7a7f54b7d1aad46f52a;p=oota-llvm.git diff --git a/lib/Analysis/VectorUtils.cpp b/lib/Analysis/VectorUtils.cpp index 1ebff0f7c05..5fb517e8edb 100644 --- a/lib/Analysis/VectorUtils.cpp +++ b/lib/Analysis/VectorUtils.cpp @@ -11,13 +11,20 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/Analysis/DemandedBits.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Value.h" +#include "llvm/IR/Constants.h" + +using namespace llvm; +using namespace llvm::PatternMatch; /// \brief Identify if the intrinsic is trivially vectorizable. /// This method returns true if the intrinsic's argument types are all @@ -79,7 +86,7 @@ bool llvm::hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, /// d) call should only reads memory. /// If all these condition is met then return ValidIntrinsicID /// else return not_intrinsic. -llvm::Intrinsic::ID +Intrinsic::ID llvm::checkUnaryFloatSignature(const CallInst &I, Intrinsic::ID ValidIntrinsicID) { if (I.getNumArgOperands() != 1 || @@ -98,7 +105,7 @@ llvm::checkUnaryFloatSignature(const CallInst &I, /// d) call should only reads memory. /// If all these condition is met then return ValidIntrinsicID /// else return not_intrinsic. -llvm::Intrinsic::ID +Intrinsic::ID llvm::checkBinaryFloatSignature(const CallInst &I, Intrinsic::ID ValidIntrinsicID) { if (I.getNumArgOperands() != 2 || @@ -114,8 +121,8 @@ llvm::checkBinaryFloatSignature(const CallInst &I, /// \brief Returns intrinsic ID for call. /// For the input call instruction it finds mapping intrinsic and returns /// its ID, in case it does not found it return not_intrinsic. -llvm::Intrinsic::ID llvm::getIntrinsicIDForCall(CallInst *CI, - const TargetLibraryInfo *TLI) { +Intrinsic::ID llvm::getIntrinsicIDForCall(CallInst *CI, + const TargetLibraryInfo *TLI) { // If we have an intrinsic call, check if it is trivially vectorizable. if (IntrinsicInst *II = dyn_cast(CI)) { Intrinsic::ID ID = II->getIntrinsicID(); @@ -228,8 +235,7 @@ unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) { cast(Gep->getType()->getScalarType())->getElementType()); // Walk backwards and try to peel off zeros. - while (LastOperand > 1 && - match(Gep->getOperand(LastOperand), llvm::PatternMatch::m_Zero())) { + while (LastOperand > 1 && match(Gep->getOperand(LastOperand), m_Zero())) { // Find the type we're currently indexing into. gep_type_iterator GEPTI = gep_type_begin(Gep); std::advance(GEPTI, LastOperand - 1); @@ -247,8 +253,7 @@ unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) { /// \brief If the argument is a GEP, then returns the operand identified by /// getGEPInductionOperand. However, if there is some other non-loop-invariant /// operand, it returns that instead. -llvm::Value *llvm::stripGetElementPtr(llvm::Value *Ptr, ScalarEvolution *SE, - Loop *Lp) { +Value *llvm::stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { GetElementPtrInst *GEP = dyn_cast(Ptr); if (!GEP) return Ptr; @@ -265,8 +270,8 @@ llvm::Value *llvm::stripGetElementPtr(llvm::Value *Ptr, ScalarEvolution *SE, } /// \brief If a value has only one user that is a CastInst, return it. -llvm::Value *llvm::getUniqueCastUse(llvm::Value *Ptr, Loop *Lp, Type *Ty) { - llvm::Value *UniqueCast = nullptr; +Value *llvm::getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) { + Value *UniqueCast = nullptr; for (User *U : Ptr->users()) { CastInst *CI = dyn_cast(U); if (CI && CI->getType() == Ty) { @@ -281,8 +286,7 @@ llvm::Value *llvm::getUniqueCastUse(llvm::Value *Ptr, Loop *Lp, Type *Ty) { /// \brief Get the stride of a pointer access in a loop. Looks for symbolic /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise. -llvm::Value *llvm::getStrideFromPointer(llvm::Value *Ptr, ScalarEvolution *SE, - Loop *Lp) { +Value *llvm::getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { auto *PtrTy = dyn_cast(Ptr->getType()); if (!PtrTy || PtrTy->isAggregateType()) return nullptr; @@ -290,7 +294,7 @@ llvm::Value *llvm::getStrideFromPointer(llvm::Value *Ptr, ScalarEvolution *SE, // Try to remove a gep instruction to make the pointer (actually index at this // point) easier analyzable. If OrigPtr is equal to Ptr we are analzying the // pointer, otherwise, we are analyzing the index. - llvm::Value *OrigPtr = Ptr; + Value *OrigPtr = Ptr; // The size of the pointer access. int64_t PtrAccessSize = 1; @@ -346,7 +350,7 @@ llvm::Value *llvm::getStrideFromPointer(llvm::Value *Ptr, ScalarEvolution *SE, if (!U) return nullptr; - llvm::Value *Stride = U->getValue(); + Value *Stride = U->getValue(); if (!Lp->isLoopInvariant(Stride)) return nullptr; @@ -361,7 +365,7 @@ llvm::Value *llvm::getStrideFromPointer(llvm::Value *Ptr, ScalarEvolution *SE, /// \brief Given a vector and an element number, see if the scalar value is /// already around as a register, for example if it were inserted then extracted /// from the vector. -llvm::Value *llvm::findScalarElement(llvm::Value *V, unsigned EltNo) { +Value *llvm::findScalarElement(Value *V, unsigned EltNo) { assert(V->getType()->isVectorTy() && "Not looking at a vector?"); VectorType *VTy = cast(V->getType()); unsigned Width = VTy->getNumElements(); @@ -399,13 +403,166 @@ llvm::Value *llvm::findScalarElement(llvm::Value *V, unsigned EltNo) { // Extract a value from a vector add operation with a constant zero. Value *Val = nullptr; Constant *Con = nullptr; - if (match(V, - llvm::PatternMatch::m_Add(llvm::PatternMatch::m_Value(Val), - llvm::PatternMatch::m_Constant(Con)))) { - if (Con->getAggregateElement(EltNo)->isNullValue()) - return findScalarElement(Val, EltNo); - } + if (match(V, m_Add(m_Value(Val), m_Constant(Con)))) + if (Constant *Elt = Con->getAggregateElement(EltNo)) + if (Elt->isNullValue()) + return findScalarElement(Val, EltNo); // Otherwise, we don't know. return nullptr; } + +/// \brief Get splat value if the input is a splat vector or return nullptr. +/// This function is not fully general. It checks only 2 cases: +/// the input value is (1) a splat constants vector or (2) a sequence +/// of instructions that broadcast a single value into a vector. +/// +const llvm::Value *llvm::getSplatValue(const Value *V) { + + if (auto *C = dyn_cast(V)) + if (isa(V->getType())) + return C->getSplatValue(); + + auto *ShuffleInst = dyn_cast(V); + if (!ShuffleInst) + return nullptr; + // All-zero (or undef) shuffle mask elements. + for (int MaskElt : ShuffleInst->getShuffleMask()) + if (MaskElt != 0 && MaskElt != -1) + return nullptr; + // The first shuffle source is 'insertelement' with index 0. + auto *InsertEltInst = + dyn_cast(ShuffleInst->getOperand(0)); + if (!InsertEltInst || !isa(InsertEltInst->getOperand(2)) || + !cast(InsertEltInst->getOperand(2))->isNullValue()) + return nullptr; + + return InsertEltInst->getOperand(1); +} + +MapVector +llvm::computeMinimumValueSizes(ArrayRef Blocks, DemandedBits &DB, + const TargetTransformInfo *TTI) { + + // DemandedBits will give us every value's live-out bits. But we want + // to ensure no extra casts would need to be inserted, so every DAG + // of connected values must have the same minimum bitwidth. + EquivalenceClasses ECs; + SmallVector Worklist; + SmallPtrSet Roots; + SmallPtrSet Visited; + DenseMap DBits; + SmallPtrSet InstructionSet; + MapVector MinBWs; + + // Determine the roots. We work bottom-up, from truncs or icmps. + bool SeenExtFromIllegalType = false; + for (auto *BB : Blocks) + for (auto &I : *BB) { + InstructionSet.insert(&I); + + if (TTI && (isa(&I) || isa(&I)) && + !TTI->isTypeLegal(I.getOperand(0)->getType())) + SeenExtFromIllegalType = true; + + // Only deal with non-vector integers up to 64-bits wide. + if ((isa(&I) || isa(&I)) && + !I.getType()->isVectorTy() && + I.getOperand(0)->getType()->getScalarSizeInBits() <= 64) { + // Don't make work for ourselves. If we know the loaded type is legal, + // don't add it to the worklist. + if (TTI && isa(&I) && TTI->isTypeLegal(I.getType())) + continue; + + Worklist.push_back(&I); + Roots.insert(&I); + } + } + // Early exit. + if (Worklist.empty() || (TTI && !SeenExtFromIllegalType)) + return MinBWs; + + // Now proceed breadth-first, unioning values together. + while (!Worklist.empty()) { + Value *Val = Worklist.pop_back_val(); + Value *Leader = ECs.getOrInsertLeaderValue(Val); + + if (Visited.count(Val)) + continue; + Visited.insert(Val); + + // Non-instructions terminate a chain successfully. + if (!isa(Val)) + continue; + Instruction *I = cast(Val); + + // If we encounter a type that is larger than 64 bits, we can't represent + // it so bail out. + if (DB.getDemandedBits(I).getBitWidth() > 64) + return MapVector(); + + uint64_t V = DB.getDemandedBits(I).getZExtValue(); + DBits[Leader] |= V; + + // Casts, loads and instructions outside of our range terminate a chain + // successfully. + if (isa(I) || isa(I) || isa(I) || + !InstructionSet.count(I)) + continue; + + // Unsafe casts terminate a chain unsuccessfully. We can't do anything + // useful with bitcasts, ptrtoints or inttoptrs and it'd be unsafe to + // transform anything that relies on them. + if (isa(I) || isa(I) || isa(I) || + !I->getType()->isIntegerTy()) { + DBits[Leader] |= ~0ULL; + continue; + } + + // We don't modify the types of PHIs. Reductions will already have been + // truncated if possible, and inductions' sizes will have been chosen by + // indvars. + if (isa(I)) + continue; + + if (DBits[Leader] == ~0ULL) + // All bits demanded, no point continuing. + continue; + + for (Value *O : cast(I)->operands()) { + ECs.unionSets(Leader, O); + Worklist.push_back(O); + } + } + + // Now we've discovered all values, walk them to see if there are + // any users we didn't see. If there are, we can't optimize that + // chain. + for (auto &I : DBits) + for (auto *U : I.first->users()) + if (U->getType()->isIntegerTy() && DBits.count(U) == 0) + DBits[ECs.getOrInsertLeaderValue(I.first)] |= ~0ULL; + + for (auto I = ECs.begin(), E = ECs.end(); I != E; ++I) { + uint64_t LeaderDemandedBits = 0; + for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) + LeaderDemandedBits |= DBits[*MI]; + + uint64_t MinBW = (sizeof(LeaderDemandedBits) * 8) - + llvm::countLeadingZeros(LeaderDemandedBits); + // Round up to a power of 2 + if (!isPowerOf2_64((uint64_t)MinBW)) + MinBW = NextPowerOf2(MinBW); + for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) { + if (!isa(*MI)) + continue; + Type *Ty = (*MI)->getType(); + if (Roots.count(*MI)) + Ty = cast(*MI)->getOperand(0)->getType(); + if (MinBW < Ty->getScalarSizeInBits()) + MinBWs[cast(*MI)] = MinBW; + } + } + + return MinBWs; +}