Updated CallGraph to keep an inverse map (callee->caller set)
[IRC.git] / Robust / src / Analysis / CallGraph / CallGraph.java
1 package Analysis.CallGraph;
2 import IR.State;
3 import IR.Flat.FlatMethod;
4 import IR.Flat.FlatNode;
5 import IR.Flat.FlatCall;
6 import IR.Flat.FKind;
7 import IR.Descriptor;
8 import IR.ClassDescriptor;
9 import IR.MethodDescriptor;
10 import IR.TaskDescriptor;
11 import IR.TypeDescriptor;
12 import java.util.*;
13 import java.io.*;
14
15 public class CallGraph {
16     private State state;
17
18     // MethodDescriptor maps to HashSet<MethodDescriptor>
19     private Hashtable mapVirtual2ImplementationSet;
20
21     // MethodDescriptor or TaskDescriptor maps to HashSet<MethodDescriptor>
22     private Hashtable mapCaller2CalleeSet;
23
24     // MethodDescriptor maps to HashSet<MethodDescriptor or TaskDescriptor>
25     private Hashtable mapCallee2CallerSet;
26
27     public CallGraph(State state) {
28         this.state=state;
29         mapVirtual2ImplementationSet = new Hashtable();
30         mapCaller2CalleeSet          = new Hashtable();
31         mapCallee2CallerSet          = new Hashtable();
32         buildVirtualMap();
33         buildGraph();
34     }
35     
36     // build a mapping of virtual methods to all
37     // possible implementations of that method
38     private void buildVirtualMap() {
39         //Iterator through classes
40         Iterator it=state.getClassSymbolTable().getDescriptorsIterator();
41         while(it.hasNext()) {
42             ClassDescriptor cn=(ClassDescriptor)it.next();
43             Iterator methodit=cn.getMethods();
44             //Iterator through methods
45             while(methodit.hasNext()) {
46                 MethodDescriptor md=(MethodDescriptor)methodit.next();
47                 if (md.isStatic()||md.getReturnType()==null)
48                     continue;
49                 ClassDescriptor superdesc=cn.getSuperDesc();
50                 if (superdesc!=null) {
51                     Set possiblematches=superdesc.getMethodTable().getSet(md.getSymbol());
52                     boolean foundmatch=false;
53                     for(Iterator matchit=possiblematches.iterator();matchit.hasNext();) {
54                         MethodDescriptor matchmd=(MethodDescriptor)matchit.next();
55                         if (md.matches(matchmd)) {
56                             if (!mapVirtual2ImplementationSet.containsKey(matchmd))
57                                 mapVirtual2ImplementationSet.put(matchmd,new HashSet());
58                             ((HashSet)mapVirtual2ImplementationSet.get(matchmd)).add(md);
59                             break;
60                         }
61                     }
62                 }
63             }
64         }
65     }
66
67     public Set getMethods(MethodDescriptor md, TypeDescriptor type) {
68         return getMethods(md);
69     }
70
71     /** Given a call to MethodDescriptor, lists the methods which
72         could actually be called due to virtual dispatch. */
73     public Set getMethods(MethodDescriptor md) {
74         HashSet ns=new HashSet();
75         ns.add(md);
76         Set s=(Set)mapVirtual2ImplementationSet.get(md);
77         if (s!=null)
78             for(Iterator it=s.iterator();it.hasNext();) {
79                 MethodDescriptor md2=(MethodDescriptor)it.next();
80                 ns.addAll(getMethods(md2));
81             }
82         return ns;
83     }
84
85     /** Given a call to MethodDescriptor, lists the methods which
86         could actually be call by that method. */
87     public Set getMethodCalls(MethodDescriptor md) {
88         HashSet ns=new HashSet();
89         ns.add(md);
90         Set s=(Set)mapCaller2CalleeSet.get(md);
91         if (s!=null)
92             for(Iterator it=s.iterator();it.hasNext();) {
93                 MethodDescriptor md2=(MethodDescriptor)it.next();
94                 ns.addAll(getMethodCalls(md2));
95             }
96         return ns;
97     }
98
99     private void buildGraph() { 
100         Iterator it=state.getClassSymbolTable().getDescriptorsIterator();
101         while(it.hasNext()) {
102             ClassDescriptor cn=(ClassDescriptor)it.next();
103             Iterator methodit=cn.getMethods();
104             //Iterator through methods
105             while(methodit.hasNext()) {
106                 MethodDescriptor md=(MethodDescriptor)methodit.next();
107                 analyzeMethod( (Object)md, state.getMethodFlat(md) );
108             }
109         }
110         it=state.getTaskSymbolTable().getDescriptorsIterator();
111         while(it.hasNext()) {
112             TaskDescriptor td=(TaskDescriptor)it.next();
113             analyzeMethod( (Object)td, state.getMethodFlat(td) );
114         }
115     }
116
117     private void analyzeMethod(Object caller, FlatMethod fm) {
118         HashSet toexplore=new HashSet();
119         toexplore.add(fm);
120         HashSet explored=new HashSet();
121         //look at all the nodes in the flat representation
122         while(!toexplore.isEmpty()) {
123             FlatNode fn=(FlatNode)(toexplore.iterator()).next();
124             toexplore.remove(fn);
125             explored.add(fn);
126             for(int i=0;i<fn.numNext();i++) {
127                 FlatNode fnnext=fn.getNext(i);
128                 if (!explored.contains(fnnext))
129                     toexplore.add(fnnext);
130             }
131             if (fn.kind()==FKind.FlatCall) {
132                 FlatCall fc=(FlatCall)fn;
133                 MethodDescriptor calledmethod=fc.getMethod();
134                 Set methodsthatcouldbecalled=fc.getThis()==null?getMethods(calledmethod):
135                     getMethods(calledmethod, fc.getThis().getType());
136
137                 // add caller -> callee maps
138                 if( !mapCaller2CalleeSet.containsKey( caller ) ) {
139                     mapCaller2CalleeSet.put( caller, new HashSet() );
140                 }
141                 ((HashSet)mapCaller2CalleeSet.get( caller )).addAll( methodsthatcouldbecalled );
142
143                 // add callee -> caller maps
144                 Iterator calleeItr = methodsthatcouldbecalled.iterator();
145                 while( calleeItr.hasNext() ) {
146                     MethodDescriptor callee = (MethodDescriptor) calleeItr.next();
147                     if( !mapCallee2CallerSet.containsKey( callee ) ) {
148                         mapCallee2CallerSet.put( callee, new HashSet() );
149                     }
150                     ((HashSet)mapCallee2CallerSet.get( callee )).add( caller );
151                 }
152             }
153         }
154     }
155
156     public void writeToDot( String graphName )  throws java.io.IOException {
157         // each task or method only needs to be labeled once
158         // in a dot file
159         HashSet labeledInDot = new HashSet();
160
161         // write out the call graph using the callees mapping
162         BufferedWriter bw = new BufferedWriter( new FileWriter( graphName+"byCallees.dot" ) );
163         bw.write( "digraph "+graphName+"byCallees {\n" );
164         Iterator mapItr = mapCallee2CallerSet.entrySet().iterator();
165         while( mapItr.hasNext() ) {
166             Map.Entry        me        = (Map.Entry)        mapItr.next();
167             MethodDescriptor callee    = (MethodDescriptor) me.getKey();
168             HashSet          callerSet = (HashSet)          me.getValue();
169
170             if( !labeledInDot.contains( callee ) ) {
171                 labeledInDot.add( callee );
172                 bw.write( "  " + callee.getNum() + "[label=\"" + callee + "\"];\n" );
173             }
174
175             Iterator callerItr = callerSet.iterator();
176             while( callerItr.hasNext() ) {
177                 Descriptor caller = (Descriptor) callerItr.next();
178
179                 if( !labeledInDot.contains( caller ) ) {
180                     labeledInDot.add( caller );
181                     bw.write( "  " + caller.getNum() + "[label=\"" + caller + "\"];\n" );
182                 }
183
184                 bw.write( "  " + callee.getNum() + "->" + caller.getNum() + ";\n" );
185             }
186         }
187         bw.write( "}\n" );
188         bw.close();
189
190         // write out the call graph (should be equivalent) by
191         // using the callers mapping
192         labeledInDot = new HashSet();
193         bw = new BufferedWriter( new FileWriter( graphName+"byCallers.dot" ) );
194         bw.write( "digraph "+graphName+"byCallers {\n" );
195         mapItr = mapCaller2CalleeSet.entrySet().iterator();
196         while( mapItr.hasNext() ) {
197             Map.Entry  me        = (Map.Entry)  mapItr.next();
198             Descriptor caller    = (Descriptor) me.getKey();
199             HashSet    calleeSet = (HashSet)    me.getValue();
200
201             if( !labeledInDot.contains( caller ) ) {
202                 labeledInDot.add( caller );
203                 bw.write( "  " + caller.getNum() + "[label=\"" + caller + "\"];\n" );
204             }
205
206             Iterator calleeItr = calleeSet.iterator();
207             while( calleeItr.hasNext() ) {
208                 MethodDescriptor callee = (MethodDescriptor) calleeItr.next();
209
210                 if( !labeledInDot.contains( callee ) ) {
211                     labeledInDot.add( callee );
212                     bw.write( "  " + callee.getNum() + "[label=\"" + callee + "\"];\n" );
213                 }
214
215                 bw.write( "  " + callee.getNum() + "->" + caller.getNum() + ";\n" );
216             }
217         }
218         bw.write( "}\n" );
219         bw.close();
220     }
221 }