Fix gcc -Wpedantic.
[oota-llvm.git] / lib / Target / AArch64 / AArch64PBQPRegAlloc.cpp
1 //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 // This file contains the AArch64 / Cortex-A57 specific register allocation
10 // constraints for use by the PBQP register allocator.
11 //
12 // It is essentially a transcription of what is contained in
13 // AArch64A57FPLoadBalancing, which tries to use a balanced
14 // mix of odd and even D-registers when performing a critical sequence of
15 // independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
16 //===----------------------------------------------------------------------===//
17
18 #define DEBUG_TYPE "aarch64-pbqp"
19
20 #include "AArch64.h"
21 #include "AArch64RegisterInfo.h"
22
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/CodeGen/LiveIntervalAnalysis.h"
25 #include "llvm/CodeGen/MachineBasicBlock.h"
26 #include "llvm/CodeGen/MachineFunction.h"
27 #include "llvm/CodeGen/MachineRegisterInfo.h"
28 #include "llvm/CodeGen/RegAllocPBQP.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/raw_ostream.h"
32
33 #define PBQP_BUILDER PBQPBuilderWithCoalescing
34
35 using namespace llvm;
36
37 namespace {
38
39 #ifndef NDEBUG
40 bool isFPReg(unsigned reg) {
41   return AArch64::FPR32RegClass.contains(reg) ||
42          AArch64::FPR64RegClass.contains(reg) ||
43          AArch64::FPR128RegClass.contains(reg);
44 }
45 #endif
46
47 bool isOdd(unsigned reg) {
48   switch (reg) {
49   default:
50     llvm_unreachable("Register is not from the expected class !");
51   case AArch64::S1:
52   case AArch64::S3:
53   case AArch64::S5:
54   case AArch64::S7:
55   case AArch64::S9:
56   case AArch64::S11:
57   case AArch64::S13:
58   case AArch64::S15:
59   case AArch64::S17:
60   case AArch64::S19:
61   case AArch64::S21:
62   case AArch64::S23:
63   case AArch64::S25:
64   case AArch64::S27:
65   case AArch64::S29:
66   case AArch64::S31:
67   case AArch64::D1:
68   case AArch64::D3:
69   case AArch64::D5:
70   case AArch64::D7:
71   case AArch64::D9:
72   case AArch64::D11:
73   case AArch64::D13:
74   case AArch64::D15:
75   case AArch64::D17:
76   case AArch64::D19:
77   case AArch64::D21:
78   case AArch64::D23:
79   case AArch64::D25:
80   case AArch64::D27:
81   case AArch64::D29:
82   case AArch64::D31:
83   case AArch64::Q1:
84   case AArch64::Q3:
85   case AArch64::Q5:
86   case AArch64::Q7:
87   case AArch64::Q9:
88   case AArch64::Q11:
89   case AArch64::Q13:
90   case AArch64::Q15:
91   case AArch64::Q17:
92   case AArch64::Q19:
93   case AArch64::Q21:
94   case AArch64::Q23:
95   case AArch64::Q25:
96   case AArch64::Q27:
97   case AArch64::Q29:
98   case AArch64::Q31:
99     return true;
100   case AArch64::S0:
101   case AArch64::S2:
102   case AArch64::S4:
103   case AArch64::S6:
104   case AArch64::S8:
105   case AArch64::S10:
106   case AArch64::S12:
107   case AArch64::S14:
108   case AArch64::S16:
109   case AArch64::S18:
110   case AArch64::S20:
111   case AArch64::S22:
112   case AArch64::S24:
113   case AArch64::S26:
114   case AArch64::S28:
115   case AArch64::S30:
116   case AArch64::D0:
117   case AArch64::D2:
118   case AArch64::D4:
119   case AArch64::D6:
120   case AArch64::D8:
121   case AArch64::D10:
122   case AArch64::D12:
123   case AArch64::D14:
124   case AArch64::D16:
125   case AArch64::D18:
126   case AArch64::D20:
127   case AArch64::D22:
128   case AArch64::D24:
129   case AArch64::D26:
130   case AArch64::D28:
131   case AArch64::D30:
132   case AArch64::Q0:
133   case AArch64::Q2:
134   case AArch64::Q4:
135   case AArch64::Q6:
136   case AArch64::Q8:
137   case AArch64::Q10:
138   case AArch64::Q12:
139   case AArch64::Q14:
140   case AArch64::Q16:
141   case AArch64::Q18:
142   case AArch64::Q20:
143   case AArch64::Q22:
144   case AArch64::Q24:
145   case AArch64::Q26:
146   case AArch64::Q28:
147   case AArch64::Q30:
148     return false;
149
150   }
151 }
152
153 bool haveSameParity(unsigned reg1, unsigned reg2) {
154   assert(isFPReg(reg1) && "Expecting an FP register for reg1");
155   assert(isFPReg(reg2) && "Expecting an FP register for reg2");
156
157   return isOdd(reg1) == isOdd(reg2);
158 }
159
160 class A57PBQPBuilder : public PBQP_BUILDER {
161 public:
162   A57PBQPBuilder() : PBQP_BUILDER(), TRI(nullptr), LIs(nullptr), Chains() {}
163
164   // Build a PBQP instance to represent the register allocation problem for
165   // the given MachineFunction.
166   std::unique_ptr<PBQPRAProblem>
167   build(MachineFunction *MF, const LiveIntervals *LI,
168         const MachineBlockFrequencyInfo *blockInfo,
169         const RegSet &VRegs) override;
170
171 private:
172   const AArch64RegisterInfo *TRI;
173   const LiveIntervals *LIs;
174   SmallSetVector<unsigned, 32> Chains;
175
176   // Return true if reg is a physical register
177   bool isPhysicalReg(unsigned reg) const {
178     return TRI->isPhysicalRegister(reg);
179   }
180
181   // Add the accumulator chaining constraint, inside the chain, i.e. so that
182   // parity(Rd) == parity(Ra).
183   // \return true if a constraint was added
184   bool addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd, unsigned Ra);
185
186   // Add constraints between existing chains
187   void addInterChainConstraint(PBQPRAProblem *p, unsigned Rd, unsigned Ra);
188 };
189 } // Anonymous namespace
190
191 bool A57PBQPBuilder::addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd,
192                                              unsigned Ra) {
193   if (Rd == Ra)
194     return false;
195
196   if (isPhysicalReg(Rd) || isPhysicalReg(Ra)) {
197     DEBUG(dbgs() << "Rd is a physical reg:" << isPhysicalReg(Rd) << '\n');
198     DEBUG(dbgs() << "Ra is a physical reg:" << isPhysicalReg(Ra) << '\n');
199     return false;
200   }
201
202   const PBQPRAProblem::AllowedSet *vRdAllowed = &p->getAllowedSet(Rd);
203   const PBQPRAProblem::AllowedSet *vRaAllowed = &p->getAllowedSet(Ra);
204
205   PBQPRAGraph &g = p->getGraph();
206   PBQPRAGraph::NodeId node1 = p->getNodeForVReg(Rd);
207   PBQPRAGraph::NodeId node2 = p->getNodeForVReg(Ra);
208   PBQPRAGraph::EdgeId edge = g.findEdge(node1, node2);
209
210   // The edge does not exist. Create one with the appropriate interference
211   // costs.
212   if (edge == g.invalidEdgeId()) {
213     const LiveInterval &ld = LIs->getInterval(Rd);
214     const LiveInterval &la = LIs->getInterval(Ra);
215     bool livesOverlap = ld.overlaps(la);
216
217     PBQP::Matrix costs(vRdAllowed->size() + 1, vRaAllowed->size() + 1, 0);
218     for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
219       unsigned pRd = (*vRdAllowed)[i];
220       for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
221         unsigned pRa = (*vRaAllowed)[j];
222         if (livesOverlap && TRI->regsOverlap(pRd, pRa))
223           costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
224         else
225           costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
226       }
227     }
228     g.addEdge(node1, node2, std::move(costs));
229     return true;
230   }
231
232   if (g.getEdgeNode1Id(edge) == node2) {
233     std::swap(node1, node2);
234     std::swap(vRdAllowed, vRaAllowed);
235   }
236
237   // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
238   PBQP::Matrix costs(g.getEdgeCosts(edge));
239   for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
240     unsigned pRd = (*vRdAllowed)[i];
241
242     // Get the maximum cost (excluding unallocatable reg) for same parity
243     // registers
244     PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
245     for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
246       unsigned pRa = (*vRaAllowed)[j];
247       if (haveSameParity(pRd, pRa))
248         if (costs[i + 1][j + 1] !=
249                 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
250             costs[i + 1][j + 1] > sameParityMax)
251           sameParityMax = costs[i + 1][j + 1];
252     }
253
254     // Ensure all registers with a different parity have a higher cost
255     // than sameParityMax
256     for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
257       unsigned pRa = (*vRaAllowed)[j];
258       if (!haveSameParity(pRd, pRa))
259         if (sameParityMax > costs[i + 1][j + 1])
260           costs[i + 1][j + 1] = sameParityMax + 1.0;
261     }
262   }
263   g.setEdgeCosts(edge, costs);
264
265   return true;
266 }
267
268 void
269 A57PBQPBuilder::addInterChainConstraint(PBQPRAProblem *p, unsigned Rd,
270                                         unsigned Ra) {
271   // Do some Chain management
272   if (Chains.count(Ra)) {
273     if (Rd != Ra) {
274       DEBUG(dbgs() << "Moving acc chain from " << PrintReg(Ra, TRI) << " to "
275                    << PrintReg(Rd, TRI) << '\n';);
276       Chains.remove(Ra);
277       Chains.insert(Rd);
278     }
279   } else {
280     DEBUG(dbgs() << "Creating new acc chain for " << PrintReg(Rd, TRI)
281                  << '\n';);
282     Chains.insert(Rd);
283   }
284
285   const LiveInterval &ld = LIs->getInterval(Rd);
286   for (auto r : Chains) {
287     // Skip self
288     if (r == Rd)
289       continue;
290
291     const LiveInterval &lr = LIs->getInterval(r);
292     if (ld.overlaps(lr)) {
293       const PBQPRAProblem::AllowedSet *vRdAllowed = &p->getAllowedSet(Rd);
294       const PBQPRAProblem::AllowedSet *vRrAllowed = &p->getAllowedSet(r);
295
296       PBQPRAGraph &g = p->getGraph();
297       PBQPRAGraph::NodeId node1 = p->getNodeForVReg(Rd);
298       PBQPRAGraph::NodeId node2 = p->getNodeForVReg(r);
299       PBQPRAGraph::EdgeId edge = g.findEdge(node1, node2);
300       assert(edge != g.invalidEdgeId() &&
301              "PBQP error ! The edge should exist !");
302
303       DEBUG(dbgs() << "Refining constraint !\n";);
304
305       if (g.getEdgeNode1Id(edge) == node2) {
306         std::swap(node1, node2);
307         std::swap(vRdAllowed, vRrAllowed);
308       }
309
310       // Enforce that cost is higher with all other Chains of the same parity
311       PBQP::Matrix costs(g.getEdgeCosts(edge));
312       for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
313         unsigned pRd = (*vRdAllowed)[i];
314
315         // Get the maximum cost (excluding unallocatable reg) for all other
316         // parity registers
317         PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
318         for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
319           unsigned pRa = (*vRrAllowed)[j];
320           if (!haveSameParity(pRd, pRa))
321             if (costs[i + 1][j + 1] !=
322                     std::numeric_limits<PBQP::PBQPNum>::infinity() &&
323                 costs[i + 1][j + 1] > sameParityMax)
324               sameParityMax = costs[i + 1][j + 1];
325         }
326
327         // Ensure all registers with same parity have a higher cost
328         // than sameParityMax
329         for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
330           unsigned pRa = (*vRrAllowed)[j];
331           if (haveSameParity(pRd, pRa))
332             if (sameParityMax > costs[i + 1][j + 1])
333               costs[i + 1][j + 1] = sameParityMax + 1.0;
334         }
335       }
336       g.setEdgeCosts(edge, costs);
337     }
338   }
339 }
340
341 std::unique_ptr<PBQPRAProblem>
342 A57PBQPBuilder::build(MachineFunction *MF, const LiveIntervals *LI,
343                       const MachineBlockFrequencyInfo *blockInfo,
344                       const RegSet &VRegs) {
345   std::unique_ptr<PBQPRAProblem> p =
346       PBQP_BUILDER::build(MF, LI, blockInfo, VRegs);
347
348   TRI = static_cast<const AArch64RegisterInfo *>(
349       MF->getTarget().getSubtargetImpl()->getRegisterInfo());
350   LIs = LI;
351
352   DEBUG(MF->dump(););
353
354   for (MachineFunction::const_iterator mbbItr = MF->begin(), mbbEnd = MF->end();
355        mbbItr != mbbEnd; ++mbbItr) {
356     const MachineBasicBlock *MBB = &*mbbItr;
357     Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
358
359     for (MachineBasicBlock::const_iterator miItr = MBB->begin(),
360                                            miEnd = MBB->end();
361          miItr != miEnd; ++miItr) {
362       const MachineInstr *MI = &*miItr;
363       switch (MI->getOpcode()) {
364       case AArch64::FMSUBSrrr:
365       case AArch64::FMADDSrrr:
366       case AArch64::FNMSUBSrrr:
367       case AArch64::FNMADDSrrr:
368       case AArch64::FMSUBDrrr:
369       case AArch64::FMADDDrrr:
370       case AArch64::FNMSUBDrrr:
371       case AArch64::FNMADDDrrr: {
372         unsigned Rd = MI->getOperand(0).getReg();
373         unsigned Ra = MI->getOperand(3).getReg();
374
375         if (addIntraChainConstraint(p.get(), Rd, Ra))
376           addInterChainConstraint(p.get(), Rd, Ra);
377         break;
378       }
379
380       case AArch64::FMLAv2f32:
381       case AArch64::FMLSv2f32: {
382         unsigned Rd = MI->getOperand(0).getReg();
383         addInterChainConstraint(p.get(), Rd, Rd);
384         break;
385       }
386
387       default:
388         // Forget Chains which have been killed
389         for (auto r : Chains) {
390           SmallVector<unsigned, 8> toDel;
391           if (MI->killsRegister(r)) {
392             DEBUG(dbgs() << "Killing chain " << PrintReg(r, TRI) << " at ";
393                   MI->print(dbgs()););
394             toDel.push_back(r);
395           }
396
397           while (!toDel.empty()) {
398             Chains.remove(toDel.back());
399             toDel.pop_back();
400           }
401         }
402       }
403     }
404   }
405
406   return p;
407 }
408
409 // Factory function used by AArch64TargetMachine to add the pass to the
410 // passmanager.
411 FunctionPass *llvm::createAArch64A57PBQPRegAlloc() {
412   std::unique_ptr<PBQP_BUILDER> builder = llvm::make_unique<A57PBQPBuilder>();
413   return createPBQPRegisterAllocator(std::move(builder), nullptr);
414 }