From 4f6b4674be5473319ac5e70c76fd5cb964da2128 Mon Sep 17 00:00:00 2001 From: Evan Cheng Date: Wed, 21 Jul 2010 06:09:07 +0000 Subject: [PATCH] Teach bottom up pre-ra scheduler to track register pressure. Work in progress. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@108991 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/Target/TargetLowering.h | 25 +- .../SelectionDAG/ScheduleDAGRRList.cpp | 244 ++++++++++++++++-- lib/CodeGen/SelectionDAG/TargetLowering.cpp | 20 +- lib/Target/ARM/ARMISelLowering.cpp | 42 ++- lib/Target/ARM/ARMISelLowering.h | 4 +- lib/Target/PIC16/PIC16ISelLowering.cpp | 10 + lib/Target/PIC16/PIC16ISelLowering.h | 3 + 7 files changed, 306 insertions(+), 42 deletions(-) diff --git a/include/llvm/Target/TargetLowering.h b/include/llvm/Target/TargetLowering.h index 926efc4eb73..285c4be5bff 100644 --- a/include/llvm/Target/TargetLowering.h +++ b/include/llvm/Target/TargetLowering.h @@ -174,12 +174,18 @@ public: /// For example, on i386 the rep register class for i8, i16, and i32 are GR32; /// while the rep register class is GR64 on x86_64. virtual const TargetRegisterClass *getRepRegClassFor(EVT VT) const { - assert(VT.isSimple() && "getRegClassFor called on illegal type!"); + assert(VT.isSimple() && "getRepRegClassFor called on illegal type!"); const TargetRegisterClass *RC = RepRegClassForVT[VT.getSimpleVT().SimpleTy]; - assert(RC && "This value type is not natively supported!"); return RC; } + /// getRepRegClassCostFor - Return the cost of the 'representative' register + /// class for the specified value type. + virtual uint8_t getRepRegClassCostFor(EVT VT) const { + assert(VT.isSimple() && "getRepRegClassCostFor called on illegal type!"); + return RepRegClassCostForVT[VT.getSimpleVT().SimpleTy]; + } + /// isTypeLegal - Return true if the target has native support for the /// specified value type. This means that it has a register that directly /// holds it without promotions or expansions. @@ -994,9 +1000,9 @@ protected: } /// findRepresentativeClass - Return the largest legal super-reg register class - /// of the specified register class. - virtual const TargetRegisterClass * - findRepresentativeClass(const TargetRegisterClass *RC) const; + /// of the register class for the specified type and its associated "cost". + virtual std::pair + findRepresentativeClass(EVT VT) const; /// computeRegisterProperties - Once all of the register classes are added, /// this allows us to compute derived properties we expose. @@ -1581,10 +1587,17 @@ private: /// RepRegClassForVT - This indicates the "representative" register class to /// use for each ValueType the target supports natively. This information is - /// used by the scheduler to track register pressure. e.g. On x86, i8, i16, + /// used by the scheduler to track register pressure. By default, the + /// representative register class is the largest legal super-reg register + /// class of the register class of the specified type. e.g. On x86, i8, i16, /// and i32's representative class would be GR32. const TargetRegisterClass *RepRegClassForVT[MVT::LAST_VALUETYPE]; + /// RepRegClassCostForVT - This indicates the "cost" of the "representative" + /// register class for each ValueType. The cost is used by the scheduler to + /// approximate register pressure. + uint8_t RepRegClassCostForVT[MVT::LAST_VALUETYPE]; + /// Synthesizable indicates whether it is OK for the compiler to create new /// operations using this type. All Legal types are Synthesizable except /// MMX types on X86. Non-Legal types are not Synthesizable. diff --git a/lib/CodeGen/SelectionDAG/ScheduleDAGRRList.cpp b/lib/CodeGen/SelectionDAG/ScheduleDAGRRList.cpp index 3ef521c398e..9503b0c08a6 100644 --- a/lib/CodeGen/SelectionDAG/ScheduleDAGRRList.cpp +++ b/lib/CodeGen/SelectionDAG/ScheduleDAGRRList.cpp @@ -24,15 +24,20 @@ #include "llvm/Target/TargetData.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetInstrInfo.h" +#include "llvm/Target/TargetLowering.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include using namespace llvm; +static cl::opt RegPressureAware("reg-pressure-aware-sched", + cl::init(false), cl::Hidden); + STATISTIC(NumBacktracks, "Number of times scheduler backtracked"); STATISTIC(NumUnfolds, "Number of nodes unfolded"); STATISTIC(NumDups, "Number of duplicated nodes"); @@ -181,7 +186,9 @@ private: /// Schedule - Schedule the DAG using list scheduling. void ScheduleDAGRRList::Schedule() { - DEBUG(dbgs() << "********** List Scheduling **********\n"); + DEBUG(dbgs() + << "********** List Scheduling BB#" << BB->getNumber() + << " **********\n"); NumLiveRegs = 0; LiveRegDefs.resize(TRI->getNumRegs(), NULL); @@ -989,7 +996,7 @@ namespace { : SPQ(spq) {} hybrid_ls_rr_sort(const hybrid_ls_rr_sort &RHS) : SPQ(RHS.SPQ) {} - + bool operator()(const SUnit* left, const SUnit* right) const; }; } // end anonymous namespace @@ -1029,23 +1036,46 @@ namespace { std::vector Queue; SF Picker; unsigned CurQueueId; + bool isBottomUp; protected: // SUnits - The SUnits for the current graph. std::vector *SUnits; - + + MachineFunction &MF; const TargetInstrInfo *TII; const TargetRegisterInfo *TRI; + const TargetLowering *TLI; ScheduleDAGRRList *scheduleDAG; // SethiUllmanNumbers - The SethiUllman number for each node. std::vector SethiUllmanNumbers; + /// RegPressure - Tracking current reg pressure per register class. + /// + std::vector RegPressure; + + /// RegLimit - Tracking the number of allocatable registers per register + /// class. + std::vector RegLimit; + public: - RegReductionPriorityQueue(const TargetInstrInfo *tii, - const TargetRegisterInfo *tri) - : Picker(this), CurQueueId(0), - TII(tii), TRI(tri), scheduleDAG(NULL) {} + RegReductionPriorityQueue(MachineFunction &mf, + bool isbottomup, + const TargetInstrInfo *tii, + const TargetRegisterInfo *tri, + const TargetLowering *tli) + : Picker(this), CurQueueId(0), isBottomUp(isbottomup), + MF(mf), TII(tii), TRI(tri), TLI(tli), scheduleDAG(NULL) { + unsigned NumRC = TRI->getNumRegClasses(); + RegLimit.resize(NumRC); + RegPressure.resize(NumRC); + std::fill(RegLimit.begin(), RegLimit.end(), 0); + std::fill(RegPressure.begin(), RegPressure.end(), 0); + for (TargetRegisterInfo::regclass_iterator I = TRI->regclass_begin(), + E = TRI->regclass_end(); I != E; ++I) + RegLimit[(*I)->getID()] = tri->getAllocatableSet(MF, *I).count() - 1; + } void initNodes(std::vector &sunits) { SUnits = &sunits; @@ -1072,6 +1102,7 @@ namespace { void releaseState() { SUnits = 0; SethiUllmanNumbers.clear(); + std::fill(RegPressure.begin(), RegPressure.end(), 0); } unsigned getNodePriority(const SUnit *SU) const { @@ -1139,10 +1170,191 @@ namespace { SU->NodeQueueId = 0; } + // EstimateSpills - Given a scheduling unit, estimate the number of spills + // it would cause by scheduling it at the current cycle. + unsigned EstimateSpills(const SUnit *SU) const { + if (!TLI) + return 0; + + unsigned Spills = 0; + for (SUnit::const_pred_iterator I = SU->Preds.begin(),E = SU->Preds.end(); + I != E; ++I) { + if (I->isCtrl()) + continue; + SUnit *PredSU = I->getSUnit(); + if (PredSU->NumSuccsLeft != PredSU->NumSuccs - 1) + continue; + const SDNode *N = PredSU->getNode(); + if (!N->isMachineOpcode()) + continue; + unsigned NumDefs = TII->get(N->getMachineOpcode()).getNumDefs(); + for (unsigned i = 0; i != NumDefs; ++i) { + EVT VT = N->getValueType(i); + if (!N->hasAnyUseOfValue(i)) + continue; + unsigned RCId = TLI->getRepRegClassFor(VT)->getID(); + unsigned Cost = TLI->getRepRegClassCostFor(VT); + // Check if this increases register pressure of the specific register + // class to the point where it would cause spills. + int Excess = RegPressure[RCId] + Cost - RegLimit[RCId]; + if (Excess > 0) + Spills += Excess; + } + } + + if (!SU->NumSuccs || !Spills) + return Spills; + const SDNode *N = SU->getNode(); + if (!N->isMachineOpcode()) + return Spills; + unsigned NumDefs = TII->get(N->getMachineOpcode()).getNumDefs(); + for (unsigned i = 0; i != NumDefs; ++i) { + EVT VT = N->getValueType(i); + if (!N->hasAnyUseOfValue(i)) + continue; + unsigned RCId = TLI->getRepRegClassFor(VT)->getID(); + unsigned Cost = TLI->getRepRegClassCostFor(VT); + if (RegPressure[RCId] > RegLimit[RCId]) { + int Less = RegLimit[RCId] - (RegPressure[RCId] - Cost); + if (Less > 0) { + if (Spills <= (unsigned)Less) + return 0; + Spills -= Less; + } + } + } + + return Spills; + } + + void OpenPredLives(SUnit *SU) { + const SDNode *N = SU->getNode(); + if (!N->isMachineOpcode()) + return; + unsigned Opc = N->getMachineOpcode(); + if (Opc == TargetOpcode::EXTRACT_SUBREG || + Opc == TargetOpcode::INSERT_SUBREG || + Opc == TargetOpcode::SUBREG_TO_REG || + Opc == TargetOpcode::COPY_TO_REGCLASS || + Opc == TargetOpcode::REG_SEQUENCE || + Opc == TargetOpcode::IMPLICIT_DEF) + return; + + for (SUnit::pred_iterator I = SU->Preds.begin(), E = SU->Preds.end(); + I != E; ++I) { + if (I->isCtrl()) + continue; + SUnit *PredSU = I->getSUnit(); + if (PredSU->NumSuccsLeft != PredSU->NumSuccs - 1) + continue; + const SDNode *PN = PredSU->getNode(); + if (!PN->isMachineOpcode()) + continue; + unsigned NumDefs = TII->get(PN->getMachineOpcode()).getNumDefs(); + for (unsigned i = 0; i != NumDefs; ++i) { + EVT VT = PN->getValueType(i); + if (!PN->hasAnyUseOfValue(i)) + continue; + unsigned RCId = TLI->getRepRegClassFor(VT)->getID(); + RegPressure[RCId] += TLI->getRepRegClassCostFor(VT); + } + } + + if (!SU->NumSuccs) + return; + unsigned NumDefs = TII->get(N->getMachineOpcode()).getNumDefs(); + for (unsigned i = 0; i != NumDefs; ++i) { + EVT VT = N->getValueType(i); + if (!N->hasAnyUseOfValue(i)) + continue; + unsigned RCId = TLI->getRepRegClassFor(VT)->getID(); + RegPressure[RCId] -= TLI->getRepRegClassCostFor(VT); + if (RegPressure[RCId] < 0) + // Register pressure tracking is imprecise. This can happen. + RegPressure[RCId] = 0; + } + } + + void ClosePredLives(SUnit *SU) { + const SDNode *N = SU->getNode(); + if (!N->isMachineOpcode()) + return; + unsigned Opc = N->getMachineOpcode(); + if (Opc == TargetOpcode::EXTRACT_SUBREG || + Opc == TargetOpcode::INSERT_SUBREG || + Opc == TargetOpcode::SUBREG_TO_REG || + Opc == TargetOpcode::COPY_TO_REGCLASS || + Opc == TargetOpcode::REG_SEQUENCE || + Opc == TargetOpcode::IMPLICIT_DEF) + return; + + for (SUnit::pred_iterator I = SU->Preds.begin(), E = SU->Preds.end(); + I != E; ++I) { + if (I->isCtrl()) + continue; + SUnit *PredSU = I->getSUnit(); + if (PredSU->NumSuccsLeft != PredSU->NumSuccs - 1) + continue; + const SDNode *PN = PredSU->getNode(); + if (!PN->isMachineOpcode()) + continue; + unsigned NumDefs = TII->get(PN->getMachineOpcode()).getNumDefs(); + for (unsigned i = 0; i != NumDefs; ++i) { + EVT VT = PN->getValueType(i); + if (!PN->hasAnyUseOfValue(i)) + continue; + unsigned RCId = TLI->getRepRegClassFor(VT)->getID(); + RegPressure[RCId] -= TLI->getRepRegClassCostFor(VT); + if (RegPressure[RCId] < 0) + // Register pressure tracking is imprecise. This can happen. + RegPressure[RCId] = 0; + } + } + + if (!SU->NumSuccs) + return; + unsigned NumDefs = TII->get(N->getMachineOpcode()).getNumDefs(); + for (unsigned i = NumDefs, e = N->getNumValues(); i != e; ++i) { + EVT VT = N->getValueType(i); + if (VT == MVT::Flag || VT == MVT::Other) + continue; + if (!N->hasAnyUseOfValue(i)) + continue; + unsigned RCId = TLI->getRepRegClassFor(VT)->getID(); + RegPressure[RCId] += TLI->getRepRegClassCostFor(VT); + } + } + + void ScheduledNode(SUnit *SU) { + if (!TLI || !isBottomUp) + return; + OpenPredLives(SU); + dumpRegPressure(); + } + + void UnscheduledNode(SUnit *SU) { + if (!TLI || !isBottomUp) + return; + ClosePredLives(SU); + dumpRegPressure(); + } + void setScheduleDAG(ScheduleDAGRRList *scheduleDag) { scheduleDAG = scheduleDag; } + void dumpRegPressure() const { + for (TargetRegisterInfo::regclass_iterator I = TRI->regclass_begin(), + E = TRI->regclass_end(); I != E; ++I) { + const TargetRegisterClass *RC = *I; + unsigned Id = RC->getID(); + unsigned RP = RegPressure[Id]; + if (!RP) continue; + DEBUG(dbgs() << RC->getName() << ": " << RP << " / " << RegLimit[Id] + << '\n'); + } + } + protected: bool canClobber(const SUnit *SU, const SUnit *Op); void AddPseudoTwoAddrDeps(); @@ -1635,8 +1847,8 @@ llvm::createBURRListDAGScheduler(SelectionDAGISel *IS, CodeGenOpt::Level) { const TargetInstrInfo *TII = TM.getInstrInfo(); const TargetRegisterInfo *TRI = TM.getRegisterInfo(); - BURegReductionPriorityQueue *PQ = new BURegReductionPriorityQueue(TII, TRI); - + BURegReductionPriorityQueue *PQ = + new BURegReductionPriorityQueue(*IS->MF, true, TII, TRI, 0); ScheduleDAGRRList *SD = new ScheduleDAGRRList(*IS->MF, true, false, PQ); PQ->setScheduleDAG(SD); return SD; @@ -1648,8 +1860,8 @@ llvm::createTDRRListDAGScheduler(SelectionDAGISel *IS, CodeGenOpt::Level) { const TargetInstrInfo *TII = TM.getInstrInfo(); const TargetRegisterInfo *TRI = TM.getRegisterInfo(); - TDRegReductionPriorityQueue *PQ = new TDRegReductionPriorityQueue(TII, TRI); - + TDRegReductionPriorityQueue *PQ = + new TDRegReductionPriorityQueue(*IS->MF, false, TII, TRI, 0); ScheduleDAGRRList *SD = new ScheduleDAGRRList(*IS->MF, false, false, PQ); PQ->setScheduleDAG(SD); return SD; @@ -1661,8 +1873,8 @@ llvm::createSourceListDAGScheduler(SelectionDAGISel *IS, CodeGenOpt::Level) { const TargetInstrInfo *TII = TM.getInstrInfo(); const TargetRegisterInfo *TRI = TM.getRegisterInfo(); - SrcRegReductionPriorityQueue *PQ = new SrcRegReductionPriorityQueue(TII, TRI); - + SrcRegReductionPriorityQueue *PQ = + new SrcRegReductionPriorityQueue(*IS->MF, true, TII, TRI, 0); ScheduleDAGRRList *SD = new ScheduleDAGRRList(*IS->MF, true, false, PQ); PQ->setScheduleDAG(SD); return SD; @@ -1673,9 +1885,11 @@ llvm::createHybridListDAGScheduler(SelectionDAGISel *IS, CodeGenOpt::Level) { const TargetMachine &TM = IS->TM; const TargetInstrInfo *TII = TM.getInstrInfo(); const TargetRegisterInfo *TRI = TM.getRegisterInfo(); + const TargetLowering *TLI = &IS->getTargetLowering(); - HybridBURRPriorityQueue *PQ = new HybridBURRPriorityQueue(TII, TRI); - + HybridBURRPriorityQueue *PQ = + new HybridBURRPriorityQueue(*IS->MF, true, TII, TRI, + (RegPressureAware ? TLI : 0)); ScheduleDAGRRList *SD = new ScheduleDAGRRList(*IS->MF, true, true, PQ); PQ->setScheduleDAG(SD); return SD; diff --git a/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/lib/CodeGen/SelectionDAG/TargetLowering.cpp index dafda50a834..6e6fedeb2e4 100644 --- a/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -678,9 +678,12 @@ TargetLowering::hasLegalSuperRegRegClasses(const TargetRegisterClass *RC) const{ } /// findRepresentativeClass - Return the largest legal super-reg register class -/// of the specified register class. -const TargetRegisterClass * -TargetLowering::findRepresentativeClass(const TargetRegisterClass *RC) const { +/// of the register class for the specified type and its associated "cost". +std::pair +TargetLowering::findRepresentativeClass(EVT VT) const { + const TargetRegisterClass *RC = RegClassForVT[VT.getSimpleVT().SimpleTy]; + if (!RC) + return std::make_pair(RC, 0); const TargetRegisterClass *BestRC = RC; for (TargetRegisterInfo::regclass_iterator I = RC->superregclasses_begin(), E = RC->superregclasses_end(); I != E; ++I) { @@ -688,10 +691,10 @@ TargetLowering::findRepresentativeClass(const TargetRegisterClass *RC) const { if (RRC->isASubClass() || !isLegalRC(RRC)) continue; if (!hasLegalSuperRegRegClasses(RRC)) - return RRC; + return std::make_pair(RRC, 1); BestRC = RRC; } - return BestRC; + return std::make_pair(BestRC, 1); } /// computeRegisterProperties - Once all of the register classes are added, @@ -820,8 +823,11 @@ void TargetLowering::computeRegisterProperties() { // a group of value types. For example, on i386, i8, i16, and i32 // representative would be GR32; while on x86_64 it's GR64. for (unsigned i = 0; i != MVT::LAST_VALUETYPE; ++i) { - const TargetRegisterClass *RC = RegClassForVT[i]; - RepRegClassForVT[i] = RC ? findRepresentativeClass(RC) : 0; + const TargetRegisterClass* RRC; + uint8_t Cost; + tie(RRC, Cost) = findRepresentativeClass((MVT::SimpleValueType)i); + RepRegClassForVT[i] = RRC; + RepRegClassCostForVT[i] = Cost; } } diff --git a/lib/Target/ARM/ARMISelLowering.cpp b/lib/Target/ARM/ARMISelLowering.cpp index f6d25d8706c..733042266db 100644 --- a/lib/Target/ARM/ARMISelLowering.cpp +++ b/lib/Target/ARM/ARMISelLowering.cpp @@ -550,20 +550,38 @@ ARMTargetLowering::ARMTargetLowering(TargetMachine &TM) benefitFromCodePlacementOpt = true; } -const TargetRegisterClass * -ARMTargetLowering::findRepresentativeClass(const TargetRegisterClass *RC) const{ - switch (RC->getID()) { +std::pair +ARMTargetLowering::findRepresentativeClass(EVT VT) const{ + const TargetRegisterClass *RRC = 0; + uint8_t Cost = 1; + switch (VT.getSimpleVT().SimpleTy) { default: - return RC; - case ARM::tGPRRegClassID: - case ARM::GPRRegClassID: - return ARM::GPRRegisterClass; - case ARM::SPRRegClassID: - case ARM::DPRRegClassID: - return ARM::DPRRegisterClass; - case ARM::QPRRegClassID: - return ARM::QPRRegisterClass; + return TargetLowering::findRepresentativeClass(VT); + // Use SPR as representative register class for all floating point + // and vector types. + case MVT::f32: + RRC = ARM::SPRRegisterClass; + break; + case MVT::f64: case MVT::v8i8: case MVT::v4i16: + case MVT::v2i32: case MVT::v1i64: case MVT::v2f32: + RRC = ARM::SPRRegisterClass; + Cost = 2; + break; + case MVT::v16i8: case MVT::v8i16: case MVT::v4i32: case MVT::v2i64: + case MVT::v4f32: case MVT::v2f64: + RRC = ARM::SPRRegisterClass; + Cost = 4; + break; + case MVT::v4i64: + RRC = ARM::SPRRegisterClass; + Cost = 8; + break; + case MVT::v8i64: + RRC = ARM::SPRRegisterClass; + Cost = 16; + break; } + return std::make_pair(RRC, Cost); } const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const { diff --git a/lib/Target/ARM/ARMISelLowering.h b/lib/Target/ARM/ARMISelLowering.h index ef47003058c..332b7a73be5 100644 --- a/lib/Target/ARM/ARMISelLowering.h +++ b/lib/Target/ARM/ARMISelLowering.h @@ -272,8 +272,8 @@ namespace llvm { virtual bool isFPImmLegal(const APFloat &Imm, EVT VT) const; protected: - const TargetRegisterClass * - findRepresentativeClass(const TargetRegisterClass *RC) const; + std::pair + findRepresentativeClass(EVT VT) const; private: /// Subtarget - Keep a pointer to the ARMSubtarget around so that we can diff --git a/lib/Target/PIC16/PIC16ISelLowering.cpp b/lib/Target/PIC16/PIC16ISelLowering.cpp index 54a6a28992b..527b31d0cc9 100644 --- a/lib/Target/PIC16/PIC16ISelLowering.cpp +++ b/lib/Target/PIC16/PIC16ISelLowering.cpp @@ -312,6 +312,16 @@ PIC16TargetLowering::PIC16TargetLowering(PIC16TargetMachine &TM) computeRegisterProperties(); } +std::pair +PIC16TargetLowering::findRepresentativeClass(EVT VT) const { + switch (VT.getSimpleVT().SimpleTy) { + default: + return TargetLowering::findRepresentativeClass(VT); + case MVT::i16: + return std::make_pair(PIC16::FSR16RegisterClass, 1); + } +} + // getOutFlag - Extract the flag result if the Op has it. static SDValue getOutFlag(SDValue &Op) { // Flag is the last value of the node. diff --git a/lib/Target/PIC16/PIC16ISelLowering.h b/lib/Target/PIC16/PIC16ISelLowering.h index 0a7506cb497..d7bb5c1ce51 100644 --- a/lib/Target/PIC16/PIC16ISelLowering.h +++ b/lib/Target/PIC16/PIC16ISelLowering.h @@ -181,6 +181,9 @@ namespace llvm { // FIXME: The function never seems to be aligned. return 1; } + protected: + std::pair + findRepresentativeClass(EVT VT) const; private: // If the Node is a BUILD_PAIR representing a direct Address, // then this function will return true. -- 2.34.1