changes.
[IRC.git] / Robust / src / Analysis / Loops / LoopTerminate.java
1 package Analysis.Loops;
2
3 import java.util.HashSet;
4 import java.util.Hashtable;
5 import java.util.Iterator;
6 import java.util.Set;
7
8 import IR.Operation;
9 import IR.Flat.FKind;
10 import IR.Flat.FlatCondBranch;
11 import IR.Flat.FlatLiteralNode;
12 import IR.Flat.FlatMethod;
13 import IR.Flat.FlatNode;
14 import IR.Flat.FlatOpNode;
15 import IR.Flat.TempDescriptor;
16
17 public class LoopTerminate {
18
19   LoopInvariant loopInv;
20   Set<TempDescriptor> inductionSet;
21
22   public void terminateAnalysis(FlatMethod fm, LoopInvariant loopInv) {
23     this.loopInv = loopInv;
24     this.inductionSet = new HashSet<TempDescriptor>();
25     Loops loopFinder = loopInv.root;
26     if (loopFinder.nestedLoops().size() > 0) {
27       for (Iterator lpit = loopFinder.nestedLoops().iterator(); lpit.hasNext();) {
28         Loops loop = (Loops) lpit.next();
29         Set entrances = loop.loopEntrances();
30         processLoop(fm, loop, loopInv);
31       }
32     }
33   }
34
35   public void processLoop(FlatMethod fm, Loops l, LoopInvariant loopInv) {
36
37     boolean changed = true;
38
39     Set elements = l.loopIncElements();
40     Set toprocess = l.loopIncElements();
41     Set entrances = l.loopEntrances();
42     assert entrances.size() == 1;
43     FlatNode entrance = (FlatNode) entrances.iterator().next();
44
45     Hashtable<TempDescriptor, FlatNode> inductionVar2DefNode =
46         new Hashtable<TempDescriptor, FlatNode>();
47
48     Hashtable<TempDescriptor, TempDescriptor> derivedVar2basicInduction =
49         new Hashtable<TempDescriptor, TempDescriptor>();
50
51     Set<FlatNode> computed = new HashSet<FlatNode>();
52
53     int backEdgeWithInductionCond = 0;
54
55     // #1 find out basic induction variable
56     // variable i is a basic induction variable in loop if the only definitions
57     // of i within L are of the form i=i+c or i=i-c where c is loop invariant
58     for (Iterator elit = elements.iterator(); elit.hasNext();) {
59       FlatNode fn = (FlatNode) elit.next();
60       if (fn.kind() == FKind.FlatOpNode) {
61         FlatOpNode fon = (FlatOpNode) fn;
62         int op = fon.getOp().getOp();
63         if (op == Operation.ADD /* || op == Operation.SUB */) {
64           TempDescriptor tdLeft = fon.getLeft();
65           TempDescriptor tdRight = fon.getRight();
66
67           boolean isLeftLoopInvariant = isLoopInvariantVar(l, fn, tdLeft);
68           boolean isRightLoopInvariant = isLoopInvariantVar(l, fn, tdRight);
69
70           if (isLeftLoopInvariant ^ isRightLoopInvariant) {
71
72             TempDescriptor candidateTemp;
73
74             if (isLeftLoopInvariant) {
75               candidateTemp = tdRight;
76             } else {
77               candidateTemp = tdLeft;
78             }
79
80             Set<FlatNode> defSetOfLoop = getDefinitionInsideLoop(l, fn, candidateTemp);
81             if (defSetOfLoop.size() == 1) {
82               FlatNode defNode = defSetOfLoop.iterator().next();
83               assert defNode.readsTemps().length == 1;
84
85               TempDescriptor readTemp = defNode.readsTemps()[0];
86               if (readTemp.equals(fon.getDest())) {
87                 inductionVar2DefNode.put(candidateTemp, defSetOfLoop.iterator().next());
88                 inductionSet.add(candidateTemp);
89                 computed.add(fn);
90               }
91
92             }
93
94           }
95
96         }
97       }
98     }
99
100     // #2 detect derived induction variables
101     // variable k is a derived induction variable if
102     // 1) there is only one definition of k within the loop, of the form k=j*c
103     // or k=j+d where j is induction variable, c, d are loop-invariant
104     // 2) and if j is a derived induction variable in the family of i, then:
105     // (a) the only definition of j that reaches k is the one in the loop
106     // (b) and there is no definition of i on any path between the definition of
107     // j and the definition of k
108
109     Set<TempDescriptor> basicInductionSet = new HashSet<TempDescriptor>();
110     basicInductionSet.addAll(inductionSet);
111
112     while (changed) {
113       changed = false;
114       for (Iterator elit = elements.iterator(); elit.hasNext();) {
115         FlatNode fn = (FlatNode) elit.next();
116         if (!computed.contains(fn)) {
117           if (fn.kind() == FKind.FlatOpNode) {
118             FlatOpNode fon = (FlatOpNode) fn;
119             int op = fon.getOp().getOp();
120             if (op == Operation.ADD || op == Operation.MULT) {
121               TempDescriptor tdLeft = fon.getLeft();
122               TempDescriptor tdRight = fon.getRight();
123               TempDescriptor tdDest = fon.getDest();
124
125               boolean isLeftLoopInvariant = isLoopInvariantVar(l, fn, tdLeft);
126               boolean isRightLoopInvariant = isLoopInvariantVar(l, fn, tdRight);
127
128               if (isLeftLoopInvariant ^ isRightLoopInvariant) {
129                 TempDescriptor inductionOp;
130                 if (isLeftLoopInvariant) {
131                   inductionOp = tdRight;
132                 } else {
133                   inductionOp = tdLeft;
134                 }
135                 if (inductionSet.contains(inductionOp)) {
136                   // find new derived one k
137
138                   if (!basicInductionSet.contains(inductionOp)) {
139                     // check if only definition of j that reaches k is the one
140                     // in
141                     // the loop
142                     Set defSet = getDefinitionInsideLoop(l, fn, inductionOp);
143                     if (defSet.size() == 1) {
144                       // check if there is no def of i on any path bet' def of j
145                       // and def of k
146
147                       TempDescriptor originInduc = derivedVar2basicInduction.get(inductionOp);
148                       FlatNode defI = inductionVar2DefNode.get(originInduc);
149                       FlatNode defJ = inductionVar2DefNode.get(inductionOp);
150                       FlatNode defk = fn;
151
152                       if (!checkPath(defI, defJ, defk)) {
153                         continue;
154                       }
155
156                     }
157                   }
158                   // add new induction var
159
160                   Set<FlatNode> setUseNode = loopInv.usedef.useMap(fn, tdDest);
161                   assert setUseNode.size() == 1;
162                   assert setUseNode.iterator().next().writesTemps().length == 1;
163
164                   TempDescriptor derivedInd = setUseNode.iterator().next().writesTemps()[0];
165                   FlatNode defNode = setUseNode.iterator().next();
166
167                   computed.add(fn);
168                   computed.add(defNode);
169                   inductionSet.add(derivedInd);
170                   inductionVar2DefNode.put(derivedInd, defNode);
171                   derivedVar2basicInduction.put(derivedInd, inductionOp);
172                   changed = true;
173                 }
174
175               }
176
177             }
178
179           }
180         }
181
182       }
183     }
184
185     // #3 check condition branch
186     for (Iterator elit = elements.iterator(); elit.hasNext();) {
187       FlatNode fn = (FlatNode) elit.next();
188       if (fn.kind() == FKind.FlatCondBranch) {
189         FlatCondBranch fcb = (FlatCondBranch) fn;
190
191         if (fcb.isLoopBranch() || hasLoopExitNode(l, fcb, true)) {
192           // only need to care about conditional branch that leads it out of the
193           // loop
194           Set<FlatNode> condSet = getDefinitionInsideLoop(l, fn, fcb.getTest());
195           assert condSet.size() == 1;
196           FlatNode condFn = condSet.iterator().next();
197           if (condFn instanceof FlatOpNode) {
198             FlatOpNode condOp = (FlatOpNode) condFn;
199             // check if guard condition is composed only with induction
200             // variables
201             if (checkConditionNode(l, condOp)) {
202               backEdgeWithInductionCond++;
203             }
204           }
205         }
206       }
207
208     }
209
210     if (backEdgeWithInductionCond == 0) {
211       throw new Error("Loop may never terminate at "
212           + fm.getMethod().getClassDesc().getSourceFileName() + "::" + entrance.numLine);
213     }
214
215   }
216
217   private boolean checkPath(FlatNode def, FlatNode start, FlatNode end) {
218
219     // return true if there is no def in-bet start and end
220
221     Set<FlatNode> endSet = new HashSet<FlatNode>();
222     endSet.add(end);
223     if ((start.getReachableSet(endSet)).contains(def)) {
224       return false;
225     }
226
227     return true;
228   }
229
230   private boolean checkConditionNode(Loops l, FlatOpNode fon) {
231     // check flatOpNode that computes loop guard condition
232     // currently we assume that induction variable is always getting bigger
233     // and guard variable is constant
234     // so need to check (1) one of operand should be induction variable
235     // (2) another operand should be constant or loop invariant
236
237     TempDescriptor induction = null;
238     TempDescriptor guard = null;
239
240     int op = fon.getOp().getOp();
241     if (op == Operation.LT || op == Operation.LTE) {
242       // condition is inductionVar <= loop invariant
243       induction = fon.getLeft();
244       guard = fon.getRight();
245     } else if (op == Operation.GT || op == Operation.GTE) {
246       // condition is loop invariant >= inductionVar
247       induction = fon.getRight();
248       guard = fon.getLeft();
249     } else {
250       return false;
251     }
252
253     if (!IsInductionVar(l, fon, induction)) {
254       return false;
255     }
256
257     if (guard != null) {
258       Set guardDefSet = getDefinitionInsideLoop(l, fon, guard);
259       for (Iterator iterator = guardDefSet.iterator(); iterator.hasNext();) {
260         FlatNode guardDef = (FlatNode) iterator.next();
261         if (!(guardDef instanceof FlatLiteralNode) && !loopInv.hoisted.contains(guardDef)) {
262           return false;
263         }
264       }
265     }
266
267     return true;
268   }
269
270   private boolean IsInductionVar(Loops l, FlatNode fn, TempDescriptor td) {
271
272     if (inductionSet.contains(td)) {
273       return true;
274     } else {
275       // check if td is composed by induction variables
276       Set<FlatNode> defSet = getDefinitionInsideLoop(l, fn, td);
277       for (Iterator iterator = defSet.iterator(); iterator.hasNext();) {
278         FlatNode defNode = (FlatNode) iterator.next();
279
280         TempDescriptor[] readTemps = defNode.readsTemps();
281         for (int i = 0; i < readTemps.length; i++) {
282
283           if (!IsInductionVar(l, defNode, readTemps[i])) {
284             if (!isLoopInvariantVar(l, defNode, readTemps[i])) {
285               return false;
286             }
287           }
288         }
289
290       }
291     }
292     return true;
293   }
294
295   private boolean isLoopInvariantVar(Loops l, FlatNode fn, TempDescriptor td) {
296
297     Set elements = l.loopIncElements();
298     Set<FlatNode> defset = loopInv.usedef.defMap(fn, td);
299
300     Set<FlatNode> defSetOfLoop = new HashSet<FlatNode>();
301     for (Iterator<FlatNode> defit = defset.iterator(); defit.hasNext();) {
302       FlatNode def = defit.next();
303       if (elements.contains(def)) {
304         defSetOfLoop.add(def);
305       }
306     }
307
308     if (defSetOfLoop.size() == 0) {
309       // all definition comes from outside the loop
310       // so it is loop invariant
311       return true;
312     } else if (defSetOfLoop.size() == 1) {
313       // check if def is 1) constant node or 2) loop invariant
314       FlatNode defFlatNode = defSetOfLoop.iterator().next();
315       if (defFlatNode instanceof FlatLiteralNode || loopInv.hoisted.contains(defFlatNode)) {
316         return true;
317       }
318     }
319
320     return false;
321
322   }
323
324   private Set<FlatNode> getDefinitionInsideLoop(Loops l, FlatNode fn, TempDescriptor td) {
325
326     Set<FlatNode> defSetOfLoop = new HashSet<FlatNode>();
327     Set loopElements = l.loopIncElements();
328
329     Set defSet = loopInv.usedef.defMap(fn, td);
330     for (Iterator iterator = defSet.iterator(); iterator.hasNext();) {
331       FlatNode defFlatNode = (FlatNode) iterator.next();
332       if (loopElements.contains(defFlatNode)) {
333         defSetOfLoop.add(defFlatNode);
334       }
335     }
336
337     return defSetOfLoop;
338
339   }
340
341   private boolean hasLoopExitNode(Loops l, FlatCondBranch fcb, boolean fromTrueBlock) {
342
343     Set loopElements = l.loopIncElements();
344     Set entrances = l.loopEntrances();
345     FlatNode fn = (FlatNode) entrances.iterator().next();
346
347     if (!fromTrueBlock) {
348       // in this case, FlatCondBranch must have two next flat node, one for true
349       // block and one for false block
350       assert fcb.next.size() == 2;
351     }
352
353     FlatNode next;
354     if (fromTrueBlock) {
355       next = fcb.getNext(0);
356     } else {
357       next = fcb.getNext(1);
358     }
359
360     if (hasLoopExitNode(fn, next, loopElements)) {
361       return true;
362     } else {
363       return false;
364     }
365
366   }
367
368   private boolean hasLoopExitNode(FlatNode loopHeader, FlatNode start, Set loopElements) {
369
370     Set<FlatNode> tovisit = new HashSet<FlatNode>();
371     Set<FlatNode> visited = new HashSet<FlatNode>();
372     tovisit.add(start);
373
374     while (!tovisit.isEmpty()) {
375
376       FlatNode fn = tovisit.iterator().next();
377       tovisit.remove(fn);
378       visited.add(fn);
379
380       if (!loopElements.contains(fn)) {
381         // check if this loop exit is derived from start node
382         return true;
383       }
384
385       for (int i = 0; i < fn.numNext(); i++) {
386         FlatNode next = fn.getNext(i);
387         if (!visited.contains(next)) {
388           if (loopInv.domtree.idom(next).equals(fn)) {
389             // add next node only if current node is immediate dominator of the
390             // next node
391             tovisit.add(next);
392           }
393         }
394       }
395
396     }
397
398     return false;
399
400   }
401 }