additions and bug fixes
[oota-llvm.git] / lib / Transforms / Instrumentation / ProfilePaths / Graph.cpp
1 //===--Graph.cpp--- implements Graph class ---------------- ------*- C++ -*--=//
2 //
3 // This implements Graph for helping in trace generation
4 // This graph gets used by "ProfilePaths" class
5 //
6 //===----------------------------------------------------------------------===//
7
8 #include "llvm/Transforms/Instrumentation/Graph.h"
9 #include "llvm/BasicBlock.h"
10 #include <algorithm>
11 #include <iostream>
12
13 //using std::list;
14 //using std::set;
15 using std::map;
16 using std::vector;
17 using std::cerr;
18
19 const graphListElement *findNodeInList(const Graph::nodeList &NL,
20                                               Node *N) {
21   for(Graph::nodeList::const_iterator NI = NL.begin(), NE=NL.end(); NI != NE; 
22       ++NI)
23     if (*NI->element== *N)
24       return &*NI;
25   return 0;
26 }
27
28 graphListElement *findNodeInList(Graph::nodeList &NL, Node *N) {
29   for(Graph::nodeList::iterator NI = NL.begin(), NE=NL.end(); NI != NE; ++NI)
30     if (*NI->element== *N)
31       return &*NI;
32   return 0;
33 }
34
35 //graph constructor with root and exit specified
36 Graph::Graph(std::vector<Node*> n, std::vector<Edge> e, 
37              Node *rt, Node *lt){
38   strt=rt;
39   ext=lt;
40   for(vector<Node* >::iterator x=n.begin(), en=n.end(); x!=en; ++x)
41     //nodes[*x] = list<graphListElement>();
42     nodes[*x] = vector<graphListElement>();
43
44   for(vector<Edge >::iterator x=e.begin(), en=e.end(); x!=en; ++x){
45     Edge ee=*x;
46     int w=ee.getWeight();
47     //nodes[ee.getFirst()].push_front(graphListElement(ee.getSecond(),w, ee.getRandId()));   
48     nodes[ee.getFirst()].push_back(graphListElement(ee.getSecond(),w, ee.getRandId()));
49   }
50   
51 }
52
53 //check whether graph has an edge
54 //having an edge simply means that there is an edge in the graph
55 //which has same endpoints as the given edge
56 bool Graph::hasEdge(Edge ed) const{
57   if(ed.isNull())
58     return false;
59
60   nodeList nli=getNodeList(ed.getFirst());
61   Node *nd2=ed.getSecond();
62
63   return (findNodeInList(nli,nd2)!=NULL);
64
65 }
66
67
68 //check whether graph has an edge, with a given wt
69 //having an edge simply means that there is an edge in the graph
70 //which has same endpoints as the given edge
71 //This function checks, moreover, that the wt of edge matches too
72 bool Graph::hasEdgeAndWt(Edge ed) const{
73   if(ed.isNull())
74     return false;
75
76   Node *nd2=ed.getSecond();
77   nodeList nli=getNodeList(ed.getFirst());
78   
79   for(nodeList::iterator NI=nli.begin(), NE=nli.end(); NI!=NE; ++NI)
80     if(*NI->element == *nd2 && ed.getWeight()==NI->weight)
81       return true;
82   
83   return false;
84 }
85
86 //add a node
87 void Graph::addNode(Node *nd){
88   vector<Node *> lt=getAllNodes();
89
90   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE;++LI){
91     if(**LI==*nd)
92       return;
93   }
94   //chng
95   nodes[nd] =vector<graphListElement>(); //list<graphListElement>();
96 }
97
98 //add an edge
99 //this adds an edge ONLY when 
100 //the edge to be added doesn not already exist
101 //we "equate" two edges here only with their 
102 //end points
103 void Graph::addEdge(Edge ed, int w){
104   nodeList &ndList = nodes[ed.getFirst()];
105   Node *nd2=ed.getSecond();
106
107   if(findNodeInList(nodes[ed.getFirst()], nd2))
108     return;
109  
110   //ndList.push_front(graphListElement(nd2,w, ed.getRandId()));
111   ndList.push_back(graphListElement(nd2,w, ed.getRandId()));//chng
112
113   //sort(ndList.begin(), ndList.end(), NodeListSort());
114 }
115
116 //add an edge EVEN IF such an edge already exists
117 //this may make a multi-graph
118 //which does happen when we add dummy edges
119 //to the graph, for compensating for back-edges
120 void Graph::addEdgeForce(Edge ed){
121   //nodes[ed.getFirst()].push_front(graphListElement(ed.getSecond(),
122   //ed.getWeight(), ed.getRandId()));
123   nodes[ed.getFirst()].push_back
124     (graphListElement(ed.getSecond(), ed.getWeight(), ed.getRandId()));
125
126   //sort(nodes[ed.getFirst()].begin(), nodes[ed.getFirst()].end(), NodeListSort());
127 }
128
129 //remove an edge
130 //Note that it removes just one edge,
131 //the first edge that is encountered
132 void Graph::removeEdge(Edge ed){
133   nodeList &ndList = nodes[ed.getFirst()];
134   Node &nd2 = *ed.getSecond();
135
136   for(nodeList::iterator NI=ndList.begin(), NE=ndList.end(); NI!=NE ;++NI) {
137     if(*NI->element == nd2) {
138       ndList.erase(NI);
139       break;
140     }
141   }
142 }
143
144 //remove an edge with a given wt
145 //Note that it removes just one edge,
146 //the first edge that is encountered
147 void Graph::removeEdgeWithWt(Edge ed){
148   nodeList &ndList = nodes[ed.getFirst()];
149   Node &nd2 = *ed.getSecond();
150
151   for(nodeList::iterator NI=ndList.begin(), NE=ndList.end(); NI!=NE ;++NI) {
152     if(*NI->element == nd2 && NI->weight==ed.getWeight()) {
153       ndList.erase(NI);
154       break;
155     }
156   }
157 }
158
159 //set the weight of an edge
160 void Graph::setWeight(Edge ed){
161   graphListElement *El = findNodeInList(nodes[ed.getFirst()], ed.getSecond());
162   if (El)
163     El->weight=ed.getWeight();
164 }
165
166
167
168 //get the list of successor nodes
169 vector<Node *> Graph::getSuccNodes(Node *nd) const {
170   nodeMapTy::const_iterator nli = nodes.find(nd);
171   assert(nli != nodes.end() && "Node must be in nodes map");
172   const nodeList &nl = nli->second;
173
174   vector<Node *> lt;
175   for(nodeList::const_iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI)
176     lt.push_back(NI->element);
177
178   return lt;
179 }
180
181 //get the number of outgoing edges
182 int Graph::getNumberOfOutgoingEdges(Node *nd) const {
183   nodeMapTy::const_iterator nli = nodes.find(nd);
184   assert(nli != nodes.end() && "Node must be in nodes map");
185   const nodeList &nl = nli->second;
186
187   int count=0;
188   for(nodeList::const_iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI)
189     count++;
190
191   return count;
192 }
193
194 //get the list of predecessor nodes
195 vector<Node *> Graph::getPredNodes(Node *nd) const{
196   vector<Node *> lt;
197   for(nodeMapTy::const_iterator EI=nodes.begin(), EE=nodes.end(); EI!=EE ;++EI){
198     Node *lnode=EI->first;
199     const nodeList &nl = getNodeList(lnode);
200
201     const graphListElement *N = findNodeInList(nl, nd);
202     if (N) lt.push_back(lnode);
203   }
204   return lt;
205 }
206
207 //get the number of predecessor nodes
208 int Graph::getNumberOfIncomingEdges(Node *nd) const{
209   int count=0;
210   for(nodeMapTy::const_iterator EI=nodes.begin(), EE=nodes.end(); EI!=EE ;++EI){
211     Node *lnode=EI->first;
212     const nodeList &nl = getNodeList(lnode);
213     for(Graph::nodeList::const_iterator NI = nl.begin(), NE=nl.end(); NI != NE; 
214         ++NI)
215       if (*NI->element== *nd)
216         count++;
217   }
218   return count;
219 }
220
221 //get the list of all the vertices in graph
222 vector<Node *> Graph::getAllNodes() const{
223   vector<Node *> lt;
224   for(nodeMapTy::const_iterator x=nodes.begin(), en=nodes.end(); x != en; ++x)
225     lt.push_back(x->first);
226
227   return lt;
228 }
229
230 //get the list of all the vertices in graph
231 vector<Node *> Graph::getAllNodes(){
232   vector<Node *> lt;
233   for(nodeMapTy::const_iterator x=nodes.begin(), en=nodes.end(); x != en; ++x)
234     lt.push_back(x->first);
235
236   return lt;
237 }
238
239 //class to compare two nodes in graph
240 //based on their wt: this is used in
241 //finding the maximal spanning tree
242 struct compare_nodes {
243   bool operator()(Node *n1, Node *n2){
244     return n1->getWeight() < n2->getWeight();
245   }
246 };
247
248
249 static void printNode(Node *nd){
250   cerr<<"Node:"<<nd->getElement()->getName()<<"\n";
251 }
252
253 //Get the Maximal spanning tree (also a graph)
254 //of the graph
255 Graph* Graph::getMaxSpanningTree(){
256   //assume connected graph
257  
258   Graph *st=new Graph();//max spanning tree, undirected edges
259   int inf=9999999;//largest key
260   vector<Node *> lt = getAllNodes();
261   
262   //initially put all vertices in vector vt
263   //assign wt(root)=0
264   //wt(others)=infinity
265   //
266   //now:
267   //pull out u: a vertex frm vt of min wt
268   //for all vertices w in vt, 
269   //if wt(w) greater than 
270   //the wt(u->w), then assign
271   //wt(w) to be wt(u->w).
272   //
273   //make parent(u)=w in the spanning tree
274   //keep pulling out vertices from vt till it is empty
275
276   vector<Node *> vt;
277   
278   map<Node*, Node* > parent;
279   map<Node*, int > ed_weight;
280
281   //initialize: wt(root)=0, wt(others)=infinity
282   //parent(root)=NULL, parent(others) not defined (but not null)
283   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
284     Node *thisNode=*LI;
285     if(*thisNode == *getRoot()){
286       thisNode->setWeight(0);
287       parent[thisNode]=NULL;
288       ed_weight[thisNode]=0;
289     }
290     else{ 
291       thisNode->setWeight(inf);
292     }
293     st->addNode(thisNode);//add all nodes to spanning tree
294     //we later need to assign edges in the tree
295     vt.push_back(thisNode); //pushed all nodes in vt
296   }
297
298   //keep pulling out vertex of min wt from vt
299   while(!vt.empty()){
300     Node *u=*(min_element(vt.begin(), vt.end(), compare_nodes()));
301     DEBUG(cerr<<"popped wt"<<(u)->getWeight()<<"\n";
302           printNode(u));
303
304     if(parent[u]!=NULL){ //so not root
305       Edge edge(parent[u],u, ed_weight[u]); //assign edge in spanning tree
306       st->addEdge(edge,ed_weight[u]);
307
308       DEBUG(cerr<<"added:\n";
309             printEdge(edge));
310     }
311
312     //vt.erase(u);
313     
314     //remove u frm vt
315     for(vector<Node *>::iterator VI=vt.begin(), VE=vt.end(); VI!=VE; ++VI){
316       if(**VI==*u){
317         vt.erase(VI);
318         break;
319       }
320     }
321     
322     //assign wt(v) to all adjacent vertices v of u
323     //only if v is in vt
324     Graph::nodeList nl=getNodeList(u);
325     for(nodeList::iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI){
326       Node *v=NI->element;
327       int weight=-NI->weight;
328       //check if v is in vt
329       bool contains=false;
330       for(vector<Node *>::iterator VI=vt.begin(), VE=vt.end(); VI!=VE; ++VI){
331         if(**VI==*v){
332           contains=true;
333           break;
334         }
335       }
336       DEBUG(cerr<<"wt:v->wt"<<weight<<":"<<v->getWeight()<<"\n";
337             printNode(v);cerr<<"node wt:"<<(*v).weight<<"\n");
338
339       //so if v in in vt, change wt(v) to wt(u->v)
340       //only if wt(u->v)<wt(v)
341       if(contains && weight<v->getWeight()){
342         parent[v]=u;
343         ed_weight[v]=weight;
344         v->setWeight(weight);
345
346         DEBUG(cerr<<v->getWeight()<<":Set weight------\n";
347               printGraph();
348               printEdge(Edge(u,v,weight)));
349       }
350     }
351   }
352   return st;
353 }
354
355 //print the graph (for debugging)   
356 void Graph::printGraph(){
357    vector<Node *> lt=getAllNodes();
358    cerr<<"Graph---------------------\n";
359    for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
360      cerr<<((*LI)->getElement())->getName()<<"->";
361      Graph::nodeList nl=getNodeList(*LI);
362      for(Graph::nodeList::iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI){
363        cerr<<":"<<"("<<(NI->element->getElement())
364          ->getName()<<":"<<NI->element->getWeight()<<","<<NI->weight<<")";
365      }
366      cerr<<"--------\n";
367    }
368 }
369
370
371 //get a list of nodes in the graph
372 //in r-topological sorted order
373 //note that we assumed graph to be connected
374 vector<Node *> Graph::reverseTopologicalSort() const{
375   vector <Node *> toReturn;
376   vector<Node *> lt=getAllNodes();
377   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
378     if((*LI)->getWeight()!=GREY && (*LI)->getWeight()!=BLACK)
379       DFS_Visit(*LI, toReturn);
380   }
381   return toReturn;
382 }
383
384 //a private method for doing DFS traversal of graph
385 //this is used in determining the reverse topological sort 
386 //of the graph
387 void Graph::DFS_Visit(Node *nd, vector<Node *> &toReturn) const {
388   nd->setWeight(GREY);
389   vector<Node *> lt=getSuccNodes(nd);
390   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
391     if((*LI)->getWeight()!=GREY && (*LI)->getWeight()!=BLACK)
392       DFS_Visit(*LI, toReturn);
393   }
394   toReturn.push_back(nd);
395 }
396
397 //Ordinarily, the graph is directional
398 //this converts the graph into an 
399 //undirectional graph
400 //This is done by adding an edge
401 //v->u for all existing edges u->v
402 void Graph::makeUnDirectional(){
403   vector<Node* > allNodes=getAllNodes();
404   for(vector<Node *>::iterator NI=allNodes.begin(), NE=allNodes.end(); NI!=NE; 
405       ++NI) {
406     nodeList nl=getNodeList(*NI);
407     for(nodeList::iterator NLI=nl.begin(), NLE=nl.end(); NLI!=NLE; ++NLI){
408       Edge ed(NLI->element, *NI, NLI->weight);
409       if(!hasEdgeAndWt(ed)){
410         DEBUG(cerr<<"######doesn't hv\n";
411               printEdge(ed));
412         addEdgeForce(ed);
413       }
414     }
415   }
416 }
417
418 //reverse the sign of weights on edges
419 //this way, max-spanning tree could be obtained
420 //usin min-spanning tree, and vice versa
421 void Graph::reverseWts(){
422   vector<Node *> allNodes=getAllNodes();
423   for(vector<Node *>::iterator NI=allNodes.begin(), NE=allNodes.end(); NI!=NE; 
424       ++NI) {
425     nodeList node_list=getNodeList(*NI);
426     for(nodeList::iterator NLI=nodes[*NI].begin(), NLE=nodes[*NI].end(); 
427         NLI!=NLE; ++NLI)
428       NLI->weight=-NLI->weight;
429   }
430 }
431
432
433 //getting the backedges in a graph
434 //Its a variation of DFS to get the backedges in the graph
435 //We get back edges by associating a time
436 //and a color with each vertex.
437 //The time of a vertex is the time when it was first visited
438 //The color of a vertex is initially WHITE,
439 //Changes to GREY when it is first visited,
440 //and changes to BLACK when ALL its neighbors
441 //have been visited
442 //So we have a back edge when we meet a successor of
443 //a node with smaller time, and GREY color
444 void Graph::getBackEdges(vector<Edge > &be) const{
445   map<Node *, Color > color;
446   map<Node *, int > d;
447   vector<Node *> allNodes=getAllNodes();
448   int time=0;
449   for(vector<Node *>::const_iterator NI=allNodes.begin(), NE=allNodes.end(); 
450       NI!=NE; ++NI){
451     if(color[*NI]!=GREY && color[*NI]!=BLACK)
452       getBackEdgesVisit(*NI, be, color, d, time);
453   }
454 }
455
456 //helper function to get back edges: it is called by 
457 //the "getBackEdges" function above
458 void Graph::getBackEdgesVisit(Node *u, vector<Edge > &be,
459                               map<Node *, Color > &color,
460                               map<Node *, int > &d, int &time) const{
461   color[u]=GREY;
462   time++;
463   d[u]=time;
464
465   vector<graphListElement> succ_list=getNodeList(u);
466   for(vector<graphListElement>::const_iterator vl=succ_list.begin(), 
467         ve=succ_list.end(); vl!=ve; ++vl){
468     Node *v=vl->element;
469   //  for(vector<Node *>::const_iterator v=succ_list.begin(), ve=succ_list.end(); 
470   //  v!=ve; ++v){
471
472     if(color[v]!=GREY && color[v]!=BLACK){
473       getBackEdgesVisit(v, be, color, d, time);
474     }
475
476     //now checking for d and f vals
477     if(color[v]==GREY){
478       //so v is ancestor of u if time of u > time of v
479       if(d[u] >= d[v]){
480         Edge *ed=new Edge(u, v,vl->weight, vl->randId);
481         if (!(*u == *getExit() && *v == *getRoot()))
482           be.push_back(*ed);      // choose the forward edges
483       }
484     }
485   }
486   color[u]=BLACK;//done with visiting the node and its neighbors
487 }
488
489