1 //===-- HeuristicSolver.h - Heuristic PBQP Solver --------------*- C++ -*-===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // Heuristic PBQP solver. This solver is able to perform optimal reductions for
11 // nodes of degree 0, 1 or 2. For nodes of degree >2 a plugable heuristic is
12 // used to select a node for reduction.
14 //===----------------------------------------------------------------------===//
16 #ifndef LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H
17 #define LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H
26 /// \brief Heuristic PBQP solver implementation.
28 /// This class should usually be created (and destroyed) indirectly via a call
29 /// to HeuristicSolver<HImpl>::solve(Graph&).
30 /// See the comments for HeuristicSolver.
32 /// HeuristicSolverImpl provides the R0, R1 and R2 reduction rules,
33 /// backpropagation phase, and maintains the internal copy of the graph on
34 /// which the reduction is carried out (the original being kept to facilitate
36 template <typename HImpl>
37 class HeuristicSolverImpl {
40 typedef typename HImpl::NodeData HeuristicNodeData;
41 typedef typename HImpl::EdgeData HeuristicEdgeData;
43 typedef std::list<Graph::EdgeId> SolverEdges;
47 /// \brief Iterator type for edges in the solver graph.
48 typedef SolverEdges::iterator SolverEdgeItr;
54 NodeData() : solverDegree(0) {}
56 HeuristicNodeData& getHeuristicData() { return hData; }
58 SolverEdgeItr addSolverEdge(Graph::EdgeId eId) {
60 return solverEdges.insert(solverEdges.end(), eId);
63 void removeSolverEdge(SolverEdgeItr seItr) {
65 solverEdges.erase(seItr);
68 SolverEdgeItr solverEdgesBegin() { return solverEdges.begin(); }
69 SolverEdgeItr solverEdgesEnd() { return solverEdges.end(); }
70 unsigned getSolverDegree() const { return solverDegree; }
71 void clearSolverEdges() {
77 HeuristicNodeData hData;
78 unsigned solverDegree;
79 SolverEdges solverEdges;
84 HeuristicEdgeData& getHeuristicData() { return hData; }
86 void setN1SolverEdgeItr(SolverEdgeItr n1SolverEdgeItr) {
87 this->n1SolverEdgeItr = n1SolverEdgeItr;
90 SolverEdgeItr getN1SolverEdgeItr() { return n1SolverEdgeItr; }
92 void setN2SolverEdgeItr(SolverEdgeItr n2SolverEdgeItr){
93 this->n2SolverEdgeItr = n2SolverEdgeItr;
96 SolverEdgeItr getN2SolverEdgeItr() { return n2SolverEdgeItr; }
100 HeuristicEdgeData hData;
101 SolverEdgeItr n1SolverEdgeItr, n2SolverEdgeItr;
107 std::vector<Graph::NodeId> stack;
109 typedef std::list<NodeData> NodeDataList;
110 NodeDataList nodeDataList;
112 typedef std::list<EdgeData> EdgeDataList;
113 EdgeDataList edgeDataList;
117 /// \brief Construct a heuristic solver implementation to solve the given
119 /// @param g The graph representing the problem instance to be solved.
120 HeuristicSolverImpl(Graph &g) : g(g), h(*this) {}
122 /// \brief Get the graph being solved by this solver.
123 /// @return The graph representing the problem instance being solved by this
125 Graph& getGraph() { return g; }
127 /// \brief Get the heuristic data attached to the given node.
128 /// @param nId Node id.
129 /// @return The heuristic data attached to the given node.
130 HeuristicNodeData& getHeuristicNodeData(Graph::NodeId nId) {
131 return getSolverNodeData(nId).getHeuristicData();
134 /// \brief Get the heuristic data attached to the given edge.
135 /// @param eId Edge id.
136 /// @return The heuristic data attached to the given node.
137 HeuristicEdgeData& getHeuristicEdgeData(Graph::EdgeId eId) {
138 return getSolverEdgeData(eId).getHeuristicData();
141 /// \brief Begin iterator for the set of edges adjacent to the given node in
142 /// the solver graph.
143 /// @param nId Node id.
144 /// @return Begin iterator for the set of edges adjacent to the given node
145 /// in the solver graph.
146 SolverEdgeItr solverEdgesBegin(Graph::NodeId nId) {
147 return getSolverNodeData(nId).solverEdgesBegin();
150 /// \brief End iterator for the set of edges adjacent to the given node in
151 /// the solver graph.
152 /// @param nId Node id.
153 /// @return End iterator for the set of edges adjacent to the given node in
154 /// the solver graph.
155 SolverEdgeItr solverEdgesEnd(Graph::NodeId nId) {
156 return getSolverNodeData(nId).solverEdgesEnd();
159 /// \brief Remove a node from the solver graph.
160 /// @param eId Edge id for edge to be removed.
162 /// Does <i>not</i> notify the heuristic of the removal. That should be
163 /// done manually if necessary.
164 void removeSolverEdge(Graph::EdgeId eId) {
165 EdgeData &eData = getSolverEdgeData(eId);
166 NodeData &n1Data = getSolverNodeData(g.getEdgeNode1(eId)),
167 &n2Data = getSolverNodeData(g.getEdgeNode2(eId));
169 n1Data.removeSolverEdge(eData.getN1SolverEdgeItr());
170 n2Data.removeSolverEdge(eData.getN2SolverEdgeItr());
173 /// \brief Compute a solution to the PBQP problem instance with which this
174 /// heuristic solver was constructed.
175 /// @return A solution to the PBQP problem.
177 /// Performs the full PBQP heuristic solver algorithm, including setup,
178 /// calls to the heuristic (which will call back to the reduction rules in
179 /// this class), and cleanup.
180 Solution computeSolution() {
190 /// \brief Add to the end of the stack.
191 /// @param nId Node id to add to the reduction stack.
192 void pushToStack(Graph::NodeId nId) {
193 getSolverNodeData(nId).clearSolverEdges();
194 stack.push_back(nId);
197 /// \brief Returns the solver degree of the given node.
198 /// @param nId Node id for which degree is requested.
199 /// @return Node degree in the <i>solver</i> graph (not the original graph).
200 unsigned getSolverDegree(Graph::NodeId nId) {
201 return getSolverNodeData(nId).getSolverDegree();
204 /// \brief Set the solution of the given node.
205 /// @param nId Node id to set solution for.
206 /// @param selection Selection for node.
207 void setSolution(const Graph::NodeId &nId, unsigned selection) {
208 s.setSelection(nId, selection);
210 for (Graph::AdjEdgeItr aeItr = g.adjEdgesBegin(nId),
211 aeEnd = g.adjEdgesEnd(nId);
212 aeItr != aeEnd; ++aeItr) {
213 Graph::EdgeId eId(*aeItr);
214 Graph::NodeId anId(g.getEdgeOtherNode(eId, nId));
215 getSolverNodeData(anId).addSolverEdge(eId);
219 /// \brief Apply rule R0.
220 /// @param nId Node id for node to apply R0 to.
222 /// Node will be automatically pushed to the solver stack.
223 void applyR0(Graph::NodeId nId) {
224 assert(getSolverNodeData(nId).getSolverDegree() == 0 &&
225 "R0 applied to node with degree != 0.");
227 // Nothing to do. Just push the node onto the reduction stack.
233 /// \brief Apply rule R1.
234 /// @param xnId Node id for node to apply R1 to.
236 /// Node will be automatically pushed to the solver stack.
237 void applyR1(Graph::NodeId xnId) {
238 NodeData &nd = getSolverNodeData(xnId);
239 assert(nd.getSolverDegree() == 1 &&
240 "R1 applied to node with degree != 1.");
242 Graph::EdgeId eId = *nd.solverEdgesBegin();
244 const Matrix &eCosts = g.getEdgeCosts(eId);
245 const Vector &xCosts = g.getNodeCosts(xnId);
247 // Duplicate a little to avoid transposing matrices.
248 if (xnId == g.getEdgeNode1(eId)) {
249 Graph::NodeId ynId = g.getEdgeNode2(eId);
250 Vector &yCosts = g.getNodeCosts(ynId);
251 for (unsigned j = 0; j < yCosts.getLength(); ++j) {
252 PBQPNum min = eCosts[0][j] + xCosts[0];
253 for (unsigned i = 1; i < xCosts.getLength(); ++i) {
254 PBQPNum c = eCosts[i][j] + xCosts[i];
260 h.handleRemoveEdge(eId, ynId);
262 Graph::NodeId ynId = g.getEdgeNode1(eId);
263 Vector &yCosts = g.getNodeCosts(ynId);
264 for (unsigned i = 0; i < yCosts.getLength(); ++i) {
265 PBQPNum min = eCosts[i][0] + xCosts[0];
266 for (unsigned j = 1; j < xCosts.getLength(); ++j) {
267 PBQPNum c = eCosts[i][j] + xCosts[j];
273 h.handleRemoveEdge(eId, ynId);
275 removeSolverEdge(eId);
276 assert(nd.getSolverDegree() == 0 &&
277 "Degree 1 with edge removed should be 0.");
282 /// \brief Apply rule R2.
283 /// @param xnId Node id for node to apply R2 to.
285 /// Node will be automatically pushed to the solver stack.
286 void applyR2(Graph::NodeId xnId) {
287 assert(getSolverNodeData(xnId).getSolverDegree() == 2 &&
288 "R2 applied to node with degree != 2.");
290 NodeData &nd = getSolverNodeData(xnId);
291 const Vector &xCosts = g.getNodeCosts(xnId);
293 SolverEdgeItr aeItr = nd.solverEdgesBegin();
294 Graph::EdgeId yxeId = *aeItr,
297 Graph::NodeId ynId = g.getEdgeOtherNode(yxeId, xnId),
298 znId = g.getEdgeOtherNode(zxeId, xnId);
300 bool flipEdge1 = (g.getEdgeNode1(yxeId) == xnId),
301 flipEdge2 = (g.getEdgeNode1(zxeId) == xnId);
303 const Matrix *yxeCosts = flipEdge1 ?
304 new Matrix(g.getEdgeCosts(yxeId).transpose()) :
305 &g.getEdgeCosts(yxeId);
307 const Matrix *zxeCosts = flipEdge2 ?
308 new Matrix(g.getEdgeCosts(zxeId).transpose()) :
309 &g.getEdgeCosts(zxeId);
311 unsigned xLen = xCosts.getLength(),
312 yLen = yxeCosts->getRows(),
313 zLen = zxeCosts->getRows();
315 Matrix delta(yLen, zLen);
317 for (unsigned i = 0; i < yLen; ++i) {
318 for (unsigned j = 0; j < zLen; ++j) {
319 PBQPNum min = (*yxeCosts)[i][0] + (*zxeCosts)[j][0] + xCosts[0];
320 for (unsigned k = 1; k < xLen; ++k) {
321 PBQPNum c = (*yxeCosts)[i][k] + (*zxeCosts)[j][k] + xCosts[k];
336 Graph::EdgeId yzeId = g.findEdge(ynId, znId);
337 bool addedEdge = false;
339 if (yzeId == g.invalidEdgeId()) {
340 yzeId = g.addEdge(ynId, znId, delta);
343 Matrix &yzeCosts = g.getEdgeCosts(yzeId);
344 h.preUpdateEdgeCosts(yzeId);
345 if (ynId == g.getEdgeNode1(yzeId)) {
348 yzeCosts += delta.transpose();
352 bool nullCostEdge = tryNormaliseEdgeMatrix(yzeId);
355 // If we modified the edge costs let the heuristic know.
356 h.postUpdateEdgeCosts(yzeId);
360 // If this edge ended up null remove it.
362 // We didn't just add it, so we need to notify the heuristic
363 // and remove it from the solver.
364 h.handleRemoveEdge(yzeId, ynId);
365 h.handleRemoveEdge(yzeId, znId);
366 removeSolverEdge(yzeId);
369 } else if (addedEdge) {
370 // If the edge was added, and non-null, finish setting it up, add it to
371 // the solver & notify heuristic.
372 edgeDataList.push_back(EdgeData());
373 g.setEdgeData(yzeId, &edgeDataList.back());
374 addSolverEdge(yzeId);
375 h.handleAddEdge(yzeId);
378 h.handleRemoveEdge(yxeId, ynId);
379 removeSolverEdge(yxeId);
380 h.handleRemoveEdge(zxeId, znId);
381 removeSolverEdge(zxeId);
387 /// \brief Record an application of the RN rule.
389 /// For use by the HeuristicBase.
390 void recordRN() { s.recordRN(); }
394 NodeData& getSolverNodeData(Graph::NodeId nId) {
395 return *static_cast<NodeData*>(g.getNodeData(nId));
398 EdgeData& getSolverEdgeData(Graph::EdgeId eId) {
399 return *static_cast<EdgeData*>(g.getEdgeData(eId));
402 void addSolverEdge(Graph::EdgeId eId) {
403 EdgeData &eData = getSolverEdgeData(eId);
404 NodeData &n1Data = getSolverNodeData(g.getEdgeNode1(eId)),
405 &n2Data = getSolverNodeData(g.getEdgeNode2(eId));
407 eData.setN1SolverEdgeItr(n1Data.addSolverEdge(eId));
408 eData.setN2SolverEdgeItr(n2Data.addSolverEdge(eId));
412 if (h.solverRunSimplify()) {
416 // Create node data objects.
417 for (Graph::NodeItr nItr = g.nodesBegin(), nEnd = g.nodesEnd();
418 nItr != nEnd; ++nItr) {
419 nodeDataList.push_back(NodeData());
420 g.setNodeData(*nItr, &nodeDataList.back());
423 // Create edge data objects.
424 for (Graph::EdgeItr eItr = g.edgesBegin(), eEnd = g.edgesEnd();
425 eItr != eEnd; ++eItr) {
426 edgeDataList.push_back(EdgeData());
427 g.setEdgeData(*eItr, &edgeDataList.back());
428 addSolverEdge(*eItr);
433 disconnectTrivialNodes();
434 eliminateIndependentEdges();
437 // Eliminate trivial nodes.
438 void disconnectTrivialNodes() {
439 unsigned numDisconnected = 0;
441 for (Graph::NodeItr nItr = g.nodesBegin(), nEnd = g.nodesEnd();
442 nItr != nEnd; ++nItr) {
444 Graph::NodeId nId = *nItr;
446 if (g.getNodeCosts(nId).getLength() == 1) {
448 std::vector<Graph::EdgeId> edgesToRemove;
450 for (Graph::AdjEdgeItr aeItr = g.adjEdgesBegin(nId),
451 aeEnd = g.adjEdgesEnd(nId);
452 aeItr != aeEnd; ++aeItr) {
454 Graph::EdgeId eId = *aeItr;
456 if (g.getEdgeNode1(eId) == nId) {
457 Graph::NodeId otherNodeId = g.getEdgeNode2(eId);
458 g.getNodeCosts(otherNodeId) +=
459 g.getEdgeCosts(eId).getRowAsVector(0);
462 Graph::NodeId otherNodeId = g.getEdgeNode1(eId);
463 g.getNodeCosts(otherNodeId) +=
464 g.getEdgeCosts(eId).getColAsVector(0);
467 edgesToRemove.push_back(eId);
470 if (!edgesToRemove.empty())
473 while (!edgesToRemove.empty()) {
474 g.removeEdge(edgesToRemove.back());
475 edgesToRemove.pop_back();
481 void eliminateIndependentEdges() {
482 std::vector<Graph::EdgeId> edgesToProcess;
483 unsigned numEliminated = 0;
485 for (Graph::EdgeItr eItr = g.edgesBegin(), eEnd = g.edgesEnd();
486 eItr != eEnd; ++eItr) {
487 edgesToProcess.push_back(*eItr);
490 while (!edgesToProcess.empty()) {
491 if (tryToEliminateEdge(edgesToProcess.back()))
493 edgesToProcess.pop_back();
497 bool tryToEliminateEdge(Graph::EdgeId eId) {
498 if (tryNormaliseEdgeMatrix(eId)) {
505 bool tryNormaliseEdgeMatrix(Graph::EdgeId &eId) {
507 const PBQPNum infinity = std::numeric_limits<PBQPNum>::infinity();
509 Matrix &edgeCosts = g.getEdgeCosts(eId);
510 Vector &uCosts = g.getNodeCosts(g.getEdgeNode1(eId)),
511 &vCosts = g.getNodeCosts(g.getEdgeNode2(eId));
513 for (unsigned r = 0; r < edgeCosts.getRows(); ++r) {
514 PBQPNum rowMin = infinity;
516 for (unsigned c = 0; c < edgeCosts.getCols(); ++c) {
517 if (vCosts[c] != infinity && edgeCosts[r][c] < rowMin)
518 rowMin = edgeCosts[r][c];
523 if (rowMin != infinity) {
524 edgeCosts.subFromRow(r, rowMin);
527 edgeCosts.setRow(r, 0);
531 for (unsigned c = 0; c < edgeCosts.getCols(); ++c) {
532 PBQPNum colMin = infinity;
534 for (unsigned r = 0; r < edgeCosts.getRows(); ++r) {
535 if (uCosts[r] != infinity && edgeCosts[r][c] < colMin)
536 colMin = edgeCosts[r][c];
541 if (colMin != infinity) {
542 edgeCosts.subFromCol(c, colMin);
545 edgeCosts.setCol(c, 0);
549 return edgeCosts.isZero();
552 void backpropagate() {
553 while (!stack.empty()) {
554 computeSolution(stack.back());
559 void computeSolution(Graph::NodeId nId) {
561 NodeData &nodeData = getSolverNodeData(nId);
563 Vector v(g.getNodeCosts(nId));
565 // Solve based on existing solved edges.
566 for (SolverEdgeItr solvedEdgeItr = nodeData.solverEdgesBegin(),
567 solvedEdgeEnd = nodeData.solverEdgesEnd();
568 solvedEdgeItr != solvedEdgeEnd; ++solvedEdgeItr) {
570 Graph::EdgeId eId(*solvedEdgeItr);
571 Matrix &edgeCosts = g.getEdgeCosts(eId);
573 if (nId == g.getEdgeNode1(eId)) {
574 Graph::NodeId adjNode(g.getEdgeNode2(eId));
575 unsigned adjSolution = s.getSelection(adjNode);
576 v += edgeCosts.getColAsVector(adjSolution);
579 Graph::NodeId adjNode(g.getEdgeNode1(eId));
580 unsigned adjSolution = s.getSelection(adjNode);
581 v += edgeCosts.getRowAsVector(adjSolution);
586 setSolution(nId, v.minIndex());
591 nodeDataList.clear();
592 edgeDataList.clear();
596 /// \brief PBQP heuristic solver class.
598 /// Given a PBQP Graph g representing a PBQP problem, you can find a solution
600 /// <tt>Solution s = HeuristicSolver<H>::solve(g);</tt>
602 /// The choice of heuristic for the H parameter will affect both the solver
603 /// speed and solution quality. The heuristic should be chosen based on the
604 /// nature of the problem being solved.
605 /// Currently the only solver included with LLVM is the Briggs heuristic for
606 /// register allocation.
607 template <typename HImpl>
608 class HeuristicSolver {
610 static Solution solve(Graph &g) {
611 HeuristicSolverImpl<HImpl> hs(g);
612 return hs.computeSolution();
618 #endif // LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H