#define DEBUG_TYPE "aarch64-pbqp"
#include "AArch64.h"
+#include "AArch64PBQPRegAlloc.h"
#include "AArch64RegisterInfo.h"
-
-#include "llvm/ADT/SetVector.h"
#include "llvm/CodeGen/LiveIntervalAnalysis.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
-#define PBQP_BUILDER PBQPBuilderWithCoalescing
-//#define PBQP_BUILDER PBQPBuilder
-
using namespace llvm;
namespace {
+#ifndef NDEBUG
bool isFPReg(unsigned reg) {
return AArch64::FPR32RegClass.contains(reg) ||
AArch64::FPR64RegClass.contains(reg) ||
AArch64::FPR128RegClass.contains(reg);
-};
+}
+#endif
bool isOdd(unsigned reg) {
switch (reg) {
return isOdd(reg1) == isOdd(reg2);
}
-class A57PBQPBuilder : public PBQP_BUILDER {
-public:
- A57PBQPBuilder() : PBQP_BUILDER(), TRI(nullptr), LIs(nullptr), Chains() {}
-
- // Build a PBQP instance to represent the register allocation problem for
- // the given MachineFunction.
- std::unique_ptr<PBQPRAProblem>
- build(MachineFunction *MF, const LiveIntervals *LI,
- const MachineBlockFrequencyInfo *blockInfo,
- const RegSet &VRegs) override;
-
-private:
- const AArch64RegisterInfo *TRI;
- const LiveIntervals *LIs;
- SmallSetVector<unsigned, 32> Chains;
-
- // Return true if reg is a physical register
- bool isPhysicalReg(unsigned reg) const {
- return TRI->isPhysicalRegister(reg);
- }
-
- // Add the accumulator chaining constraint, inside the chain, i.e. so that
- // parity(Rd) == parity(Ra).
- // \return true if a constraint was added
- bool addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd, unsigned Ra);
-
- // Add constraints between existing chains
- void addInterChainConstraint(PBQPRAProblem *p, unsigned Rd, unsigned Ra);
-};
-} // Anonymous namespace
+}
-bool A57PBQPBuilder::addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd,
- unsigned Ra) {
+bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
+ unsigned Ra) {
if (Rd == Ra)
return false;
- if (isPhysicalReg(Rd) || isPhysicalReg(Ra)) {
- dbgs() << "Rd is a physical reg:" << isPhysicalReg(Rd) << '\n';
- dbgs() << "Ra is a physical reg:" << isPhysicalReg(Ra) << '\n';
+ LiveIntervals &LIs = G.getMetadata().LIS;
+
+ if (TRI->isPhysicalRegister(Rd) || TRI->isPhysicalRegister(Ra)) {
+ DEBUG(dbgs() << "Rd is a physical reg:" << TRI->isPhysicalRegister(Rd)
+ << '\n');
+ DEBUG(dbgs() << "Ra is a physical reg:" << TRI->isPhysicalRegister(Ra)
+ << '\n');
return false;
}
- const PBQPRAProblem::AllowedSet *vRdAllowed = &p->getAllowedSet(Rd);
- const PBQPRAProblem::AllowedSet *vRaAllowed = &p->getAllowedSet(Ra);
+ PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
+ PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
+
+ const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
+ &G.getNodeMetadata(node1).getAllowedRegs();
+ const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =
+ &G.getNodeMetadata(node2).getAllowedRegs();
- PBQPRAGraph &g = p->getGraph();
- PBQPRAGraph::NodeId node1 = p->getNodeForVReg(Rd);
- PBQPRAGraph::NodeId node2 = p->getNodeForVReg(Ra);
- PBQPRAGraph::EdgeId edge = g.findEdge(node1, node2);
+ PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
// The edge does not exist. Create one with the appropriate interference
// costs.
- if (edge == g.invalidEdgeId()) {
- const LiveInterval &ld = LIs->getInterval(Rd);
- const LiveInterval &la = LIs->getInterval(Ra);
+ if (edge == G.invalidEdgeId()) {
+ const LiveInterval &ld = LIs.getInterval(Rd);
+ const LiveInterval &la = LIs.getInterval(Ra);
bool livesOverlap = ld.overlaps(la);
- PBQP::Matrix costs(vRdAllowed->size() + 1, vRaAllowed->size() + 1, 0);
- for (unsigned i = 0; i != vRdAllowed->size(); ++i) {
+ PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
+ vRaAllowed->size() + 1, 0);
+ for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
unsigned pRd = (*vRdAllowed)[i];
- for (unsigned j = 0; j != vRaAllowed->size(); ++j) {
+ for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
unsigned pRa = (*vRaAllowed)[j];
if (livesOverlap && TRI->regsOverlap(pRd, pRa))
costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
}
}
- g.addEdge(node1, node2, std::move(costs));
+ G.addEdge(node1, node2, std::move(costs));
return true;
}
- if (g.getEdgeNode1Id(edge) == node2) {
+ if (G.getEdgeNode1Id(edge) == node2) {
std::swap(node1, node2);
std::swap(vRdAllowed, vRaAllowed);
}
// Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
- PBQP::Matrix costs(g.getEdgeCosts(edge));
- for (unsigned i = 0; i != vRdAllowed->size(); ++i) {
+ PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
+ for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
unsigned pRd = (*vRdAllowed)[i];
// Get the maximum cost (excluding unallocatable reg) for same parity
// registers
PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
- for (unsigned j = 0; j != vRaAllowed->size(); ++j) {
+ for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
unsigned pRa = (*vRaAllowed)[j];
if (haveSameParity(pRd, pRa))
if (costs[i + 1][j + 1] !=
// Ensure all registers with a different parity have a higher cost
// than sameParityMax
- for (unsigned j = 0; j != vRaAllowed->size(); ++j) {
+ for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
unsigned pRa = (*vRaAllowed)[j];
if (!haveSameParity(pRd, pRa))
if (sameParityMax > costs[i + 1][j + 1])
costs[i + 1][j + 1] = sameParityMax + 1.0;
}
}
- g.setEdgeCosts(edge, costs);
+ G.updateEdgeCosts(edge, std::move(costs));
return true;
}
-void
-A57PBQPBuilder::addInterChainConstraint(PBQPRAProblem *p, unsigned Rd,
- unsigned Ra) {
+void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
+ unsigned Ra) {
+ LiveIntervals &LIs = G.getMetadata().LIS;
+
// Do some Chain management
if (Chains.count(Ra)) {
if (Rd != Ra) {
Chains.insert(Rd);
}
- const LiveInterval &ld = LIs->getInterval(Rd);
+ PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
+
+ const LiveInterval &ld = LIs.getInterval(Rd);
for (auto r : Chains) {
// Skip self
if (r == Rd)
continue;
- const LiveInterval &lr = LIs->getInterval(r);
+ const LiveInterval &lr = LIs.getInterval(r);
if (ld.overlaps(lr)) {
- const PBQPRAProblem::AllowedSet *vRdAllowed = &p->getAllowedSet(Rd);
- const PBQPRAProblem::AllowedSet *vRrAllowed = &p->getAllowedSet(r);
-
- PBQPRAGraph &g = p->getGraph();
- PBQPRAGraph::NodeId node1 = p->getNodeForVReg(Rd);
- PBQPRAGraph::NodeId node2 = p->getNodeForVReg(r);
- PBQPRAGraph::EdgeId edge = g.findEdge(node1, node2);
- assert(edge != g.invalidEdgeId() &&
+ const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
+ &G.getNodeMetadata(node1).getAllowedRegs();
+
+ PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
+ const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =
+ &G.getNodeMetadata(node2).getAllowedRegs();
+
+ PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
+ assert(edge != G.invalidEdgeId() &&
"PBQP error ! The edge should exist !");
DEBUG(dbgs() << "Refining constraint !\n";);
- if (g.getEdgeNode1Id(edge) == node2) {
+ if (G.getEdgeNode1Id(edge) == node2) {
std::swap(node1, node2);
std::swap(vRdAllowed, vRrAllowed);
}
// Enforce that cost is higher with all other Chains of the same parity
- PBQP::Matrix costs(g.getEdgeCosts(edge));
- for (unsigned i = 0; i != vRdAllowed->size(); ++i) {
+ PBQP::Matrix costs(G.getEdgeCosts(edge));
+ for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
unsigned pRd = (*vRdAllowed)[i];
// Get the maximum cost (excluding unallocatable reg) for all other
// parity registers
PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
- for (unsigned j = 0; j != vRrAllowed->size(); ++j) {
+ for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
unsigned pRa = (*vRrAllowed)[j];
if (!haveSameParity(pRd, pRa))
if (costs[i + 1][j + 1] !=
// Ensure all registers with same parity have a higher cost
// than sameParityMax
- for (unsigned j = 0; j != vRrAllowed->size(); ++j) {
+ for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
unsigned pRa = (*vRrAllowed)[j];
if (haveSameParity(pRd, pRa))
if (sameParityMax > costs[i + 1][j + 1])
costs[i + 1][j + 1] = sameParityMax + 1.0;
}
}
- g.setEdgeCosts(edge, costs);
+ G.updateEdgeCosts(edge, std::move(costs));
}
}
}
-std::unique_ptr<PBQPRAProblem>
-A57PBQPBuilder::build(MachineFunction *MF, const LiveIntervals *LI,
- const MachineBlockFrequencyInfo *blockInfo,
- const RegSet &VRegs) {
- std::unique_ptr<PBQPRAProblem> p =
- PBQP_BUILDER::build(MF, LI, blockInfo, VRegs);
+static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,
+ const MachineInstr &MI) {
+ const LiveInterval &LI = LIs.getInterval(reg);
+ SlotIndex SI = LIs.getInstructionIndex(&MI);
+ return LI.expiredAt(SI);
+}
- TRI = static_cast<const AArch64RegisterInfo *>(
- MF->getTarget().getSubtargetImpl()->getRegisterInfo());
- LIs = LI;
+void A57ChainingConstraint::apply(PBQPRAGraph &G) {
+ const MachineFunction &MF = G.getMetadata().MF;
+ LiveIntervals &LIs = G.getMetadata().LIS;
- DEBUG(MF->dump(););
+ TRI = MF.getSubtarget().getRegisterInfo();
+ DEBUG(MF.dump());
- for (MachineFunction::const_iterator mbbItr = MF->begin(), mbbEnd = MF->end();
- mbbItr != mbbEnd; ++mbbItr) {
- const MachineBasicBlock *MBB = &*mbbItr;
+ for (const auto &MBB: MF) {
Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
- for (MachineBasicBlock::const_iterator miItr = MBB->begin(),
- miEnd = MBB->end();
- miItr != miEnd; ++miItr) {
- const MachineInstr *MI = &*miItr;
- switch (MI->getOpcode()) {
+ for (const auto &MI: MBB) {
+
+ // Forget Chains which have expired
+ for (auto r : Chains) {
+ SmallVector<unsigned, 8> toDel;
+ if(regJustKilledBefore(LIs, r, MI)) {
+ DEBUG(dbgs() << "Killing chain " << PrintReg(r, TRI) << " at ";
+ MI.print(dbgs()););
+ toDel.push_back(r);
+ }
+
+ while (!toDel.empty()) {
+ Chains.remove(toDel.back());
+ toDel.pop_back();
+ }
+ }
+
+ switch (MI.getOpcode()) {
case AArch64::FMSUBSrrr:
case AArch64::FMADDSrrr:
case AArch64::FNMSUBSrrr:
case AArch64::FMADDDrrr:
case AArch64::FNMSUBDrrr:
case AArch64::FNMADDDrrr: {
- unsigned Rd = MI->getOperand(0).getReg();
- unsigned Ra = MI->getOperand(3).getReg();
+ unsigned Rd = MI.getOperand(0).getReg();
+ unsigned Ra = MI.getOperand(3).getReg();
- if (addIntraChainConstraint(p.get(), Rd, Ra))
- addInterChainConstraint(p.get(), Rd, Ra);
+ if (addIntraChainConstraint(G, Rd, Ra))
+ addInterChainConstraint(G, Rd, Ra);
break;
}
case AArch64::FMLAv2f32:
case AArch64::FMLSv2f32: {
- unsigned Rd = MI->getOperand(0).getReg();
- addInterChainConstraint(p.get(), Rd, Rd);
+ unsigned Rd = MI.getOperand(0).getReg();
+ addInterChainConstraint(G, Rd, Rd);
break;
}
default:
- // Forget Chains which have been killed
- for (auto r : Chains) {
- SmallVector<unsigned, 8> toDel;
- if (MI->killsRegister(r)) {
- DEBUG(dbgs() << "Killing chain " << PrintReg(r, TRI) << " at ";
- MI->print(dbgs()););
- toDel.push_back(r);
- }
-
- while (!toDel.empty()) {
- Chains.remove(toDel.back());
- toDel.pop_back();
- }
- }
+ break;
}
}
}
-
- return p;
-}
-
-// Factory function used by AArch64TargetMachine to add the pass to the
-// passmanager.
-FunctionPass *llvm::createAArch64A57PBQPRegAlloc() {
- std::unique_ptr<PBQP_BUILDER> builder = llvm::make_unique<A57PBQPBuilder>();
- return createPBQPRegisterAllocator(std::move(builder), nullptr);
}