changes.
[IRC.git] / Robust / src / Analysis / Loops / localCSE.java
1 package Analysis.Loops;
2
3 import IR.MethodDescriptor;
4 import IR.TypeDescriptor;
5 import IR.TypeUtil;
6 import IR.Operation;
7 import IR.Flat.*;
8 import IR.FieldDescriptor;
9 import java.util.Set;
10 import java.util.HashSet;
11 import java.util.Hashtable;
12 import java.util.Iterator;
13
14 public class localCSE {
15   GlobalFieldType gft;
16   TypeUtil typeutil;
17   public localCSE(GlobalFieldType gft, TypeUtil typeutil) {
18     this.gft=gft;
19     this.typeutil=typeutil;
20   }
21   int index;
22
23   public Group getGroup(Hashtable<LocalExpression, Group> tab, TempDescriptor t) {
24     LocalExpression e=new LocalExpression(t);
25     return getGroup(tab, e);
26   }
27   public Group getGroup(Hashtable<LocalExpression, Group> tab, LocalExpression e) {
28     if (tab.containsKey(e))
29       return tab.get(e);
30     else {
31       Group g=new Group(index++);
32       g.set.add(e);
33       tab.put(e,g);
34       return g;
35     }
36   }
37   public TempDescriptor getTemp(Group g) {
38     for(Iterator it=g.set.iterator(); it.hasNext(); ) {
39       LocalExpression e=(LocalExpression)it.next();
40       if (e.t!=null)
41         return e.t;
42     }
43     return null;
44   }
45
46   public void doAnalysis(FlatMethod fm) {
47     Set nodes=fm.getNodeSet();
48     HashSet<FlatNode> toanalyze=new HashSet<FlatNode>();
49     for(Iterator it=nodes.iterator(); it.hasNext(); ) {
50       FlatNode fn=(FlatNode)it.next();
51       if (fn.numPrev()>1)
52         toanalyze.add(fn);
53     }
54     for(Iterator<FlatNode> it=toanalyze.iterator(); it.hasNext(); ) {
55       FlatNode fn=it.next();
56       Hashtable<LocalExpression, Group> table=new Hashtable<LocalExpression,Group>();
57       do {
58         index=0;
59         switch(fn.kind()) {
60         case FKind.FlatOpNode: {
61           FlatOpNode fon=(FlatOpNode)fn;
62           Group left=getGroup(table, fon.getLeft());
63           Group right=getGroup(table, fon.getRight());
64           LocalExpression dst=new LocalExpression(fon.getDest());
65           if (fon.getOp().getOp()==Operation.ASSIGN) {
66             left.set.add(dst);
67             kill(table, fon.getDest());
68             table.put(dst, left);
69           } else {
70             LocalExpression e=new LocalExpression(left, right, fon.getOp());
71             Group g=getGroup(table,e);
72             TempDescriptor td=getTemp(g);
73             if (td!=null) {
74               FlatNode nfon=new FlatOpNode(fon.getDest(),td,null,new Operation(Operation.ASSIGN));
75               fn.replace(nfon);
76             }
77             g.set.add(dst);
78             kill(table, fon.getDest());
79             table.put(dst,g);
80           }
81           break;
82         }
83
84         case FKind.FlatLiteralNode: {
85           FlatLiteralNode fln=(FlatLiteralNode)fn;
86           LocalExpression e=new LocalExpression(fln.getValue());
87           Group src=getGroup(table, e);
88           LocalExpression dst=new LocalExpression(fln.getDst());
89           src.set.add(dst);
90           kill(table, fln.getDst());
91           table.put(dst, src);
92           break;
93         }
94
95         case FKind.FlatFieldNode: {
96           FlatFieldNode ffn=(FlatFieldNode) fn;
97           Group src=getGroup(table, ffn.getSrc());
98           LocalExpression e=new LocalExpression(src, ffn.getField());
99           Group srcf=getGroup(table, e);
100           LocalExpression dst=new LocalExpression(ffn.getDst());
101           TempDescriptor td=getTemp(srcf);
102           if (td!=null) {
103             FlatOpNode fon=new FlatOpNode(ffn.getDst(),td,null,new Operation(Operation.ASSIGN));
104             fn.replace(fon);
105           }
106           srcf.set.add(dst);
107           kill(table, ffn.getDst());
108           table.put(dst, srcf);
109           break;
110         }
111
112         case FKind.FlatElementNode: {
113           FlatElementNode fen=(FlatElementNode) fn;
114           Group src=getGroup(table, fen.getSrc());
115           Group index=getGroup(table, fen.getIndex());
116           LocalExpression e=new LocalExpression(src, fen.getSrc().getType(), index);
117           Group srcf=getGroup(table, e);
118           LocalExpression dst=new LocalExpression(fen.getDst());
119           TempDescriptor td=getTemp(srcf);
120           if (td!=null) {
121             FlatOpNode fon=new FlatOpNode(fen.getDst(),td,null,new Operation(Operation.ASSIGN));
122             fn.replace(fon);
123           }
124           srcf.set.add(dst);
125           kill(table, fen.getDst());
126           table.put(dst, srcf);
127           break;
128         }
129
130         case FKind.FlatSetFieldNode: {
131           FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
132           Group dst=getGroup(table, fsfn.getDst());
133           LocalExpression e=new LocalExpression(dst, fsfn.getField());
134           Group dstf=getGroup(table, e);
135           LocalExpression src=new LocalExpression(fsfn.getSrc());
136           dstf.set.add(src);
137           HashSet<FieldDescriptor> fields=new HashSet<FieldDescriptor>();
138           fields.add(fsfn.getField());
139           kill(table, fields, null, false, false);
140           table.put(src, dstf);
141           break;
142         }
143
144         case FKind.FlatSetElementNode: {
145           FlatSetElementNode fsen=(FlatSetElementNode)fn;
146           Group dst=getGroup(table, fsen.getDst());
147           Group index=getGroup(table, fsen.getIndex());
148           LocalExpression e=new LocalExpression(dst, fsen.getDst().getType(), index);
149           Group dstf=getGroup(table, e);
150           LocalExpression src=new LocalExpression(fsen.getSrc());
151           dstf.set.add(src);
152           HashSet<TypeDescriptor> arrays=new HashSet<TypeDescriptor>();
153           arrays.add(fsen.getDst().getType());
154           kill(table, null, arrays, false, false);
155           table.put(src, dstf);
156           break;
157         }
158
159         case FKind.FlatCall: {
160           //do side effects
161           FlatCall fc=(FlatCall)fn;
162           MethodDescriptor md=fc.getMethod();
163           Set<FieldDescriptor> fields=gft.getFieldsAll(md);
164           Set<TypeDescriptor> arrays=gft.getArraysAll(md);
165           kill(table, fields, arrays, gft.containsAtomicAll(md), gft.containsBarrierAll(md));
166         }
167
168         default: {
169           TempDescriptor[] writes=fn.writesTemps();
170           for(int i=0; i<writes.length; i++) {
171             kill(table,writes[i]);
172           }
173         }
174         }
175       } while(fn.numPrev()==1);
176     }
177   }
178   public void kill(Hashtable<LocalExpression, Group> tab, Set<FieldDescriptor> fields, Set<TypeDescriptor> arrays, boolean isAtomic, boolean isBarrier) {
179     Set<LocalExpression> eset=tab.keySet();
180     for(Iterator<LocalExpression> it=eset.iterator(); it.hasNext(); ) {
181       LocalExpression e=it.next();
182       if (isBarrier) {
183         //make Barriers kill everything
184         it.remove();
185       } else if (isAtomic&&(e.td!=null||e.f!=null)) {
186         Group g=tab.get(e);
187         g.set.remove(e);
188         it.remove();
189       } else if (e.td!=null) {
190         //have array
191         TypeDescriptor artd=e.td;
192         for(Iterator<TypeDescriptor> arit=arrays.iterator(); arit.hasNext(); ) {
193           TypeDescriptor td=arit.next();
194           if (typeutil.isSuperorType(artd,td)||
195               typeutil.isSuperorType(td,artd)) {
196             Group g=tab.get(e);
197             g.set.remove(e);
198             it.remove();
199             break;
200           }
201         }
202       } else if (e.f!=null) {
203         if (fields.contains(e.f)) {
204           Group g=tab.get(e);
205           g.set.remove(e);
206           it.remove();
207         }
208       }
209     }
210   }
211   public void kill(Hashtable<LocalExpression, Group> tab, TempDescriptor t) {
212     LocalExpression e=new LocalExpression(t);
213     Group g=tab.get(e);
214     if (g!=null) {
215       tab.remove(e);
216       g.set.remove(e);
217     }
218   }
219 }
220
221 class Group {
222   HashSet set;
223   int i;
224   Group(int i) {
225     set=new HashSet();
226     this.i=i;
227   }
228   public int hashCode() {
229     return i;
230   }
231   public boolean equals(Object o) {
232     Group g=(Group)o;
233     return i==g.i;
234   }
235 }
236
237 class LocalExpression {
238   Operation op;
239   Object obj;
240   Group a;
241   Group b;
242   TempDescriptor t;
243   FieldDescriptor f;
244   TypeDescriptor td;
245   LocalExpression(TempDescriptor t) {
246     this.t=t;
247   }
248   LocalExpression(Object o) {
249     this.obj=o;
250   }
251   LocalExpression(Group a, Group b, Operation op) {
252     this.a=a;
253     this.b=b;
254     this.op=op;
255   }
256   LocalExpression(Group a, FieldDescriptor f) {
257     this.a=a;
258     this.f=f;
259   }
260   LocalExpression(Group a, TypeDescriptor td, Group index) {
261     this.a=a;
262     this.td=td;
263     this.b=index;
264   }
265   public int hashCode() {
266     int h=0;
267     if (td!=null)
268       h^=td.hashCode();
269     if (t!=null)
270       h^=t.hashCode();
271     if (a!=null)
272       h^=a.hashCode();
273     if (op!=null)
274       h^=op.getOp();
275     if (b!=null)
276       h^=b.hashCode();
277     if (f!=null)
278       h^=f.hashCode();
279     if (obj!=null)
280       h^=obj.hashCode();
281     return h;
282   }
283   public static boolean equiv(Object a, Object b) {
284     if (a!=null)
285       return a.equals(b);
286     else
287       return b==null;
288   }
289
290   public boolean equals(Object o) {
291     LocalExpression e=(LocalExpression)o;
292     if (!(equiv(a,e.a)&&equiv(f,e.f)&&equiv(b,e.b)&&
293           equiv(td,e.td)&&equiv(this.obj,e.obj)))
294       return false;
295     if (op!=null)
296       return op.getOp()==e.op.getOp();
297     else if (e.op!=null)
298       return false;
299     return true;
300   }
301 }