nasty bugs...finally fixed
[IRC.git] / Robust / src / Analysis / Loops / localCSE.java
1 package Analysis.Loops;
2
3 import IR.MethodDescriptor;
4 import IR.TypeDescriptor;
5 import IR.TypeUtil;
6 import IR.Operation;
7 import IR.Flat.*;
8 import IR.FieldDescriptor;
9 import java.util.Set;
10 import java.util.HashSet;
11 import java.util.Hashtable;
12 import java.util.Iterator;
13
14 public class localCSE {
15   GlobalFieldType gft;
16   TypeUtil typeutil;
17   public localCSE(GlobalFieldType gft, TypeUtil typeutil) {
18     this.gft=gft;
19     this.typeutil=typeutil;
20   }
21   int index;
22
23   public Group getGroup(Hashtable<LocalExpression, Group> tab, TempDescriptor t) {
24     LocalExpression e=new LocalExpression(t);
25     return getGroup(tab, e);
26   }
27   public Group getGroup(Hashtable<LocalExpression, Group> tab, LocalExpression e) {
28     if (tab.containsKey(e))
29       return tab.get(e);
30     else {
31       Group g=new Group(index++);
32       g.set.add(e);
33       tab.put(e,g);
34       return g;
35     }
36   }
37   public TempDescriptor getTemp(Group g) {
38     for(Iterator it=g.set.iterator();it.hasNext();) {
39       LocalExpression e=(LocalExpression)it.next();
40       if (e.t!=null)
41         return e.t;
42     }
43     return null;
44   }
45
46   public void doAnalysis(FlatMethod fm) {
47     Set nodes=fm.getNodeSet();
48     HashSet<FlatNode> toanalyze=new HashSet<FlatNode>();
49     for(Iterator it=nodes.iterator();it.hasNext();) {
50       FlatNode fn=(FlatNode)it.next();
51       if (fn.numPrev()>1)
52         toanalyze.add(fn);
53     }
54     for(Iterator<FlatNode> it=toanalyze.iterator();it.hasNext();) {
55       FlatNode fn=it.next();
56       Hashtable<LocalExpression, Group> table=new Hashtable<LocalExpression,Group>();
57       do {
58         index=0;
59         switch(fn.kind()) {
60         case FKind.FlatOpNode: {
61           FlatOpNode fon=(FlatOpNode)fn;
62           Group left=getGroup(table, fon.getLeft());
63           Group right=getGroup(table, fon.getRight());
64           LocalExpression dst=new LocalExpression(fon.getDest());
65           if (fon.getOp().getOp()==Operation.ASSIGN) {
66             left.set.add(dst);
67             kill(table, fon.getDest());
68             table.put(dst, left);
69           } else {
70             LocalExpression e=new LocalExpression(left, right, fon.getOp());
71             Group g=getGroup(table,e);
72             TempDescriptor td=getTemp(g);
73             if (td!=null) {
74               FlatNode nfon=new FlatOpNode(fon.getDest(),td,null,new Operation(Operation.ASSIGN));
75               fn.replace(nfon);
76             }
77             g.set.add(dst);
78             kill(table, fon.getDest());
79             table.put(dst,g);
80           }
81           break;
82         }
83         case FKind.FlatLiteralNode: {
84           FlatLiteralNode fln=(FlatLiteralNode)fn;
85           LocalExpression e=new LocalExpression(fln.getValue());
86           Group src=getGroup(table, e);
87           LocalExpression dst=new LocalExpression(fln.getDst());
88           src.set.add(dst);
89           kill(table, fln.getDst());
90           table.put(dst, src);
91           break;
92         }
93         case FKind.FlatFieldNode: {
94           FlatFieldNode ffn=(FlatFieldNode) fn;
95           Group src=getGroup(table, ffn.getSrc());
96           LocalExpression e=new LocalExpression(src, ffn.getField());
97           Group srcf=getGroup(table, e);
98           LocalExpression dst=new LocalExpression(ffn.getDst());
99           TempDescriptor td=getTemp(srcf);
100           if (td!=null) {
101             FlatOpNode fon=new FlatOpNode(ffn.getDst(),td,null,new Operation(Operation.ASSIGN));
102             fn.replace(fon);
103           }
104           srcf.set.add(dst);
105           kill(table, ffn.getDst());
106           table.put(dst, srcf);
107           break;
108         }
109         case FKind.FlatElementNode: {
110           FlatElementNode fen=(FlatElementNode) fn;
111           Group src=getGroup(table, fen.getSrc());
112           Group index=getGroup(table, fen.getIndex());
113           LocalExpression e=new LocalExpression(src, fen.getSrc().getType(), index);
114           Group srcf=getGroup(table, e);
115           LocalExpression dst=new LocalExpression(fen.getDst());
116           TempDescriptor td=getTemp(srcf);
117           if (td!=null) {
118             FlatOpNode fon=new FlatOpNode(fen.getDst(),td,null,new Operation(Operation.ASSIGN));
119             fn.replace(fon);
120           }
121           srcf.set.add(dst);
122           kill(table, fen.getDst());
123           table.put(dst, srcf);
124           break;
125         }
126         case FKind.FlatSetFieldNode: {
127           FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
128           Group dst=getGroup(table, fsfn.getDst());
129           LocalExpression e=new LocalExpression(dst, fsfn.getField());
130           Group dstf=getGroup(table, e);
131           LocalExpression src=new LocalExpression(fsfn.getSrc());
132           dstf.set.add(src);
133           HashSet<FieldDescriptor> fields=new HashSet<FieldDescriptor>();
134           fields.add(fsfn.getField());
135           kill(table, fields, null, false, false);
136           table.put(src, dstf);
137           break;
138         }
139         case FKind.FlatSetElementNode: {
140           FlatSetElementNode fsen=(FlatSetElementNode)fn;
141           Group dst=getGroup(table, fsen.getDst());
142           Group index=getGroup(table, fsen.getIndex());
143           LocalExpression e=new LocalExpression(dst, fsen.getDst().getType(), index);
144           Group dstf=getGroup(table, e);
145           LocalExpression src=new LocalExpression(fsen.getSrc());
146           dstf.set.add(src);
147           HashSet<TypeDescriptor> arrays=new HashSet<TypeDescriptor>();
148           arrays.add(fsen.getDst().getType());
149           kill(table, null, arrays, false, false);
150           table.put(src, dstf);
151           break;
152         }
153         case FKind.FlatCall:{
154           //do side effects
155           FlatCall fc=(FlatCall)fn;
156           MethodDescriptor md=fc.getMethod();
157           Set<FieldDescriptor> fields=gft.getFieldsAll(md);
158           Set<TypeDescriptor> arrays=gft.getArraysAll(md);
159           kill(table, fields, arrays, gft.containsAtomicAll(md), gft.containsBarrierAll(md));
160         }
161         default: {
162           TempDescriptor[] writes=fn.writesTemps();
163           for(int i=0;i<writes.length;i++) {
164             kill(table,writes[i]);
165           }
166         }
167         }
168       } while(fn.numPrev()==1);
169     }
170   }
171   public void kill(Hashtable<LocalExpression, Group> tab, Set<FieldDescriptor> fields, Set<TypeDescriptor> arrays, boolean isAtomic, boolean isBarrier) {
172     Set<LocalExpression> eset=tab.keySet();
173     for(Iterator<LocalExpression> it=eset.iterator();it.hasNext();) {
174       LocalExpression e=it.next();
175       if (isBarrier) {
176         //make Barriers kill everything
177         it.remove();
178       } else if (isAtomic&&(e.td!=null||e.f!=null)) {
179         Group g=tab.get(e);
180         g.set.remove(e);
181         it.remove();
182       } else if (e.td!=null) {
183         //have array
184         TypeDescriptor artd=e.td;
185         for(Iterator<TypeDescriptor> arit=arrays.iterator();arit.hasNext();) {
186           TypeDescriptor td=arit.next();
187           if (typeutil.isSuperorType(artd,td)||
188               typeutil.isSuperorType(td,artd)) {
189             Group g=tab.get(e);
190             g.set.remove(e);
191             it.remove();
192             break;
193           }
194         }
195       } else if (e.f!=null) {
196         if (fields.contains(e.f)) {
197           Group g=tab.get(e);
198           g.set.remove(e);
199           it.remove();
200         }
201       }
202     }
203   }
204   public void kill(Hashtable<LocalExpression, Group> tab, TempDescriptor t) {
205     LocalExpression e=new LocalExpression(t);
206     Group g=tab.get(e);
207     if (g!=null) {
208       tab.remove(e);
209       g.set.remove(e);
210     }
211   }
212 }
213
214 class Group {
215   HashSet set;
216   int i;
217   Group(int i) {
218     set=new HashSet();
219     this.i=i;
220   }
221   public int hashCode() {
222     return i;
223   }
224   public boolean equals(Object o) {
225     Group g=(Group)o;
226     return i==g.i;
227   }
228 }
229
230 class LocalExpression {
231   Operation op;
232   Object obj;
233   Group a;
234   Group b;
235   TempDescriptor t;
236   FieldDescriptor f;
237   TypeDescriptor td;
238   LocalExpression(TempDescriptor t) {
239     this.t=t;
240   }
241   LocalExpression(Object o) {
242     this.obj=o;
243   }
244   LocalExpression(Group a, Group b, Operation op) {
245     this.a=a;
246     this.b=b;
247     this.op=op;
248   }
249   LocalExpression(Group a, FieldDescriptor f) {
250     this.a=a;
251     this.f=f;
252   }
253   LocalExpression(Group a, TypeDescriptor td, Group index) {
254     this.a=a;
255     this.td=td;
256     this.b=index;
257   }
258   public int hashCode() {
259     int h=0;
260     if (td!=null)
261       h^=td.hashCode();
262     if (t!=null)
263       h^=t.hashCode();
264     if (a!=null)
265       h^=a.hashCode();
266     if (op!=null)
267       h^=op.getOp();
268     if (b!=null)
269       h^=b.hashCode();
270     if (f!=null)
271       h^=f.hashCode();
272     if (obj!=null)
273       h^=obj.hashCode();
274     return h;
275   }
276   public static boolean equiv(Object a, Object b) {
277     if (a!=null)
278       return a.equals(b);
279     else
280       return b==null;
281   }
282
283   public boolean equals(Object o) {
284     LocalExpression e=(LocalExpression)o;
285     if (!(equiv(a,e.a)&&equiv(f,e.f)&&equiv(b,e.b)&&
286           equiv(td,e.td)&&equiv(this.obj,e.obj)))
287       return false;
288     if (op!=null)
289       return op.getOp()==e.op.getOp();
290     else if (e.op!=null)
291       return false;
292     return true;
293   }
294 }