1 //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
9 // This file contains the AArch64 / Cortex-A57 specific register allocation
10 // constraints for use by the PBQP register allocator.
12 // It is essentially a transcription of what is contained in
13 // AArch64A57FPLoadBalancing, which tries to use a balanced
14 // mix of odd and even D-registers when performing a critical sequence of
15 // independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
16 //===----------------------------------------------------------------------===//
18 #define DEBUG_TYPE "aarch64-pbqp"
21 #include "AArch64PBQPRegAlloc.h"
22 #include "AArch64RegisterInfo.h"
23 #include "llvm/CodeGen/LiveIntervalAnalysis.h"
24 #include "llvm/CodeGen/MachineBasicBlock.h"
25 #include "llvm/CodeGen/MachineFunction.h"
26 #include "llvm/CodeGen/MachineRegisterInfo.h"
27 #include "llvm/CodeGen/RegAllocPBQP.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "llvm/Support/raw_ostream.h"
37 bool isFPReg(unsigned reg) {
38 return AArch64::FPR32RegClass.contains(reg) ||
39 AArch64::FPR64RegClass.contains(reg) ||
40 AArch64::FPR128RegClass.contains(reg);
44 bool isOdd(unsigned reg) {
47 llvm_unreachable("Register is not from the expected class !");
150 bool haveSameParity(unsigned reg1, unsigned reg2) {
151 assert(isFPReg(reg1) && "Expecting an FP register for reg1");
152 assert(isFPReg(reg2) && "Expecting an FP register for reg2");
154 return isOdd(reg1) == isOdd(reg2);
159 bool A57PBQPConstraints::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
164 const TargetRegisterInfo &TRI =
165 *G.getMetadata().MF.getTarget().getSubtargetImpl()->getRegisterInfo();
166 LiveIntervals &LIs = G.getMetadata().LIS;
168 if (TRI.isPhysicalRegister(Rd) || TRI.isPhysicalRegister(Ra)) {
169 DEBUG(dbgs() << "Rd is a physical reg:" << TRI.isPhysicalRegister(Rd)
171 DEBUG(dbgs() << "Ra is a physical reg:" << TRI.isPhysicalRegister(Ra)
176 PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
177 PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
179 const PBQPRAGraph::NodeMetadata::OptionToRegMap *vRdAllowed =
180 &G.getNodeMetadata(node1).getOptionRegs();
181 const PBQPRAGraph::NodeMetadata::OptionToRegMap *vRaAllowed =
182 &G.getNodeMetadata(node2).getOptionRegs();
184 PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
186 // The edge does not exist. Create one with the appropriate interference
188 if (edge == G.invalidEdgeId()) {
189 const LiveInterval &ld = LIs.getInterval(Rd);
190 const LiveInterval &la = LIs.getInterval(Ra);
191 bool livesOverlap = ld.overlaps(la);
193 PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
194 vRaAllowed->size() + 1, 0);
195 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
196 unsigned pRd = (*vRdAllowed)[i];
197 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
198 unsigned pRa = (*vRaAllowed)[j];
199 if (livesOverlap && TRI.regsOverlap(pRd, pRa))
200 costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
202 costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
205 G.addEdge(node1, node2, std::move(costs));
209 if (G.getEdgeNode1Id(edge) == node2) {
210 std::swap(node1, node2);
211 std::swap(vRdAllowed, vRaAllowed);
214 // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
215 PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
216 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
217 unsigned pRd = (*vRdAllowed)[i];
219 // Get the maximum cost (excluding unallocatable reg) for same parity
221 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
222 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
223 unsigned pRa = (*vRaAllowed)[j];
224 if (haveSameParity(pRd, pRa))
225 if (costs[i + 1][j + 1] !=
226 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
227 costs[i + 1][j + 1] > sameParityMax)
228 sameParityMax = costs[i + 1][j + 1];
231 // Ensure all registers with a different parity have a higher cost
232 // than sameParityMax
233 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
234 unsigned pRa = (*vRaAllowed)[j];
235 if (!haveSameParity(pRd, pRa))
236 if (sameParityMax > costs[i + 1][j + 1])
237 costs[i + 1][j + 1] = sameParityMax + 1.0;
240 G.setEdgeCosts(edge, std::move(costs));
245 void A57PBQPConstraints::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
247 const TargetRegisterInfo &TRI =
248 *G.getMetadata().MF.getTarget().getSubtargetImpl()->getRegisterInfo();
250 LiveIntervals &LIs = G.getMetadata().LIS;
252 // Do some Chain management
253 if (Chains.count(Ra)) {
255 DEBUG(dbgs() << "Moving acc chain from " << PrintReg(Ra, &TRI) << " to "
256 << PrintReg(Rd, &TRI) << '\n';);
261 DEBUG(dbgs() << "Creating new acc chain for " << PrintReg(Rd, &TRI)
266 PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
268 const LiveInterval &ld = LIs.getInterval(Rd);
269 for (auto r : Chains) {
274 const LiveInterval &lr = LIs.getInterval(r);
275 if (ld.overlaps(lr)) {
276 const PBQPRAGraph::NodeMetadata::OptionToRegMap *vRdAllowed =
277 &G.getNodeMetadata(node1).getOptionRegs();
279 PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
280 const PBQPRAGraph::NodeMetadata::OptionToRegMap *vRrAllowed =
281 &G.getNodeMetadata(node2).getOptionRegs();
283 PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
284 assert(edge != G.invalidEdgeId() &&
285 "PBQP error ! The edge should exist !");
287 DEBUG(dbgs() << "Refining constraint !\n";);
289 if (G.getEdgeNode1Id(edge) == node2) {
290 std::swap(node1, node2);
291 std::swap(vRdAllowed, vRrAllowed);
294 // Enforce that cost is higher with all other Chains of the same parity
295 PBQP::Matrix costs(G.getEdgeCosts(edge));
296 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
297 unsigned pRd = (*vRdAllowed)[i];
299 // Get the maximum cost (excluding unallocatable reg) for all other
301 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
302 for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
303 unsigned pRa = (*vRrAllowed)[j];
304 if (!haveSameParity(pRd, pRa))
305 if (costs[i + 1][j + 1] !=
306 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
307 costs[i + 1][j + 1] > sameParityMax)
308 sameParityMax = costs[i + 1][j + 1];
311 // Ensure all registers with same parity have a higher cost
312 // than sameParityMax
313 for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
314 unsigned pRa = (*vRrAllowed)[j];
315 if (haveSameParity(pRd, pRa))
316 if (sameParityMax > costs[i + 1][j + 1])
317 costs[i + 1][j + 1] = sameParityMax + 1.0;
320 G.setEdgeCosts(edge, std::move(costs));
325 void A57PBQPConstraints::apply(PBQPRAGraph &G) {
326 MachineFunction &MF = G.getMetadata().MF;
328 const TargetRegisterInfo &TRI =
329 *MF.getTarget().getSubtargetImpl()->getRegisterInfo();
333 for (MachineFunction::const_iterator mbbItr = MF.begin(), mbbEnd = MF.end();
334 mbbItr != mbbEnd; ++mbbItr) {
335 const MachineBasicBlock *MBB = &*mbbItr;
336 Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
338 for (MachineBasicBlock::const_iterator miItr = MBB->begin(),
340 miItr != miEnd; ++miItr) {
341 const MachineInstr *MI = &*miItr;
342 switch (MI->getOpcode()) {
343 case AArch64::FMSUBSrrr:
344 case AArch64::FMADDSrrr:
345 case AArch64::FNMSUBSrrr:
346 case AArch64::FNMADDSrrr:
347 case AArch64::FMSUBDrrr:
348 case AArch64::FMADDDrrr:
349 case AArch64::FNMSUBDrrr:
350 case AArch64::FNMADDDrrr: {
351 unsigned Rd = MI->getOperand(0).getReg();
352 unsigned Ra = MI->getOperand(3).getReg();
354 if (addIntraChainConstraint(G, Rd, Ra))
355 addInterChainConstraint(G, Rd, Ra);
359 case AArch64::FMLAv2f32:
360 case AArch64::FMLSv2f32: {
361 unsigned Rd = MI->getOperand(0).getReg();
362 addInterChainConstraint(G, Rd, Rd);
367 // Forget Chains which have been killed
368 for (auto r : Chains) {
369 SmallVector<unsigned, 8> toDel;
370 if (MI->killsRegister(r)) {
371 DEBUG(dbgs() << "Killing chain " << PrintReg(r, &TRI) << " at ";
376 while (!toDel.empty()) {
377 Chains.remove(toDel.back());