X-Git-Url: http://demsky.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FCodeGen%2FSelectionDAG%2FSelectionDAGBuilder.cpp;h=85b2d5f62ff2bc0e6eba0c369a527236aa82c803;hb=d0872b393c76547945a8cedd28cbfe50b6764ae9;hp=c7ef51316b3637fb837adf7e607e693c893fd9aa;hpb=f6548b86b6b7153eb31d05dc6eafae9245735128;p=oota-llvm.git diff --git a/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index c7ef51316b3..85b2d5f62ff 100644 --- a/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -78,12 +78,16 @@ LimitFPPrecision("limit-float-precision", cl::location(LimitFloatPrecision), cl::init(0)); +static cl::opt +EnableFMFInDAG("enable-fmf-dag", cl::init(false), cl::Hidden, + cl::desc("Enable fast-math-flags for DAG nodes")); + // Limit the width of DAG chains. This is important in general to prevent -// prevent DAG-based analysis from blowing up. For example, alias analysis and +// DAG-based analysis from blowing up. For example, alias analysis and // load clustering may not complete in reasonable time. It is difficult to // recognize and avoid this situation within each individual analysis, and // future analyses are likely to have the same behavior. Limiting DAG width is -// the safe approach, and will be especially important with global DAGs. +// the safe approach and will be especially important with global DAGs. // // MaxParallelChains default is arbitrarily high to avoid affecting // optimization, but could be lowered to improve compile time. Any ld-ld-st-st @@ -1002,7 +1006,16 @@ bool SelectionDAGBuilder::findValue(const Value *V) const { SDValue SelectionDAGBuilder::getNonRegisterValue(const Value *V) { // If we already have an SDValue for this value, use it. SDValue &N = NodeMap[V]; - if (N.getNode()) return N; + if (N.getNode()) { + if (isa(N) || isa(N)) { + // Remove the debug location from the node as the node is about to be used + // in a location which may differ from the original debug location. This + // is relevant to Constant and ConstantFP nodes because they can appear + // as constant expressions inside PHI nodes. + N->setDebugLoc(DebugLoc()); + } + return N; + } // Otherwise create a new SDValue and remember it. SDValue Val = getValueImpl(V); @@ -1432,8 +1445,8 @@ void SelectionDAGBuilder::FindMergedConditions(const Value *Cond, // We have flexibility in setting Prob for BB1 and Prob for TmpBB. // The requirement is that // TrueProb for BB1 + (FalseProb for BB1 * TrueProb for TmpBB) - // = TrueProb for orignal BB. - // Assuming the orignal weights are A and B, one choice is to set BB1's + // = TrueProb for original BB. + // Assuming the original weights are A and B, one choice is to set BB1's // weights to A and A+2B, and set TmpBB's weights to A and 2B. This choice // assumes that // TrueProb for BB1 == FalseProb for BB1 * TrueProb for TmpBB. @@ -1468,8 +1481,8 @@ void SelectionDAGBuilder::FindMergedConditions(const Value *Cond, // We have flexibility in setting Prob for BB1 and Prob for TmpBB. // The requirement is that // FalseProb for BB1 + (TrueProb for BB1 * FalseProb for TmpBB) - // = FalseProb for orignal BB. - // Assuming the orignal weights are A and B, one choice is to set BB1's + // = FalseProb for original BB. + // Assuming the original weights are A and B, one choice is to set BB1's // weights to 2A+B and B, and set TmpBB's weights to 2A and B. This choice // assumes that // FalseProb for BB1 == TrueProb for BB1 * FalseProb for TmpBB. @@ -2140,7 +2153,7 @@ void SelectionDAGBuilder::visitBinary(const User &I, unsigned OpCode) { bool nsw = false; bool exact = false; FastMathFlags FMF; - + if (const OverflowingBinaryOperator *OFBinOp = dyn_cast(&I)) { nuw = OFBinOp->hasNoUnsignedWrap(); @@ -2151,16 +2164,18 @@ void SelectionDAGBuilder::visitBinary(const User &I, unsigned OpCode) { exact = ExactOp->isExact(); if (const FPMathOperator *FPOp = dyn_cast(&I)) FMF = FPOp->getFastMathFlags(); - + SDNodeFlags Flags; Flags.setExact(exact); Flags.setNoSignedWrap(nsw); Flags.setNoUnsignedWrap(nuw); - Flags.setAllowReciprocal(FMF.allowReciprocal()); - Flags.setNoInfs(FMF.noInfs()); - Flags.setNoNaNs(FMF.noNaNs()); - Flags.setNoSignedZeros(FMF.noSignedZeros()); - Flags.setUnsafeAlgebra(FMF.unsafeAlgebra()); + if (EnableFMFInDAG) { + Flags.setAllowReciprocal(FMF.allowReciprocal()); + Flags.setNoInfs(FMF.noInfs()); + Flags.setNoNaNs(FMF.noNaNs()); + Flags.setNoSignedZeros(FMF.noSignedZeros()); + Flags.setUnsafeAlgebra(FMF.unsafeAlgebra()); + } SDValue BinNodeValue = DAG.getNode(OpCode, getCurSDLoc(), Op1.getValueType(), Op1, Op2, &Flags); setValue(&I, BinNodeValue); @@ -2223,17 +2238,11 @@ void SelectionDAGBuilder::visitSDiv(const User &I) { SDValue Op1 = getValue(I.getOperand(0)); SDValue Op2 = getValue(I.getOperand(1)); - // Turn exact SDivs into multiplications. - // FIXME: This should be in DAGCombiner, but it doesn't have access to the - // exact bit. - if (isa(&I) && cast(&I)->isExact() && - !isa(Op1) && - isa(Op2) && !cast(Op2)->isNullValue()) - setValue(&I, DAG.getTargetLoweringInfo() - .BuildExactSDIV(Op1, Op2, getCurSDLoc(), DAG)); - else - setValue(&I, DAG.getNode(ISD::SDIV, getCurSDLoc(), Op1.getValueType(), - Op1, Op2)); + SDNodeFlags Flags; + Flags.setExact(isa(&I) && + cast(&I)->isExact()); + setValue(&I, DAG.getNode(ISD::SDIV, getCurSDLoc(), Op1.getValueType(), Op1, + Op2, &Flags)); } void SelectionDAGBuilder::visitICmp(const User &I) { @@ -2273,19 +2282,51 @@ void SelectionDAGBuilder::visitSelect(const User &I) { SmallVector Values(NumValues); SDValue Cond = getValue(I.getOperand(0)); - SDValue TrueVal = getValue(I.getOperand(1)); - SDValue FalseVal = getValue(I.getOperand(2)); + SDValue LHSVal = getValue(I.getOperand(1)); + SDValue RHSVal = getValue(I.getOperand(2)); + auto BaseOps = {Cond}; ISD::NodeType OpCode = Cond.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT; - for (unsigned i = 0; i != NumValues; ++i) + // Min/max matching is only viable if all output VTs are the same. + if (std::equal(ValueVTs.begin(), ValueVTs.end(), ValueVTs.begin())) { + Value *LHS, *RHS; + SelectPatternFlavor SPF = matchSelectPattern(const_cast(&I), LHS, RHS); + ISD::NodeType Opc = ISD::DELETED_NODE; + switch (SPF) { + case SPF_UMAX: Opc = ISD::UMAX; break; + case SPF_UMIN: Opc = ISD::UMIN; break; + case SPF_SMAX: Opc = ISD::SMAX; break; + case SPF_SMIN: Opc = ISD::SMIN; break; + default: break; + } + + EVT VT = ValueVTs[0]; + LLVMContext &Ctx = *DAG.getContext(); + auto &TLI = DAG.getTargetLoweringInfo(); + while (TLI.getTypeAction(Ctx, VT) == TargetLoweringBase::TypeSplitVector) + VT = TLI.getTypeToTransformTo(Ctx, VT); + + if (Opc != ISD::DELETED_NODE && TLI.isOperationLegalOrCustom(Opc, VT) && + // If the underlying comparison instruction is used by any other instruction, + // the consumed instructions won't be destroyed, so it is not profitable + // to convert to a min/max. + cast(&I)->getCondition()->hasOneUse()) { + OpCode = Opc; + LHSVal = getValue(LHS); + RHSVal = getValue(RHS); + BaseOps = {}; + } + } + + for (unsigned i = 0; i != NumValues; ++i) { + SmallVector Ops(BaseOps.begin(), BaseOps.end()); + Ops.push_back(SDValue(LHSVal.getNode(), LHSVal.getResNo() + i)); + Ops.push_back(SDValue(RHSVal.getNode(), RHSVal.getResNo() + i)); Values[i] = DAG.getNode(OpCode, getCurSDLoc(), - TrueVal.getNode()->getValueType(TrueVal.getResNo()+i), - Cond, - SDValue(TrueVal.getNode(), - TrueVal.getResNo() + i), - SDValue(FalseVal.getNode(), - FalseVal.getResNo() + i)); + LHSVal.getNode()->getValueType(LHSVal.getResNo()+i), + Ops); + } setValue(&I, DAG.getNode(ISD::MERGE_VALUES, getCurSDLoc(), DAG.getVTList(ValueVTs), Values)); @@ -2836,7 +2877,17 @@ void SelectionDAGBuilder::visitLoad(const LoadInst &I) { bool isVolatile = I.isVolatile(); bool isNonTemporal = I.getMetadata(LLVMContext::MD_nontemporal) != nullptr; - bool isInvariant = I.getMetadata(LLVMContext::MD_invariant_load) != nullptr; + + // The IR notion of invariant_load only guarantees that all *non-faulting* + // invariant loads result in the same value. The MI notion of invariant load + // guarantees that the load can be legally moved to any location within its + // containing function. The MI notion of invariant_load is stronger than the + // IR notion of invariant_load -- an MI invariant_load is an IR invariant_load + // with a guarantee that the location being loaded from is dereferenceable + // throughout the function's lifetime. + + bool isInvariant = I.getMetadata(LLVMContext::MD_invariant_load) != nullptr && + isDereferenceablePointer(SV, *DAG.getTarget().getDataLayout()); unsigned Alignment = I.getAlignment(); AAMDNodes AAInfo; @@ -2857,7 +2908,7 @@ void SelectionDAGBuilder::visitLoad(const LoadInst &I) { // Serialize volatile loads with other side effects. Root = getRoot(); else if (AA->pointsToConstantMemory( - AliasAnalysis::Location(SV, AA->getTypeStoreSize(Ty), AAInfo))) { + MemoryLocation(SV, AA->getTypeStoreSize(Ty), AAInfo))) { // Do not serialize (non-volatile) loads of constant memory with anything. Root = DAG.getEntryNode(); ConstantMemory = true; @@ -2872,8 +2923,7 @@ void SelectionDAGBuilder::visitLoad(const LoadInst &I) { Root = TLI.prepareVolatileOrAtomicLoad(Root, dl, DAG); SmallVector Values(NumValues); - SmallVector Chains(std::min(unsigned(MaxParallelChains), - NumValues)); + SmallVector Chains(std::min(MaxParallelChains, NumValues)); EVT PtrVT = Ptr.getValueType(); unsigned ChainI = 0; for (unsigned i = 0; i != NumValues; ++i, ++ChainI) { @@ -2937,8 +2987,7 @@ void SelectionDAGBuilder::visitStore(const StoreInst &I) { SDValue Ptr = getValue(PtrV); SDValue Root = getRoot(); - SmallVector Chains(std::min(unsigned(MaxParallelChains), - NumValues)); + SmallVector Chains(std::min(MaxParallelChains, NumValues)); EVT PtrVT = Ptr.getValueType(); bool isVolatile = I.isVolatile(); bool isNonTemporal = I.getMetadata(LLVMContext::MD_nontemporal) != nullptr; @@ -3106,10 +3155,8 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I) { const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range); SDValue InChain = DAG.getRoot(); - if (AA->pointsToConstantMemory( - AliasAnalysis::Location(PtrOperand, - AA->getTypeStoreSize(I.getType()), - AAInfo))) { + if (AA->pointsToConstantMemory(MemoryLocation( + PtrOperand, AA->getTypeStoreSize(I.getType()), AAInfo))) { // Do not serialize (non-volatile) loads of constant memory with anything. InChain = DAG.getEntryNode(); } @@ -3151,10 +3198,9 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) { Value *BasePtr = Ptr; bool UniformBase = getUniformBase(BasePtr, Base, Index, this); bool ConstantMemory = false; - if (UniformBase && AA->pointsToConstantMemory( - AliasAnalysis::Location(BasePtr, - AA->getTypeStoreSize(I.getType()), - AAInfo))) { + if (UniformBase && + AA->pointsToConstantMemory( + MemoryLocation(BasePtr, AA->getTypeStoreSize(I.getType()), AAInfo))) { // Do not serialize (non-volatile) loads of constant memory with anything. Root = DAG.getEntryNode(); ConstantMemory = true; @@ -4033,16 +4079,20 @@ SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, unsigned Intrinsic) { return nullptr; case Intrinsic::read_register: { Value *Reg = I.getArgOperand(0); + SDValue Chain = getRoot(); SDValue RegName = DAG.getMDNode(cast(cast(Reg)->getMetadata())); EVT VT = TLI.getValueType(I.getType()); - setValue(&I, DAG.getNode(ISD::READ_REGISTER, sdl, VT, RegName)); + Res = DAG.getNode(ISD::READ_REGISTER, sdl, + DAG.getVTList(VT, MVT::Other), Chain, RegName); + setValue(&I, Res); + DAG.setRoot(Res.getValue(1)); return nullptr; } case Intrinsic::write_register: { Value *Reg = I.getArgOperand(0); Value *RegValue = I.getArgOperand(1); - SDValue Chain = getValue(RegValue).getOperand(0); + SDValue Chain = getRoot(); SDValue RegName = DAG.getMDNode(cast(cast(Reg)->getMetadata())); DAG.setRoot(DAG.getNode(ISD::WRITE_REGISTER, sdl, MVT::Other, Chain, @@ -4920,11 +4970,9 @@ SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, unsigned Intrinsic) { MF.getMMI().getContext().getOrCreateFrameAllocSymbol( GlobalValue::getRealLinkageName(Fn->getName()), IdxVal); - // Create a TargetExternalSymbol for the label to avoid any target lowering + // Create a MCSymbol for the label to avoid any target lowering // that would make this PC relative. - StringRef Name = FrameAllocSym->getName(); - assert(Name.data()[Name.size()] == '\0' && "not null terminated"); - SDValue OffsetSym = DAG.getTargetExternalSymbol(Name.data(), PtrVT); + SDValue OffsetSym = DAG.getMCSymbol(FrameAllocSym, PtrVT); SDValue OffsetVal = DAG.getNode(ISD::FRAME_ALLOC_RECOVER, sdl, PtrVT, OffsetSym); @@ -4944,6 +4992,7 @@ SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, unsigned Intrinsic) { assert(Reg && "cannot get exception code on this platform"); MVT PtrVT = TLI.getPointerTy(); const TargetRegisterClass *PtrRC = TLI.getRegClassFor(PtrVT); + assert(FuncInfo.MBB->isLandingPad() && "eh.exceptioncode in non-lpad"); unsigned VReg = FuncInfo.MBB->addLiveIn(Reg, PtrRC); SDValue N = DAG.getCopyFromReg(DAG.getEntryNode(), getCurSDLoc(), VReg, PtrVT); @@ -4963,7 +5012,7 @@ SelectionDAGBuilder::lowerInvokable(TargetLowering::CallLoweringInfo &CLI, if (LandingPad) { // Insert a label before the invoke call to mark the try range. This can be // used to detect deletion of the invoke via the MachineModuleInfo. - BeginLabel = MMI.getContext().CreateTempSymbol(); + BeginLabel = MMI.getContext().createTempSymbol(); // For SjLj, keep track of which landing pads go with which invokes // so as to maintain the ordering of pads in the LSDA. @@ -5006,7 +5055,7 @@ SelectionDAGBuilder::lowerInvokable(TargetLowering::CallLoweringInfo &CLI, if (LandingPad) { // Insert a label at the end of the invoke call to mark the try range. This // can be used to detect deletion of the invoke via the MachineModuleInfo. - MCSymbol *EndLabel = MMI.getContext().CreateTempSymbol(); + MCSymbol *EndLabel = MMI.getContext().createTempSymbol(); DAG.setRoot(DAG.getEHLabel(getCurSDLoc(), getRoot(), EndLabel)); // Inform MachineModuleInfo of range. @@ -5434,7 +5483,7 @@ void SelectionDAGBuilder::visitCall(const CallInst &I) { return; } } - if (unsigned IID = F->getIntrinsicID()) { + if (Intrinsic::ID IID = F->getIntrinsicID()) { RenameFn = visitIntrinsicCall(I, IID); if (!RenameFn) return; @@ -7421,7 +7470,7 @@ bool SelectionDAGBuilder::buildJumpTable(CaseClusterVector &Clusters, JumpTableHeader JTH(Clusters[First].Low->getValue(), Clusters[Last].High->getValue(), SI->getCondition(), nullptr, false); - JTCases.push_back(JumpTableBlock(JTH, JT)); + JTCases.emplace_back(std::move(JTH), std::move(JT)); JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High, JTCases.size() - 1, Weight); @@ -7447,6 +7496,31 @@ void SelectionDAGBuilder::findJumpTables(CaseClusterVector &Clusters, const int64_t N = Clusters.size(); const unsigned MinJumpTableSize = TLI.getMinimumJumpTableEntries(); + // TotalCases[i]: Total nbr of cases in Clusters[0..i]. + SmallVector TotalCases(N); + + for (unsigned i = 0; i < N; ++i) { + APInt Hi = Clusters[i].High->getValue(); + APInt Lo = Clusters[i].Low->getValue(); + TotalCases[i] = (Hi - Lo).getLimitedValue() + 1; + if (i != 0) + TotalCases[i] += TotalCases[i - 1]; + } + + if (N >= MinJumpTableSize && isDense(Clusters, &TotalCases[0], 0, N - 1)) { + // Cheap case: the whole range might be suitable for jump table. + CaseCluster JTCluster; + if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) { + Clusters[0] = JTCluster; + Clusters.resize(1); + return; + } + } + + // The algorithm below is not suitable for -O0. + if (TM.getOptLevel() == CodeGenOpt::None) + return; + // Split Clusters into minimum number of dense partitions. The algorithm uses // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code // for the Case Statement'" (1994), but builds the MinPartitions array in @@ -7460,16 +7534,6 @@ void SelectionDAGBuilder::findJumpTables(CaseClusterVector &Clusters, SmallVector LastElement(N); // NumTables[i]: nbr of >= MinJumpTableSize partitions from Clusters[i..N-1]. SmallVector NumTables(N); - // TotalCases[i]: Total nbr of cases in Clusters[0..i]. - SmallVector TotalCases(N); - - for (unsigned i = 0; i < N; ++i) { - APInt Hi = Clusters[i].High->getValue(); - APInt Lo = Clusters[i].Low->getValue(); - TotalCases[i] = (Hi - Lo).getLimitedValue() + 1; - if (i != 0) - TotalCases[i] += TotalCases[i - 1]; - } // Base case: There is only one way to partition Clusters[N-1]. MinPartitions[N - 1] = 1; @@ -7584,7 +7648,7 @@ bool SelectionDAGBuilder::buildBitTests(CaseClusterVector &Clusters, const int BitWidth = DAG.getTargetLoweringInfo().getPointerTy().getSizeInBits(); - assert((High - Low + 1).sle(BitWidth) && "Case range must fit in bit mask!"); + assert(rangeFitsInWord(Low, High) && "Case range must fit in bit mask!"); if (Low.isNonNegative() && High.slt(BitWidth)) { // Optimize the case where all the case values fit in a @@ -7612,10 +7676,9 @@ bool SelectionDAGBuilder::buildBitTests(CaseClusterVector &Clusters, // Update Mask, Bits and ExtraWeight. uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue(); uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue(); - for (uint64_t j = Lo; j <= Hi; ++j) { - CB->Mask |= 1ULL << j; - CB->Bits++; - } + assert(Hi >= Lo && Hi < 64 && "Invalid bit case!"); + CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo; + CB->Bits += Hi - Lo + 1; CB->ExtraWeight += Clusters[i].Weight; TotalWeight += Clusters[i].Weight; assert(TotalWeight >= Clusters[i].Weight && "Weight overflow!"); @@ -7634,9 +7697,9 @@ bool SelectionDAGBuilder::buildBitTests(CaseClusterVector &Clusters, FuncInfo.MF->CreateMachineBasicBlock(SI->getParent()); BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraWeight)); } - BitTestCases.push_back(BitTestBlock(LowBound, CmpRange, SI->getCondition(), - -1U, MVT::Other, false, nullptr, - nullptr, std::move(BTI))); + BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange), + SI->getCondition(), -1U, MVT::Other, false, nullptr, + nullptr, std::move(BTI)); BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High, BitTestCases.size() - 1, TotalWeight); @@ -7658,6 +7721,10 @@ void SelectionDAGBuilder::findBitTestClusters(CaseClusterVector &Clusters, assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue())); #endif + // The algorithm below is not suitable for -O0. + if (TM.getOptLevel() == CodeGenOpt::None) + return; + // If target does not have legal shift left, do not emit bit tests at all. const TargetLowering &TLI = DAG.getTargetLoweringInfo(); EVT PTy = TLI.getPointerTy(); @@ -7730,8 +7797,10 @@ void SelectionDAGBuilder::findBitTestClusters(CaseClusterVector &Clusters, if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) { Clusters[DstIndex++] = BitTestCluster; } else { - for (unsigned I = First; I <= Last; ++I) - std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I])); + size_t NumClusters = Last - First + 1; + std::memmove(&Clusters[DstIndex], &Clusters[First], + sizeof(Clusters[0]) * NumClusters); + DstIndex += NumClusters; } } Clusters.resize(DstIndex); @@ -7767,22 +7836,17 @@ void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond, const APInt &BigValue = Big.Low->getValue(); // Check that there is only one bit different. - if (BigValue.countPopulation() == SmallValue.countPopulation() + 1 && - (SmallValue | BigValue) == BigValue) { - // Isolate the common bit. - APInt CommonBit = BigValue & ~SmallValue; - assert((SmallValue | CommonBit) == BigValue && - CommonBit.countPopulation() == 1 && "Not a common bit?"); - + APInt CommonBit = BigValue ^ SmallValue; + if (CommonBit.isPowerOf2()) { SDValue CondLHS = getValue(Cond); EVT VT = CondLHS.getValueType(); SDLoc DL = getCurSDLoc(); SDValue Or = DAG.getNode(ISD::OR, DL, VT, CondLHS, DAG.getConstant(CommonBit, DL, VT)); - SDValue Cond = DAG.getSetCC(DL, MVT::i1, Or, - DAG.getConstant(BigValue, DL, VT), - ISD::SETEQ); + SDValue Cond = DAG.getSetCC( + DL, MVT::i1, Or, DAG.getConstant(BigValue | SmallValue, DL, VT), + ISD::SETEQ); // Update successor info. // Both Small and Big will jump to Small.BB, so we sum up the weights. @@ -7924,6 +7988,18 @@ void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond, } } +unsigned SelectionDAGBuilder::caseClusterRank(const CaseCluster &CC, + CaseClusterIt First, + CaseClusterIt Last) { + return std::count_if(First, Last + 1, [&](const CaseCluster &X) { + if (X.Weight != CC.Weight) + return X.Weight > CC.Weight; + + // Ties are broken by comparing the case value. + return X.Low->getValue().slt(CC.Low->getValue()); + }); +} + void SelectionDAGBuilder::splitWorkItem(SwitchWorkList &WorkList, const SwitchWorkListItem &W, Value *Cond, @@ -7953,6 +8029,48 @@ void SelectionDAGBuilder::splitWorkItem(SwitchWorkList &WorkList, RightWeight += (--FirstRight)->Weight; I++; } + + for (;;) { + // Our binary search tree differs from a typical BST in that ours can have up + // to three values in each leaf. The pivot selection above doesn't take that + // into account, which means the tree might require more nodes and be less + // efficient. We compensate for this here. + + unsigned NumLeft = LastLeft - W.FirstCluster + 1; + unsigned NumRight = W.LastCluster - FirstRight + 1; + + if (std::min(NumLeft, NumRight) < 3 && std::max(NumLeft, NumRight) > 3) { + // If one side has less than 3 clusters, and the other has more than 3, + // consider taking a cluster from the other side. + + if (NumLeft < NumRight) { + // Consider moving the first cluster on the right to the left side. + CaseCluster &CC = *FirstRight; + unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster); + unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft); + if (LeftSideRank <= RightSideRank) { + // Moving the cluster to the left does not demote it. + ++LastLeft; + ++FirstRight; + continue; + } + } else { + assert(NumRight < NumLeft); + // Consider moving the last element on the left to the right side. + CaseCluster &CC = *LastLeft; + unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft); + unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster); + if (RightSideRank <= LeftSideRank) { + // Moving the cluster to the right does not demot it. + --LastLeft; + --FirstRight; + continue; + } + } + } + break; + } + assert(LastLeft + 1 == FirstRight); assert(LastLeft >= W.FirstCluster); assert(FirstRight <= W.LastCluster); @@ -8076,11 +8194,8 @@ void SelectionDAGBuilder::visitSwitch(const SwitchInst &SI) { return; } - if (TM.getOptLevel() != CodeGenOpt::None) { - findJumpTables(Clusters, &SI, DefaultMBB); - findBitTestClusters(Clusters, &SI); - } - + findJumpTables(Clusters, &SI, DefaultMBB); + findBitTestClusters(Clusters, &SI); DEBUG({ dbgs() << "Case clusters: ";