nasty bugs...finally fixed
[IRC.git] / Robust / src / Analysis / Loops / LoopOptimize.java
1 package Analysis.Loops;
2
3 import IR.Flat.*;
4 import IR.TypeUtil;
5 import IR.MethodDescriptor;
6 import IR.Operation;
7 import java.util.HashSet;
8 import java.util.Set;
9 import java.util.Vector;
10 import java.util.Iterator;
11 import java.util.Hashtable;
12
13 public class LoopOptimize {
14   LoopInvariant loopinv;
15   public LoopOptimize(GlobalFieldType gft, TypeUtil typeutil) {
16     loopinv=new LoopInvariant(typeutil,gft);
17   }
18   Hashtable<FlatNode, FlatNode> ntoomap;
19   Hashtable<FlatNode, FlatNode> clonemap;
20   Hashtable<FlatNode, FlatNode> map;
21
22   public void optimize(FlatMethod fm) {
23     loopinv.analyze(fm);
24     ntoomap=new Hashtable<FlatNode, FlatNode>();
25     map=new Hashtable<FlatNode, FlatNode>();
26     clonemap=new Hashtable<FlatNode, FlatNode>();
27     dooptimize(fm);
28   } 
29
30   private FlatNode ntooremap(FlatNode fn) {
31     while(ntoomap.containsKey(fn)) {
32       fn=ntoomap.get(fn);
33     }
34     return fn;
35   }
36
37   private FlatNode otonremap(FlatNode fn) {
38     while(map.containsKey(fn)) {
39       fn=map.get(fn);
40     }
41     return fn;
42   }
43
44   private void dooptimize(FlatMethod fm) {
45     Loops root=loopinv.root;
46     recurse(fm, root);
47   }
48   private void recurse(FlatMethod fm, Loops parent) {
49     for(Iterator lpit=parent.nestedLoops().iterator();lpit.hasNext();) {
50       Loops child=(Loops)lpit.next();
51       processLoop(fm, child);
52       recurse(fm, child);
53     }
54   }
55   public void processLoop(FlatMethod fm, Loops l) {
56     Set entrances=l.loopEntrances();
57     assert entrances.size()==1;
58     FlatNode entrance=(FlatNode)entrances.iterator().next();
59     if (loopinv.tounroll.contains(entrance)) {
60       unrollLoop(l, fm);
61     } else {
62       hoistOps(l);
63     }
64   }
65   public void hoistOps(Loops l) {
66     Set entrances=l.loopEntrances();
67     assert entrances.size()==1;
68     FlatNode entrance=(FlatNode)entrances.iterator().next();
69     Vector<FlatNode> tohoist=loopinv.table.get(entrance);
70     Set lelements=l.loopIncElements();
71     TempMap t=new TempMap();
72     TempMap tnone=new TempMap();
73     FlatNode first=null;
74     FlatNode last=null;
75     if (tohoist.size()==0)
76       return;
77
78     for(int i=0;i<tohoist.size();i++) {
79       FlatNode fn=tohoist.elementAt(i);
80       TempDescriptor[] writes=fn.writesTemps();
81
82       //deal with the possiblity we already hoisted this node
83       if (clonemap.containsKey(fn)) {
84         FlatNode fnnew=clonemap.get(fn);
85         TempDescriptor writenew[]=fnnew.writesTemps();
86         t.addPair(writes[0],writenew[0]);
87         if (fn==entrance)
88           entrance=map.get(fn);
89         continue;
90       }
91
92       //build hoisted version
93       FlatNode fnnew=fn.clone(tnone);
94       fnnew.rewriteUse(t);
95
96       for(int j=0;j<writes.length;j++) {
97         if (writes[j]!=null) {
98           TempDescriptor cp=writes[j].createNew("a");
99           t.addPair(writes[j],cp);
100         }
101       }
102       fnnew.rewriteDef(t);
103
104       //store mapping
105       clonemap.put(fn, fnnew);
106
107       //add hoisted version to chain
108       if (first==null)
109         first=fnnew;
110       else
111         last.addNext(fnnew);
112       last=fnnew;
113
114       /* Splice out old node */
115       if (writes.length==1) {
116         FlatOpNode fon=new FlatOpNode(writes[0], t.tempMap(writes[0]), null, new Operation(Operation.ASSIGN));
117         fn.replace(fon);
118         ntoomap.put(fon, fn);
119         map.put(fn, fon);
120         if (fn==entrance)
121           entrance=fon;
122       } else if (writes.length>1) {
123         throw new Error();
124       }
125     }
126     /* If the chain is empty, we can exit now */
127     if (first==null)
128       return;
129
130     /* The chain is built at this point. */
131     FlatNode[] prevarray=new FlatNode[entrance.numPrev()];
132     for(int i=0;i<entrance.numPrev();i++) {
133       prevarray[i]=entrance.getPrev(i);
134     }
135     for(int i=0;i<prevarray.length;i++) {
136       FlatNode prev=prevarray[i];
137
138       if (!lelements.contains(ntooremap(prev))) {
139         //need to fix this edge
140         for(int j=0;j<prev.numNext();j++) {
141           if (prev.getNext(j)==entrance)
142             prev.setNext(j, first);
143         }
144       }
145     }
146     last.addNext(entrance);
147   }
148
149   public void unrollLoop(Loops l, FlatMethod fm) {
150     assert l.loopEntrances().size()==1;
151     //deal with possibility that entrance has been hoisted
152     FlatNode entrance=(FlatNode)l.loopEntrances().iterator().next();
153     entrance=otonremap(entrance);
154
155     Set lelements=l.loopIncElements();
156
157     Set<FlatNode> tohoist=loopinv.hoisted;
158     Hashtable<FlatNode, TempDescriptor> temptable=new Hashtable<FlatNode, TempDescriptor>();
159     Hashtable<FlatNode, FlatNode> copytable=new Hashtable<FlatNode, FlatNode>();
160     Hashtable<FlatNode, FlatNode> copyendtable=new Hashtable<FlatNode, FlatNode>();
161
162     TempMap t=new TempMap();
163     /* Copy the nodes */
164     for(Iterator it=lelements.iterator();it.hasNext();) {
165       FlatNode fn=(FlatNode)it.next();
166       FlatNode nfn=otonremap(fn);
167
168       FlatNode copy=nfn.clone(t);
169       FlatNode copyend=copy;
170       if (tohoist.contains(fn)) {
171         //deal with the possiblity we already hoisted this node
172         if (clonemap.containsKey(fn)) {
173           FlatNode fnnew=clonemap.get(fn);
174           TempDescriptor writenew[]=fnnew.writesTemps();
175           temptable.put(nfn, writenew[0]);
176         } else {
177           TempDescriptor[] writes=nfn.writesTemps();
178           TempDescriptor tmp=writes[0];
179           TempDescriptor ntmp=tmp.createNew("b");
180           temptable.put(nfn, ntmp);
181           copyend=new FlatOpNode(ntmp, tmp, null, new Operation(Operation.ASSIGN));
182           copy.addNext(copyend);
183         }
184       }
185       copytable.put(nfn, copy);
186       copyendtable.put(nfn, copyend);
187     }
188
189     /* Store initial in set for loop header */
190     FlatNode[] prevarray=new FlatNode[entrance.numPrev()];
191     for(int i=0;i<entrance.numPrev();i++) {
192       prevarray[i]=entrance.getPrev(i);
193     }
194     FlatNode first=copytable.get(entrance);
195
196     /* Copy the internal edges */
197     for(Iterator it=lelements.iterator();it.hasNext();) {
198       FlatNode fn=(FlatNode)it.next();
199       fn=otonremap(fn);
200       FlatNode copyend=copyendtable.get(fn);
201       for(int i=0;i<fn.numNext();i++) {
202         FlatNode nnext=fn.getNext(i);
203         if (nnext==entrance) {
204           /* Back to loop header...point to old graph */
205           copyend.setNewNext(i,nnext);
206         } else if (lelements.contains(ntooremap(nnext))) {
207           /* In graph...point to first graph */
208           copyend.setNewNext(i,copytable.get(nnext));
209         } else {
210           /* Outside loop */
211           /* Just goto same place as before */
212           copyend.setNewNext(i,nnext);
213         }
214       }
215     }
216
217     /* Splice header in using original in set */
218     for(int i=0;i<prevarray.length;i++) {
219       FlatNode prev=prevarray[i];
220
221       if (!lelements.contains(ntooremap(prev))) {
222         //need to fix this edge
223         for(int j=0;j<prev.numNext();j++) {
224           if (prev.getNext(j)==entrance) {
225             prev.setNext(j, first);
226           }
227         }
228       }
229     }
230
231     /* Splice out loop invariant stuff */
232     for(Iterator it=lelements.iterator();it.hasNext();) {
233       FlatNode fn=(FlatNode)it.next();
234       FlatNode nfn=otonremap(fn);
235       if (tohoist.contains(fn)) {
236         TempDescriptor[] writes=nfn.writesTemps();
237         TempDescriptor tmp=writes[0];
238         FlatOpNode fon=new FlatOpNode(tmp, temptable.get(nfn), null, new Operation(Operation.ASSIGN));
239         nfn.replace(fon);
240       }
241     }
242   }
243 }