changes.
[IRC.git] / Robust / src / Analysis / Loops / CSE.java
1 package Analysis.Loops;
2
3 import IR.Flat.*;
4 import IR.TypeUtil;
5 import IR.Operation;
6 import IR.FieldDescriptor;
7 import IR.MethodDescriptor;
8 import IR.TypeDescriptor;
9 import java.util.Map;
10 import java.util.Iterator;
11 import java.util.Hashtable;
12 import java.util.HashSet;
13 import java.util.Set;
14
15 public class CSE {
16   GlobalFieldType gft;
17   TypeUtil typeutil;
18   public CSE(GlobalFieldType gft, TypeUtil typeutil) {
19     this.gft=gft;
20     this.typeutil=typeutil;
21   }
22
23   public void doAnalysis(FlatMethod fm) {
24     Hashtable<FlatNode,Hashtable<Expression, TempDescriptor>> availexpr=new Hashtable<FlatNode,Hashtable<Expression, TempDescriptor>>();
25     HashSet toprocess=new HashSet();
26     HashSet discovered=new HashSet();
27     toprocess.add(fm);
28     discovered.add(fm);
29     while(!toprocess.isEmpty()) {
30       FlatNode fn=(FlatNode)toprocess.iterator().next();
31       toprocess.remove(fn);
32       for(int i=0; i<fn.numNext(); i++) {
33         FlatNode nnext=fn.getNext(i);
34         if (!discovered.contains(nnext)) {
35           toprocess.add(nnext);
36           discovered.add(nnext);
37         }
38       }
39       Hashtable<Expression, TempDescriptor> tab=computeIntersection(fn, availexpr);
40       
41 //      if(tab.size()>1000){
42 //        System.out.println("Skipping CSE of "+fm.getMethod()+" due to size.");
43 //        return;
44 //      }
45       
46       //Do kills of expression/variable mappings
47       TempDescriptor[] write=fn.writesTemps();
48       for(int i=0; i<write.length; i++) {
49         for(Iterator it=tab.entrySet().iterator(); it.hasNext(); ) {
50           Map.Entry m=(Map.Entry)it.next();
51           TempDescriptor td=(TempDescriptor)m.getValue();
52           if(td.equals(write[i])){
53             it.remove();
54           }        
55         }
56       }
57       
58       
59       
60
61       switch(fn.kind()) {
62       case FKind.FlatAtomicEnterNode:
63       {
64         killexpressions(tab, null, null, true);
65         break;
66       }
67
68       case FKind.FlatCall:
69       {
70         FlatCall fc=(FlatCall) fn;
71         MethodDescriptor md=fc.getMethod();
72         Set<FieldDescriptor> fields=gft.getFieldsAll(md);
73         Set<TypeDescriptor> arrays=gft.getArraysAll(md);
74         killexpressions(tab, fields, arrays, gft.containsAtomicAll(md)||gft.containsBarrierAll(md));
75         break;
76       }
77
78       case FKind.FlatOpNode:
79       {
80         FlatOpNode fon=(FlatOpNode) fn;
81         Expression e=new Expression(fon.getLeft(), fon.getRight(), fon.getOp());
82         tab.put(e, fon.getDest());
83         break;
84       }
85
86       case FKind.FlatSetFieldNode:
87       {
88         FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
89         Set<FieldDescriptor> fields=new HashSet<FieldDescriptor>();
90         fields.add(fsfn.getField());
91         killexpressions(tab, fields, null, false);
92         Expression e=new Expression(fsfn.getDst(), fsfn.getField());
93         tab.put(e, fsfn.getSrc());
94         break;
95       }
96
97       case FKind.FlatFieldNode:
98       {
99         FlatFieldNode ffn=(FlatFieldNode)fn;
100         Expression e=new Expression(ffn.getSrc(), ffn.getField());
101         tab.put(e, ffn.getDst());
102         break;
103       }
104
105       case FKind.FlatSetElementNode:
106       {
107         FlatSetElementNode fsen=(FlatSetElementNode)fn;
108         Expression e=new Expression(fsen.getDst(),fsen.getIndex());
109         tab.put(e, fsen.getSrc());
110         break;
111       }
112
113       case FKind.FlatElementNode:
114       {
115         FlatElementNode fen=(FlatElementNode)fn;
116         Expression e=new Expression(fen.getSrc(),fen.getIndex());
117         tab.put(e, fen.getDst());
118         break;
119       }
120
121       default:
122       }
123
124       if (write.length==1) {
125         TempDescriptor w=write[0];
126         for(Iterator it=tab.entrySet().iterator(); it.hasNext(); ) {
127           Map.Entry m=(Map.Entry)it.next();
128           Expression e=(Expression)m.getKey();
129           if (e.a==w||e.b==w)
130             it.remove();
131         }
132       }
133       if (!availexpr.containsKey(fn)||!availexpr.get(fn).equals(tab)) {
134         availexpr.put(fn, tab);
135         for(int i=0; i<fn.numNext(); i++) {
136           FlatNode nnext=fn.getNext(i);
137           toprocess.add(nnext);
138         }
139       }
140     }
141
142     doOptimize(fm, availexpr);
143   }
144
145   public void doOptimize(FlatMethod fm, Hashtable<FlatNode,Hashtable<Expression, TempDescriptor>> availexpr) {
146     Hashtable<FlatNode, FlatNode> replacetable=new Hashtable<FlatNode, FlatNode>();
147     for(Iterator<FlatNode> it=fm.getNodeSet().iterator(); it.hasNext(); ) {
148       FlatNode fn=it.next();
149       Hashtable<Expression, TempDescriptor> tab=computeIntersection(fn, availexpr);
150       switch(fn.kind()) {
151       case FKind.FlatOpNode:
152       {
153         FlatOpNode fon=(FlatOpNode) fn;
154         Expression e=new Expression(fon.getLeft(), fon.getRight(),fon.getOp());
155         if (tab.containsKey(e)) {
156           TempDescriptor t=tab.get(e);
157           FlatNode newfon=new FlatOpNode(fon.getDest(),t,null,new Operation(Operation.ASSIGN));
158           replacetable.put(fon,newfon);
159         }
160         break;
161       }
162
163       case FKind.FlatFieldNode:
164       {
165         FlatFieldNode ffn=(FlatFieldNode)fn;
166         Expression e=new Expression(ffn.getSrc(), ffn.getField());
167         if (tab.containsKey(e)) {
168           TempDescriptor t=tab.get(e);
169           FlatNode newfon=new FlatOpNode(ffn.getDst(),t,null,new Operation(Operation.ASSIGN));
170           replacetable.put(ffn,newfon);
171         }
172         break;
173       }
174
175       case FKind.FlatElementNode:
176       {
177         FlatElementNode fen=(FlatElementNode)fn;
178         Expression e=new Expression(fen.getSrc(),fen.getIndex());
179         if (tab.containsKey(e)) {
180           TempDescriptor t=tab.get(e);
181           FlatNode newfon=new FlatOpNode(fen.getDst(),t,null,new Operation(Operation.ASSIGN));
182           replacetable.put(fen,newfon);
183         }
184         break;
185       }
186
187       default:
188       }
189     }
190     for(Iterator<FlatNode> it=replacetable.keySet().iterator(); it.hasNext(); ) {
191       FlatNode fn=it.next();
192       FlatNode newfn=replacetable.get(fn);
193       fn.replace(newfn);
194     }
195   }
196
197   public Hashtable<Expression, TempDescriptor> computeIntersection(FlatNode fn, Hashtable<FlatNode,Hashtable<Expression, TempDescriptor>> availexpr) {
198     Hashtable<Expression, TempDescriptor> tab=new Hashtable<Expression, TempDescriptor>();
199     boolean first=true;
200
201     //compute intersection
202     for(int i=0; i<fn.numPrev(); i++) {
203       FlatNode prev=fn.getPrev(i);
204       if (first) {
205         if (availexpr.containsKey(prev)) {
206           tab.putAll(availexpr.get(prev));
207           first=false;
208         }
209       } else {
210         if (availexpr.containsKey(prev)) {
211           Hashtable<Expression, TempDescriptor> table=availexpr.get(prev);
212           for(Iterator mapit=tab.entrySet().iterator(); mapit.hasNext(); ) {
213             Object entry=mapit.next();
214             if (!table.contains(entry))
215               mapit.remove();
216           }
217         }
218       }
219     }
220     return tab;
221   }
222
223   public void killexpressions(Hashtable<Expression, TempDescriptor> tab, Set<FieldDescriptor> fields, Set<TypeDescriptor> arrays, boolean killall) {
224     for(Iterator it=tab.entrySet().iterator(); it.hasNext(); ) {
225       Map.Entry m=(Map.Entry)it.next();
226       Expression e=(Expression)m.getKey();
227       if (killall&&(e.f!=null||e.a!=null))
228         it.remove();
229       else if (e.f!=null&&fields!=null&&fields.contains(e.f))
230         it.remove();
231       else if ((e.a!=null)&&(arrays!=null)) {
232         for(Iterator<TypeDescriptor> arit=arrays.iterator(); arit.hasNext(); ) {
233           TypeDescriptor artd=arit.next();
234           if (typeutil.isSuperorType(artd,e.a.getType())||
235               typeutil.isSuperorType(e.a.getType(),artd)) {
236             it.remove();
237             break;
238           }
239         }
240       }
241     }
242   }
243 }
244
245 class Expression {
246   Operation op;
247   TempDescriptor a;
248   TempDescriptor b;
249   FieldDescriptor f;
250   Expression(TempDescriptor a, TempDescriptor b, Operation op) {
251     this.a=a;
252     this.b=b;
253     this.op=op;
254   }
255   Expression(TempDescriptor a, FieldDescriptor f) {
256     this.a=a;
257     this.f=f;
258   }
259   Expression(TempDescriptor a, TempDescriptor index) {
260     this.a=a;
261     this.b=index;
262   }
263   public int hashCode() {
264     int h=0;
265     h^=a.hashCode();
266     if (op!=null)
267       h^=op.getOp();
268     if (b!=null)
269       h^=b.hashCode();
270     if (f!=null)
271       h^=f.hashCode();
272     return h;
273   }
274   public boolean equals(Object o) {
275     Expression e=(Expression)o;
276     if (a!=e.a||f!=e.f||b!=e.b)
277       return false;
278     if (op!=null)
279       return op.getOp()==e.op.getOp();
280     return true;
281   }
282 }