code to lightly instrument at branches
[oota-llvm.git] / lib / Transforms / Instrumentation / ProfilePaths / Graph.cpp
1 //===--Graph.cpp--- implements Graph class ---------------- ------*- C++ -*--=//
2 //
3 // This implements Graph for helping in trace generation
4 // This graph gets used by "ProfilePaths" class
5 //
6 //===----------------------------------------------------------------------===//
7
8 #include "llvm/Transforms/Instrumentation/Graph.h"
9 #include "llvm/iTerminators.h"
10 #include "llvm/BasicBlock.h"
11 #include "Support/Statistic.h"
12 #include <algorithm>
13
14 //using std::list;
15 //using std::set;
16 using std::map;
17 using std::vector;
18 using std::cerr;
19
20 const graphListElement *findNodeInList(const Graph::nodeList &NL,
21                                               Node *N) {
22   for(Graph::nodeList::const_iterator NI = NL.begin(), NE=NL.end(); NI != NE; 
23       ++NI)
24     if (*NI->element== *N)
25       return &*NI;
26   return 0;
27 }
28
29 graphListElement *findNodeInList(Graph::nodeList &NL, Node *N) {
30   for(Graph::nodeList::iterator NI = NL.begin(), NE=NL.end(); NI != NE; ++NI)
31     if (*NI->element== *N)
32       return &*NI;
33   return 0;
34 }
35
36 //graph constructor with root and exit specified
37 Graph::Graph(std::vector<Node*> n, std::vector<Edge> e, 
38              Node *rt, Node *lt){
39   strt=rt;
40   ext=lt;
41   for(vector<Node* >::iterator x=n.begin(), en=n.end(); x!=en; ++x)
42     //nodes[*x] = list<graphListElement>();
43     nodes[*x] = vector<graphListElement>();
44
45   for(vector<Edge >::iterator x=e.begin(), en=e.end(); x!=en; ++x){
46     Edge ee=*x;
47     int w=ee.getWeight();
48     //nodes[ee.getFirst()].push_front(graphListElement(ee.getSecond(),w, ee.getRandId()));   
49     nodes[ee.getFirst()].push_back(graphListElement(ee.getSecond(),w, ee.getRandId()));
50   }
51   
52 }
53
54 //sorting edgelist, called by backEdgeVist ONLY!!!
55 Graph::nodeList &Graph::sortNodeList(Node *par, nodeList &nl, vector<Edge> &be){
56   assert(par && "null node pointer");
57   BasicBlock *bbPar = par->getElement();
58   
59   if(nl.size()<=1) return nl;
60   if(getExit() == par) return nl;
61
62   for(nodeList::iterator NLI = nl.begin(), NLE = nl.end()-1; NLI != NLE; ++NLI){
63     nodeList::iterator min = NLI;
64     for(nodeList::iterator LI = NLI+1, LE = nl.end(); LI!=LE; ++LI){
65       //if LI < min, min = LI
66       if(min->element->getElement() == LI->element->getElement() &&
67          min->element == getExit()){
68
69         //same successors: so might be exit???
70         //if it is exit, then see which is backedge
71         //check if LI is a left back edge!
72
73         TerminatorInst *tti = par->getElement()->getTerminator();
74         BranchInst *ti =  cast<BranchInst>(tti);
75
76         assert(ti && "not a branch");
77         assert(ti->getNumSuccessors()==2 && "less successors!");
78         
79         BasicBlock *tB = ti->getSuccessor(0);
80         BasicBlock *fB = ti->getSuccessor(1);
81         //so one of LI or min must be back edge!
82         //Algo: if succ(0)!=LI (and so !=min) then succ(0) is backedge
83         //and then see which of min or LI is backedge
84         //THEN if LI is in be, then min=LI
85         if(LI->element->getElement() != tB){//so backedge must be made min!
86           for(vector<Edge>::iterator VBEI = be.begin(), VBEE = be.end();
87               VBEI != VBEE; ++VBEI){
88             if(VBEI->getRandId() == LI->randId){
89               min = LI;
90               break;
91             }
92             else if(VBEI->getRandId() == min->randId)
93               break;
94           }
95         }
96         else{// if(LI->element->getElement() != fB)
97           for(vector<Edge>::iterator VBEI = be.begin(), VBEE = be.end();
98               VBEI != VBEE; ++VBEI){
99             if(VBEI->getRandId() == min->randId){
100               min = LI;
101               break;
102             }
103             else if(VBEI->getRandId() == LI->randId)
104               break;
105           }
106         }
107       }
108       
109       else if (min->element->getElement() != LI->element->getElement()){
110         TerminatorInst *tti = par->getElement()->getTerminator();
111         BranchInst *ti =  cast<BranchInst>(tti);
112         assert(ti && "not a branch");
113
114         if(ti->getNumSuccessors()<=1) continue;
115         
116         assert(ti->getNumSuccessors()==2 && "less successors!");
117         
118         BasicBlock *tB = ti->getSuccessor(0);
119         BasicBlock *fB = ti->getSuccessor(1);
120         
121         if(tB == LI->element->getElement() || fB == min->element->getElement())
122           min = LI;
123       }
124     }
125     
126     graphListElement tmpElmnt = *min;
127     *min = *NLI;
128     *NLI = tmpElmnt;
129   }
130   return nl;
131 }
132
133 //check whether graph has an edge
134 //having an edge simply means that there is an edge in the graph
135 //which has same endpoints as the given edge
136 bool Graph::hasEdge(Edge ed){
137   if(ed.isNull())
138     return false;
139
140   nodeList &nli= nodes[ed.getFirst()]; //getNodeList(ed.getFirst());
141   Node *nd2=ed.getSecond();
142
143   return (findNodeInList(nli,nd2)!=NULL);
144
145 }
146
147
148 //check whether graph has an edge, with a given wt
149 //having an edge simply means that there is an edge in the graph
150 //which has same endpoints as the given edge
151 //This function checks, moreover, that the wt of edge matches too
152 bool Graph::hasEdgeAndWt(Edge ed){
153   if(ed.isNull())
154     return false;
155
156   Node *nd2=ed.getSecond();
157   nodeList nli = nodes[ed.getFirst()];//getNodeList(ed.getFirst());
158   
159   for(nodeList::iterator NI=nli.begin(), NE=nli.end(); NI!=NE; ++NI)
160     if(*NI->element == *nd2 && ed.getWeight()==NI->weight)
161       return true;
162   
163   return false;
164 }
165
166 //add a node
167 void Graph::addNode(Node *nd){
168   vector<Node *> lt=getAllNodes();
169
170   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE;++LI){
171     if(**LI==*nd)
172       return;
173   }
174   //chng
175   nodes[nd] =vector<graphListElement>(); //list<graphListElement>();
176 }
177
178 //add an edge
179 //this adds an edge ONLY when 
180 //the edge to be added doesn not already exist
181 //we "equate" two edges here only with their 
182 //end points
183 void Graph::addEdge(Edge ed, int w){
184   nodeList &ndList = nodes[ed.getFirst()];
185   Node *nd2=ed.getSecond();
186
187   if(findNodeInList(nodes[ed.getFirst()], nd2))
188     return;
189  
190   //ndList.push_front(graphListElement(nd2,w, ed.getRandId()));
191   ndList.push_back(graphListElement(nd2,w, ed.getRandId()));//chng
192   //sortNodeList(ed.getFirst(), ndList);
193
194   //sort(ndList.begin(), ndList.end(), NodeListSort());
195 }
196
197 //add an edge EVEN IF such an edge already exists
198 //this may make a multi-graph
199 //which does happen when we add dummy edges
200 //to the graph, for compensating for back-edges
201 void Graph::addEdgeForce(Edge ed){
202   //nodes[ed.getFirst()].push_front(graphListElement(ed.getSecond(),
203   //ed.getWeight(), ed.getRandId()));
204   nodes[ed.getFirst()].push_back
205     (graphListElement(ed.getSecond(), ed.getWeight(), ed.getRandId()));
206
207   //sortNodeList(ed.getFirst(), nodes[ed.getFirst()]);
208   //sort(nodes[ed.getFirst()].begin(), nodes[ed.getFirst()].end(), NodeListSort());
209 }
210
211 //remove an edge
212 //Note that it removes just one edge,
213 //the first edge that is encountered
214 void Graph::removeEdge(Edge ed){
215   nodeList &ndList = nodes[ed.getFirst()];
216   Node &nd2 = *ed.getSecond();
217
218   for(nodeList::iterator NI=ndList.begin(), NE=ndList.end(); NI!=NE ;++NI) {
219     if(*NI->element == nd2) {
220       ndList.erase(NI);
221       break;
222     }
223   }
224 }
225
226 //remove an edge with a given wt
227 //Note that it removes just one edge,
228 //the first edge that is encountered
229 void Graph::removeEdgeWithWt(Edge ed){
230   nodeList &ndList = nodes[ed.getFirst()];
231   Node &nd2 = *ed.getSecond();
232
233   for(nodeList::iterator NI=ndList.begin(), NE=ndList.end(); NI!=NE ;++NI) {
234     if(*NI->element == nd2 && NI->weight==ed.getWeight()) {
235       ndList.erase(NI);
236       break;
237     }
238   }
239 }
240
241 //set the weight of an edge
242 void Graph::setWeight(Edge ed){
243   graphListElement *El = findNodeInList(nodes[ed.getFirst()], ed.getSecond());
244   if (El)
245     El->weight=ed.getWeight();
246 }
247
248
249
250 //get the list of successor nodes
251 vector<Node *> Graph::getSuccNodes(Node *nd){
252   nodeMapTy::const_iterator nli = nodes.find(nd);
253   assert(nli != nodes.end() && "Node must be in nodes map");
254   const nodeList &nl = getNodeList(nd);//getSortedNodeList(nd);
255
256   vector<Node *> lt;
257   for(nodeList::const_iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI)
258     lt.push_back(NI->element);
259
260   return lt;
261 }
262
263 //get the number of outgoing edges
264 int Graph::getNumberOfOutgoingEdges(Node *nd) const {
265   nodeMapTy::const_iterator nli = nodes.find(nd);
266   assert(nli != nodes.end() && "Node must be in nodes map");
267   const nodeList &nl = nli->second;
268
269   int count=0;
270   for(nodeList::const_iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI)
271     count++;
272
273   return count;
274 }
275
276 //get the list of predecessor nodes
277 vector<Node *> Graph::getPredNodes(Node *nd){
278   vector<Node *> lt;
279   for(nodeMapTy::const_iterator EI=nodes.begin(), EE=nodes.end(); EI!=EE ;++EI){
280     Node *lnode=EI->first;
281     const nodeList &nl = getNodeList(lnode);
282
283     const graphListElement *N = findNodeInList(nl, nd);
284     if (N) lt.push_back(lnode);
285   }
286   return lt;
287 }
288
289 //get the number of predecessor nodes
290 int Graph::getNumberOfIncomingEdges(Node *nd){
291   int count=0;
292   for(nodeMapTy::const_iterator EI=nodes.begin(), EE=nodes.end(); EI!=EE ;++EI){
293     Node *lnode=EI->first;
294     const nodeList &nl = getNodeList(lnode);
295     for(Graph::nodeList::const_iterator NI = nl.begin(), NE=nl.end(); NI != NE; 
296         ++NI)
297       if (*NI->element== *nd)
298         count++;
299   }
300   return count;
301 }
302
303 //get the list of all the vertices in graph
304 vector<Node *> Graph::getAllNodes() const{
305   vector<Node *> lt;
306   for(nodeMapTy::const_iterator x=nodes.begin(), en=nodes.end(); x != en; ++x)
307     lt.push_back(x->first);
308
309   return lt;
310 }
311
312 //get the list of all the vertices in graph
313 vector<Node *> Graph::getAllNodes(){
314   vector<Node *> lt;
315   for(nodeMapTy::const_iterator x=nodes.begin(), en=nodes.end(); x != en; ++x)
316     lt.push_back(x->first);
317
318   return lt;
319 }
320
321 //class to compare two nodes in graph
322 //based on their wt: this is used in
323 //finding the maximal spanning tree
324 struct compare_nodes {
325   bool operator()(Node *n1, Node *n2){
326     return n1->getWeight() < n2->getWeight();
327   }
328 };
329
330
331 static void printNode(Node *nd){
332   cerr<<"Node:"<<nd->getElement()->getName()<<"\n";
333 }
334
335 //Get the Maximal spanning tree (also a graph)
336 //of the graph
337 Graph* Graph::getMaxSpanningTree(){
338   //assume connected graph
339  
340   Graph *st=new Graph();//max spanning tree, undirected edges
341   int inf=9999999;//largest key
342   vector<Node *> lt = getAllNodes();
343   
344   //initially put all vertices in vector vt
345   //assign wt(root)=0
346   //wt(others)=infinity
347   //
348   //now:
349   //pull out u: a vertex frm vt of min wt
350   //for all vertices w in vt, 
351   //if wt(w) greater than 
352   //the wt(u->w), then assign
353   //wt(w) to be wt(u->w).
354   //
355   //make parent(u)=w in the spanning tree
356   //keep pulling out vertices from vt till it is empty
357
358   vector<Node *> vt;
359   
360   map<Node*, Node* > parent;
361   map<Node*, int > ed_weight;
362
363   //initialize: wt(root)=0, wt(others)=infinity
364   //parent(root)=NULL, parent(others) not defined (but not null)
365   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
366     Node *thisNode=*LI;
367     if(*thisNode == *getRoot()){
368       thisNode->setWeight(0);
369       parent[thisNode]=NULL;
370       ed_weight[thisNode]=0;
371     }
372     else{ 
373       thisNode->setWeight(inf);
374     }
375     st->addNode(thisNode);//add all nodes to spanning tree
376     //we later need to assign edges in the tree
377     vt.push_back(thisNode); //pushed all nodes in vt
378   }
379
380   //keep pulling out vertex of min wt from vt
381   while(!vt.empty()){
382     Node *u=*(min_element(vt.begin(), vt.end(), compare_nodes()));
383     DEBUG(cerr<<"popped wt"<<(u)->getWeight()<<"\n";
384           printNode(u));
385
386     if(parent[u]!=NULL){ //so not root
387       Edge edge(parent[u],u, ed_weight[u]); //assign edge in spanning tree
388       st->addEdge(edge,ed_weight[u]);
389
390       DEBUG(cerr<<"added:\n";
391             printEdge(edge));
392     }
393
394     //vt.erase(u);
395     
396     //remove u frm vt
397     for(vector<Node *>::iterator VI=vt.begin(), VE=vt.end(); VI!=VE; ++VI){
398       if(**VI==*u){
399         vt.erase(VI);
400         break;
401       }
402     }
403     
404     //assign wt(v) to all adjacent vertices v of u
405     //only if v is in vt
406     Graph::nodeList nl=getNodeList(u);
407     for(nodeList::iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI){
408       Node *v=NI->element;
409       int weight=-NI->weight;
410       //check if v is in vt
411       bool contains=false;
412       for(vector<Node *>::iterator VI=vt.begin(), VE=vt.end(); VI!=VE; ++VI){
413         if(**VI==*v){
414           contains=true;
415           break;
416         }
417       }
418       DEBUG(cerr<<"wt:v->wt"<<weight<<":"<<v->getWeight()<<"\n";
419             printNode(v);cerr<<"node wt:"<<(*v).weight<<"\n");
420
421       //so if v in in vt, change wt(v) to wt(u->v)
422       //only if wt(u->v)<wt(v)
423       if(contains && weight<v->getWeight()){
424         parent[v]=u;
425         ed_weight[v]=weight;
426         v->setWeight(weight);
427
428         DEBUG(cerr<<v->getWeight()<<":Set weight------\n";
429               printGraph();
430               printEdge(Edge(u,v,weight)));
431       }
432     }
433   }
434   return st;
435 }
436
437 //print the graph (for debugging)   
438 void Graph::printGraph(){
439    vector<Node *> lt=getAllNodes();
440    cerr<<"Graph---------------------\n";
441    for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
442      cerr<<((*LI)->getElement())->getName()<<"->";
443      Graph::nodeList nl=getNodeList(*LI);
444      for(Graph::nodeList::iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI){
445        cerr<<":"<<"("<<(NI->element->getElement())
446          ->getName()<<":"<<NI->element->getWeight()<<","<<NI->weight<<")";
447      }
448      cerr<<"--------\n";
449    }
450 }
451
452
453 //get a list of nodes in the graph
454 //in r-topological sorted order
455 //note that we assumed graph to be connected
456 vector<Node *> Graph::reverseTopologicalSort(){
457   vector <Node *> toReturn;
458   vector<Node *> lt=getAllNodes();
459   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
460     if((*LI)->getWeight()!=GREY && (*LI)->getWeight()!=BLACK)
461       DFS_Visit(*LI, toReturn);
462   }
463
464   return toReturn;
465 }
466
467 //a private method for doing DFS traversal of graph
468 //this is used in determining the reverse topological sort 
469 //of the graph
470 void Graph::DFS_Visit(Node *nd, vector<Node *> &toReturn){
471   nd->setWeight(GREY);
472   vector<Node *> lt=getSuccNodes(nd);
473   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
474     if((*LI)->getWeight()!=GREY && (*LI)->getWeight()!=BLACK)
475       DFS_Visit(*LI, toReturn);
476   }
477   toReturn.push_back(nd);
478 }
479
480 //Ordinarily, the graph is directional
481 //this converts the graph into an 
482 //undirectional graph
483 //This is done by adding an edge
484 //v->u for all existing edges u->v
485 void Graph::makeUnDirectional(){
486   vector<Node* > allNodes=getAllNodes();
487   for(vector<Node *>::iterator NI=allNodes.begin(), NE=allNodes.end(); NI!=NE; 
488       ++NI) {
489     nodeList nl=getNodeList(*NI);
490     for(nodeList::iterator NLI=nl.begin(), NLE=nl.end(); NLI!=NLE; ++NLI){
491       Edge ed(NLI->element, *NI, NLI->weight);
492       if(!hasEdgeAndWt(ed)){
493         DEBUG(cerr<<"######doesn't hv\n";
494               printEdge(ed));
495         addEdgeForce(ed);
496       }
497     }
498   }
499 }
500
501 //reverse the sign of weights on edges
502 //this way, max-spanning tree could be obtained
503 //usin min-spanning tree, and vice versa
504 void Graph::reverseWts(){
505   vector<Node *> allNodes=getAllNodes();
506   for(vector<Node *>::iterator NI=allNodes.begin(), NE=allNodes.end(); NI!=NE; 
507       ++NI) {
508     nodeList node_list=getNodeList(*NI);
509     for(nodeList::iterator NLI=nodes[*NI].begin(), NLE=nodes[*NI].end(); 
510         NLI!=NLE; ++NLI)
511       NLI->weight=-NLI->weight;
512   }
513 }
514
515
516 //getting the backedges in a graph
517 //Its a variation of DFS to get the backedges in the graph
518 //We get back edges by associating a time
519 //and a color with each vertex.
520 //The time of a vertex is the time when it was first visited
521 //The color of a vertex is initially WHITE,
522 //Changes to GREY when it is first visited,
523 //and changes to BLACK when ALL its neighbors
524 //have been visited
525 //So we have a back edge when we meet a successor of
526 //a node with smaller time, and GREY color
527 void Graph::getBackEdges(vector<Edge > &be, map<Node *, int> &d){
528   map<Node *, Color > color;
529   int time=0;
530
531   getBackEdgesVisit(getRoot(), be, color, d, time);
532 }
533
534 //helper function to get back edges: it is called by 
535 //the "getBackEdges" function above
536 void Graph::getBackEdgesVisit(Node *u, vector<Edge > &be,
537                               map<Node *, Color > &color,
538                               map<Node *, int > &d, int &time) {
539   color[u]=GREY;
540   time++;
541   d[u]=time;
542
543   vector<graphListElement> &succ_list = getNodeList(u);
544   
545   for(vector<graphListElement>::iterator vl=succ_list.begin(), 
546         ve=succ_list.end(); vl!=ve; ++vl){
547     Node *v=vl->element;
548     if(color[v]!=GREY && color[v]!=BLACK){
549       getBackEdgesVisit(v, be, color, d, time);
550     }
551     
552     //now checking for d and f vals
553     if(color[v]==GREY){
554       //so v is ancestor of u if time of u > time of v
555       if(d[u] >= d[v]){
556         Edge *ed=new Edge(u, v,vl->weight, vl->randId);
557         if (!(*u == *getExit() && *v == *getRoot()))
558           be.push_back(*ed);      // choose the forward edges
559       }
560     }
561   }
562   color[u]=BLACK;//done with visiting the node and its neighbors
563 }
564
565