* Add support for different "PassType's"
[oota-llvm.git] / lib / Transforms / Instrumentation / ProfilePaths / Graph.cpp
1 //===--Graph.cpp--- implements Graph class ---------------- ------*- C++ -*--=//
2 //
3 // This implements Graph for helping in trace generation
4 // This graph gets used by "ProfilePaths" class
5 //
6 //===----------------------------------------------------------------------===//
7
8 #include "llvm/Transforms/Instrumentation/Graph.h"
9 #include "llvm/iTerminators.h"
10 #include "llvm/BasicBlock.h"
11 #include <algorithm>
12 #include <iostream>
13
14 //using std::list;
15 //using std::set;
16 using std::map;
17 using std::vector;
18 using std::cerr;
19
20 const graphListElement *findNodeInList(const Graph::nodeList &NL,
21                                               Node *N) {
22   for(Graph::nodeList::const_iterator NI = NL.begin(), NE=NL.end(); NI != NE; 
23       ++NI)
24     if (*NI->element== *N)
25       return &*NI;
26   return 0;
27 }
28
29 graphListElement *findNodeInList(Graph::nodeList &NL, Node *N) {
30   for(Graph::nodeList::iterator NI = NL.begin(), NE=NL.end(); NI != NE; ++NI)
31     if (*NI->element== *N)
32       return &*NI;
33   return 0;
34 }
35
36 //graph constructor with root and exit specified
37 Graph::Graph(std::vector<Node*> n, std::vector<Edge> e, 
38              Node *rt, Node *lt){
39   strt=rt;
40   ext=lt;
41   for(vector<Node* >::iterator x=n.begin(), en=n.end(); x!=en; ++x)
42     //nodes[*x] = list<graphListElement>();
43     nodes[*x] = vector<graphListElement>();
44
45   for(vector<Edge >::iterator x=e.begin(), en=e.end(); x!=en; ++x){
46     Edge ee=*x;
47     int w=ee.getWeight();
48     //nodes[ee.getFirst()].push_front(graphListElement(ee.getSecond(),w, ee.getRandId()));   
49     nodes[ee.getFirst()].push_back(graphListElement(ee.getSecond(),w, ee.getRandId()));
50   }
51   
52 }
53
54 //sorting edgelist, called by backEdgeVist ONLY!!!
55 Graph::nodeList &Graph::sortNodeList(Node *par, nodeList &nl){
56   assert(par && "null node pointer");
57   BasicBlock *bbPar = par->getElement();
58   
59   if(nl.size()<=1) return nl;
60
61   for(nodeList::iterator NLI = nl.begin(), NLE = nl.end()-1; NLI != NLE; ++NLI){
62     nodeList::iterator min = NLI;
63     for(nodeList::iterator LI = NLI+1, LE = nl.end(); LI!=LE; ++LI){
64       //if LI < min, min = LI
65       if(min->element->getElement() == LI->element->getElement())
66         continue;
67       
68
69       TerminatorInst *tti = par->getElement()->getTerminator();
70       BranchInst *ti =  cast<BranchInst>(tti);
71       assert(ti && "not a branch");
72       assert(ti->getNumSuccessors()==2 && "less successors!");
73       
74       BasicBlock *tB = ti->getSuccessor(0);
75       BasicBlock *fB = ti->getSuccessor(1);
76       
77       if(tB == LI->element->getElement() || fB == min->element->getElement())
78         min = LI;
79     }
80     
81     graphListElement tmpElmnt = *min;
82     *min = *NLI;
83     *NLI = tmpElmnt;
84   }
85   return nl;
86 }
87
88 //check whether graph has an edge
89 //having an edge simply means that there is an edge in the graph
90 //which has same endpoints as the given edge
91 bool Graph::hasEdge(Edge ed){
92   if(ed.isNull())
93     return false;
94
95   nodeList &nli= nodes[ed.getFirst()]; //getNodeList(ed.getFirst());
96   Node *nd2=ed.getSecond();
97
98   return (findNodeInList(nli,nd2)!=NULL);
99
100 }
101
102
103 //check whether graph has an edge, with a given wt
104 //having an edge simply means that there is an edge in the graph
105 //which has same endpoints as the given edge
106 //This function checks, moreover, that the wt of edge matches too
107 bool Graph::hasEdgeAndWt(Edge ed){
108   if(ed.isNull())
109     return false;
110
111   Node *nd2=ed.getSecond();
112   nodeList nli = nodes[ed.getFirst()];//getNodeList(ed.getFirst());
113   
114   for(nodeList::iterator NI=nli.begin(), NE=nli.end(); NI!=NE; ++NI)
115     if(*NI->element == *nd2 && ed.getWeight()==NI->weight)
116       return true;
117   
118   return false;
119 }
120
121 //add a node
122 void Graph::addNode(Node *nd){
123   vector<Node *> lt=getAllNodes();
124
125   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE;++LI){
126     if(**LI==*nd)
127       return;
128   }
129   //chng
130   nodes[nd] =vector<graphListElement>(); //list<graphListElement>();
131 }
132
133 //add an edge
134 //this adds an edge ONLY when 
135 //the edge to be added doesn not already exist
136 //we "equate" two edges here only with their 
137 //end points
138 void Graph::addEdge(Edge ed, int w){
139   nodeList &ndList = nodes[ed.getFirst()];
140   Node *nd2=ed.getSecond();
141
142   if(findNodeInList(nodes[ed.getFirst()], nd2))
143     return;
144  
145   //ndList.push_front(graphListElement(nd2,w, ed.getRandId()));
146   ndList.push_back(graphListElement(nd2,w, ed.getRandId()));//chng
147   //sortNodeList(ed.getFirst(), ndList);
148
149   //sort(ndList.begin(), ndList.end(), NodeListSort());
150 }
151
152 //add an edge EVEN IF such an edge already exists
153 //this may make a multi-graph
154 //which does happen when we add dummy edges
155 //to the graph, for compensating for back-edges
156 void Graph::addEdgeForce(Edge ed){
157   //nodes[ed.getFirst()].push_front(graphListElement(ed.getSecond(),
158   //ed.getWeight(), ed.getRandId()));
159   nodes[ed.getFirst()].push_back
160     (graphListElement(ed.getSecond(), ed.getWeight(), ed.getRandId()));
161
162   //sortNodeList(ed.getFirst(), nodes[ed.getFirst()]);
163   //sort(nodes[ed.getFirst()].begin(), nodes[ed.getFirst()].end(), NodeListSort());
164 }
165
166 //remove an edge
167 //Note that it removes just one edge,
168 //the first edge that is encountered
169 void Graph::removeEdge(Edge ed){
170   nodeList &ndList = nodes[ed.getFirst()];
171   Node &nd2 = *ed.getSecond();
172
173   for(nodeList::iterator NI=ndList.begin(), NE=ndList.end(); NI!=NE ;++NI) {
174     if(*NI->element == nd2) {
175       ndList.erase(NI);
176       break;
177     }
178   }
179 }
180
181 //remove an edge with a given wt
182 //Note that it removes just one edge,
183 //the first edge that is encountered
184 void Graph::removeEdgeWithWt(Edge ed){
185   nodeList &ndList = nodes[ed.getFirst()];
186   Node &nd2 = *ed.getSecond();
187
188   for(nodeList::iterator NI=ndList.begin(), NE=ndList.end(); NI!=NE ;++NI) {
189     if(*NI->element == nd2 && NI->weight==ed.getWeight()) {
190       ndList.erase(NI);
191       break;
192     }
193   }
194 }
195
196 //set the weight of an edge
197 void Graph::setWeight(Edge ed){
198   graphListElement *El = findNodeInList(nodes[ed.getFirst()], ed.getSecond());
199   if (El)
200     El->weight=ed.getWeight();
201 }
202
203
204
205 //get the list of successor nodes
206 vector<Node *> Graph::getSuccNodes(Node *nd){
207   nodeMapTy::const_iterator nli = nodes.find(nd);
208   assert(nli != nodes.end() && "Node must be in nodes map");
209   const nodeList &nl = getNodeList(nd);//getSortedNodeList(nd);
210
211   vector<Node *> lt;
212   for(nodeList::const_iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI)
213     lt.push_back(NI->element);
214
215   return lt;
216 }
217
218 //get the number of outgoing edges
219 int Graph::getNumberOfOutgoingEdges(Node *nd) const {
220   nodeMapTy::const_iterator nli = nodes.find(nd);
221   assert(nli != nodes.end() && "Node must be in nodes map");
222   const nodeList &nl = nli->second;
223
224   int count=0;
225   for(nodeList::const_iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI)
226     count++;
227
228   return count;
229 }
230
231 //get the list of predecessor nodes
232 vector<Node *> Graph::getPredNodes(Node *nd){
233   vector<Node *> lt;
234   for(nodeMapTy::const_iterator EI=nodes.begin(), EE=nodes.end(); EI!=EE ;++EI){
235     Node *lnode=EI->first;
236     const nodeList &nl = getNodeList(lnode);
237
238     const graphListElement *N = findNodeInList(nl, nd);
239     if (N) lt.push_back(lnode);
240   }
241   return lt;
242 }
243
244 //get the number of predecessor nodes
245 int Graph::getNumberOfIncomingEdges(Node *nd){
246   int count=0;
247   for(nodeMapTy::const_iterator EI=nodes.begin(), EE=nodes.end(); EI!=EE ;++EI){
248     Node *lnode=EI->first;
249     const nodeList &nl = getNodeList(lnode);
250     for(Graph::nodeList::const_iterator NI = nl.begin(), NE=nl.end(); NI != NE; 
251         ++NI)
252       if (*NI->element== *nd)
253         count++;
254   }
255   return count;
256 }
257
258 //get the list of all the vertices in graph
259 vector<Node *> Graph::getAllNodes() const{
260   vector<Node *> lt;
261   for(nodeMapTy::const_iterator x=nodes.begin(), en=nodes.end(); x != en; ++x)
262     lt.push_back(x->first);
263
264   return lt;
265 }
266
267 //get the list of all the vertices in graph
268 vector<Node *> Graph::getAllNodes(){
269   vector<Node *> lt;
270   for(nodeMapTy::const_iterator x=nodes.begin(), en=nodes.end(); x != en; ++x)
271     lt.push_back(x->first);
272
273   return lt;
274 }
275
276 //class to compare two nodes in graph
277 //based on their wt: this is used in
278 //finding the maximal spanning tree
279 struct compare_nodes {
280   bool operator()(Node *n1, Node *n2){
281     return n1->getWeight() < n2->getWeight();
282   }
283 };
284
285
286 static void printNode(Node *nd){
287   cerr<<"Node:"<<nd->getElement()->getName()<<"\n";
288 }
289
290 //Get the Maximal spanning tree (also a graph)
291 //of the graph
292 Graph* Graph::getMaxSpanningTree(){
293   //assume connected graph
294  
295   Graph *st=new Graph();//max spanning tree, undirected edges
296   int inf=9999999;//largest key
297   vector<Node *> lt = getAllNodes();
298   
299   //initially put all vertices in vector vt
300   //assign wt(root)=0
301   //wt(others)=infinity
302   //
303   //now:
304   //pull out u: a vertex frm vt of min wt
305   //for all vertices w in vt, 
306   //if wt(w) greater than 
307   //the wt(u->w), then assign
308   //wt(w) to be wt(u->w).
309   //
310   //make parent(u)=w in the spanning tree
311   //keep pulling out vertices from vt till it is empty
312
313   vector<Node *> vt;
314   
315   map<Node*, Node* > parent;
316   map<Node*, int > ed_weight;
317
318   //initialize: wt(root)=0, wt(others)=infinity
319   //parent(root)=NULL, parent(others) not defined (but not null)
320   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
321     Node *thisNode=*LI;
322     if(*thisNode == *getRoot()){
323       thisNode->setWeight(0);
324       parent[thisNode]=NULL;
325       ed_weight[thisNode]=0;
326     }
327     else{ 
328       thisNode->setWeight(inf);
329     }
330     st->addNode(thisNode);//add all nodes to spanning tree
331     //we later need to assign edges in the tree
332     vt.push_back(thisNode); //pushed all nodes in vt
333   }
334
335   //keep pulling out vertex of min wt from vt
336   while(!vt.empty()){
337     Node *u=*(min_element(vt.begin(), vt.end(), compare_nodes()));
338     DEBUG(cerr<<"popped wt"<<(u)->getWeight()<<"\n";
339           printNode(u));
340
341     if(parent[u]!=NULL){ //so not root
342       Edge edge(parent[u],u, ed_weight[u]); //assign edge in spanning tree
343       st->addEdge(edge,ed_weight[u]);
344
345       DEBUG(cerr<<"added:\n";
346             printEdge(edge));
347     }
348
349     //vt.erase(u);
350     
351     //remove u frm vt
352     for(vector<Node *>::iterator VI=vt.begin(), VE=vt.end(); VI!=VE; ++VI){
353       if(**VI==*u){
354         vt.erase(VI);
355         break;
356       }
357     }
358     
359     //assign wt(v) to all adjacent vertices v of u
360     //only if v is in vt
361     Graph::nodeList nl=getNodeList(u);
362     for(nodeList::iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI){
363       Node *v=NI->element;
364       int weight=-NI->weight;
365       //check if v is in vt
366       bool contains=false;
367       for(vector<Node *>::iterator VI=vt.begin(), VE=vt.end(); VI!=VE; ++VI){
368         if(**VI==*v){
369           contains=true;
370           break;
371         }
372       }
373       DEBUG(cerr<<"wt:v->wt"<<weight<<":"<<v->getWeight()<<"\n";
374             printNode(v);cerr<<"node wt:"<<(*v).weight<<"\n");
375
376       //so if v in in vt, change wt(v) to wt(u->v)
377       //only if wt(u->v)<wt(v)
378       if(contains && weight<v->getWeight()){
379         parent[v]=u;
380         ed_weight[v]=weight;
381         v->setWeight(weight);
382
383         DEBUG(cerr<<v->getWeight()<<":Set weight------\n";
384               printGraph();
385               printEdge(Edge(u,v,weight)));
386       }
387     }
388   }
389   return st;
390 }
391
392 //print the graph (for debugging)   
393 void Graph::printGraph(){
394    vector<Node *> lt=getAllNodes();
395    cerr<<"Graph---------------------\n";
396    for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
397      cerr<<((*LI)->getElement())->getName()<<"->";
398      Graph::nodeList nl=getNodeList(*LI);
399      for(Graph::nodeList::iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI){
400        cerr<<":"<<"("<<(NI->element->getElement())
401          ->getName()<<":"<<NI->element->getWeight()<<","<<NI->weight<<")";
402      }
403      cerr<<"--------\n";
404    }
405 }
406
407
408 //get a list of nodes in the graph
409 //in r-topological sorted order
410 //note that we assumed graph to be connected
411 vector<Node *> Graph::reverseTopologicalSort(){
412   vector <Node *> toReturn;
413   vector<Node *> lt=getAllNodes();
414   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
415     if((*LI)->getWeight()!=GREY && (*LI)->getWeight()!=BLACK)
416       DFS_Visit(*LI, toReturn);
417   }
418
419   //print nodes
420   //std::cerr<<"Topological sort--------\n";
421   //for(vector<Node *>::iterator VI = toReturn.begin(), VE = toReturn.end(); 
422   //  VI!=VE; ++VI)
423   //std::cerr<<(*VI)->getElement()->getName()<<"->";
424   //std::cerr<<"\n----------------------\n";
425   return toReturn;
426 }
427
428 //a private method for doing DFS traversal of graph
429 //this is used in determining the reverse topological sort 
430 //of the graph
431 void Graph::DFS_Visit(Node *nd, vector<Node *> &toReturn){
432   nd->setWeight(GREY);
433   vector<Node *> lt=getSuccNodes(nd);
434   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
435     if((*LI)->getWeight()!=GREY && (*LI)->getWeight()!=BLACK)
436       DFS_Visit(*LI, toReturn);
437   }
438   toReturn.push_back(nd);
439 }
440
441 //Ordinarily, the graph is directional
442 //this converts the graph into an 
443 //undirectional graph
444 //This is done by adding an edge
445 //v->u for all existing edges u->v
446 void Graph::makeUnDirectional(){
447   vector<Node* > allNodes=getAllNodes();
448   for(vector<Node *>::iterator NI=allNodes.begin(), NE=allNodes.end(); NI!=NE; 
449       ++NI) {
450     nodeList nl=getNodeList(*NI);
451     for(nodeList::iterator NLI=nl.begin(), NLE=nl.end(); NLI!=NLE; ++NLI){
452       Edge ed(NLI->element, *NI, NLI->weight);
453       if(!hasEdgeAndWt(ed)){
454         DEBUG(cerr<<"######doesn't hv\n";
455               printEdge(ed));
456         addEdgeForce(ed);
457       }
458     }
459   }
460 }
461
462 //reverse the sign of weights on edges
463 //this way, max-spanning tree could be obtained
464 //usin min-spanning tree, and vice versa
465 void Graph::reverseWts(){
466   vector<Node *> allNodes=getAllNodes();
467   for(vector<Node *>::iterator NI=allNodes.begin(), NE=allNodes.end(); NI!=NE; 
468       ++NI) {
469     nodeList node_list=getNodeList(*NI);
470     for(nodeList::iterator NLI=nodes[*NI].begin(), NLE=nodes[*NI].end(); 
471         NLI!=NLE; ++NLI)
472       NLI->weight=-NLI->weight;
473   }
474 }
475
476
477 //getting the backedges in a graph
478 //Its a variation of DFS to get the backedges in the graph
479 //We get back edges by associating a time
480 //and a color with each vertex.
481 //The time of a vertex is the time when it was first visited
482 //The color of a vertex is initially WHITE,
483 //Changes to GREY when it is first visited,
484 //and changes to BLACK when ALL its neighbors
485 //have been visited
486 //So we have a back edge when we meet a successor of
487 //a node with smaller time, and GREY color
488 void Graph::getBackEdges(vector<Edge > &be, map<Node *, int> &d){
489   map<Node *, Color > color;
490   //map<Node *, int > d;
491   //vector<Node *> allNodes=getAllNodes();
492   int time=0;
493   //for(vector<Node *>::iterator NI=allNodes.begin(), NE=allNodes.end(); 
494   //  NI!=NE; ++NI){
495   //if(color[*NI]!=GREY && color[*NI]!=BLACK)
496   //printGraph();
497   getBackEdgesVisit(getRoot(), be, color, d, time);//*NI, be, color, d, time);
498   //}
499 }
500
501 //helper function to get back edges: it is called by 
502 //the "getBackEdges" function above
503 void Graph::getBackEdgesVisit(Node *u, vector<Edge > &be,
504                               map<Node *, Color > &color,
505                               map<Node *, int > &d, int &time) {
506   color[u]=GREY;
507   time++;
508   d[u]=time;
509
510   //std::cerr<<"Node list-----\n";
511   vector<graphListElement> succ_list = getSortedNodeList(u);
512   
513   //for(vector<graphListElement>::iterator vl=succ_list.begin(), 
514   //    ve=succ_list.end(); vl!=ve; ++vl){
515   //Node *v=vl->element;
516   //std::cerr<<v->getElement()->getName()<<"->";
517   //}
518   //std::cerr<<"\n-------- end Node list\n";
519   
520   for(vector<graphListElement>::iterator vl=succ_list.begin(), 
521         ve=succ_list.end(); vl!=ve; ++vl){
522     Node *v=vl->element;
523     //  for(vector<Node *>::const_iterator v=succ_list.begin(), ve=succ_list.end(); 
524       //  v!=ve; ++v){
525       
526       if(color[v]!=GREY && color[v]!=BLACK){
527         getBackEdgesVisit(v, be, color, d, time);
528       }
529     
530     //now checking for d and f vals
531     if(color[v]==GREY){
532       //so v is ancestor of u if time of u > time of v
533       if(d[u] >= d[v]){
534         Edge *ed=new Edge(u, v,vl->weight, vl->randId);
535         if (!(*u == *getExit() && *v == *getRoot()))
536           be.push_back(*ed);      // choose the forward edges
537       }
538     }
539   }
540   color[u]=BLACK;//done with visiting the node and its neighbors
541 }
542
543