forgot to commit
[IRC.git] / Robust / src / Analysis / Loops / LoopTerminate.java
1 package Analysis.Loops;
2
3 import java.util.HashMap;
4 import java.util.HashSet;
5 import java.util.Iterator;
6 import java.util.Set;
7
8 import Analysis.SSJava.SSJavaAnalysis;
9 import IR.FieldDescriptor;
10 import IR.Operation;
11 import IR.State;
12 import IR.Flat.FKind;
13 import IR.Flat.FlatCondBranch;
14 import IR.Flat.FlatFieldNode;
15 import IR.Flat.FlatMethod;
16 import IR.Flat.FlatNode;
17 import IR.Flat.FlatOpNode;
18 import IR.Flat.FlatSetFieldNode;
19 import IR.Flat.TempDescriptor;
20
21 public class LoopTerminate {
22
23   private FlatMethod fm;
24   private LoopInvariant loopInv;
25   private Set<TempDescriptor> inductionSet;
26   // mapping from Induction Variable TempDescriptor to Flat Node that defines
27   // it
28   private HashMap<TempDescriptor, FlatNode> inductionVar2DefNode;
29
30   // mapping from Derived Induction Variable TempDescriptor to its root
31   // induction variable TempDescriptor
32   private HashMap<TempDescriptor, TempDescriptor> derivedVar2basicInduction;
33
34   // maps a loop entrance to the result of termination analysis
35   private HashMap<FlatNode, Boolean> loopEntranceToTermination;
36
37   Set<FlatNode> computed;
38
39   State state;
40   SSJavaAnalysis ssjava;
41
42   /**
43    * Constructor for Loop Termination Analysis
44    */
45   public LoopTerminate(SSJavaAnalysis ssjava, State state) {
46     this.ssjava = ssjava;
47     this.state = state;
48     this.inductionSet = new HashSet<TempDescriptor>();
49     this.inductionVar2DefNode = new HashMap<TempDescriptor, FlatNode>();
50     this.derivedVar2basicInduction = new HashMap<TempDescriptor, TempDescriptor>();
51     this.computed = new HashSet<FlatNode>();
52     this.loopEntranceToTermination = new HashMap<FlatNode, Boolean>();
53   }
54
55   /**
56    * starts loop termination analysis
57    * 
58    * @param fm
59    *          FlatMethod for termination analysis
60    * @param loopInv
61    *          LoopInvariants for given method
62    */
63   public void terminateAnalysis(FlatMethod fm, LoopInvariant loopInv) {
64     this.fm = fm;
65     this.loopInv = loopInv;
66     Loops loopFinder = loopInv.root;
67     recurse(fm, loopFinder);
68   }
69
70   /**
71    * 
72    * spawn analysis for its child and then iterate over the current level of
73    * loops
74    * 
75    * @param fm
76    *          FlatMethod for loop analysis
77    * @param parent
78    *          the current level of loop
79    */
80   private void recurse(FlatMethod fm, Loops parent) {
81     for (Iterator lpit = parent.nestedLoops().iterator(); lpit.hasNext();) {
82       Loops child = (Loops) lpit.next();
83       recurse(fm, child);
84       processLoop(fm, child);
85     }
86   }
87
88   /**
89    * initialize internal data structure
90    */
91   private void init() {
92     inductionSet.clear();
93     inductionVar2DefNode.clear();
94     derivedVar2basicInduction.clear();
95   }
96
97   /**
98    * analysis loop for termination property
99    * 
100    * @param fm
101    *          FlatMethod that contains loop l
102    * @param l
103    *          analysis target loop l
104    */
105   private void processLoop(FlatMethod fm, Loops l) {
106
107     Set loopElements = l.loopIncElements(); // loop body elements
108     Set loopEntrances = l.loopEntrances(); // loop entrance
109     assert loopEntrances.size() == 1;
110     FlatNode loopEntrance = (FlatNode) loopEntrances.iterator().next();
111
112     String loopLabel = (String) state.fn2labelMap.get(loopEntrance);
113
114     if (loopLabel == null || !loopLabel.startsWith(ssjava.TERMINATE)) {
115       init();
116       detectBasicInductionVar(loopElements);
117       detectDerivedInductionVar(loopElements);
118       checkConditionBranch(loopEntrance, loopElements);
119     }
120
121   }
122
123   /**
124    * check if condition branch node satisfies loop condition
125    * 
126    * @param loopEntrance
127    *          loop entrance flat node
128    * @param loopElements
129    *          elements of current loop and all nested loop
130    */
131   private void checkConditionBranch(FlatNode loopEntrance, Set loopElements) {
132     // In the loop, every guard condition of the loop must be composed by
133     // induction & invariants
134     // every guard condition of the if-statement that leads it to the exit must
135     // be composed by induction&invariants
136
137     Set<FlatNode> tovisit = new HashSet<FlatNode>();
138     Set<FlatNode> visited = new HashSet<FlatNode>();
139     tovisit.add(loopEntrance);
140
141     int numMustTerminateGuardCondtion = 0;
142     while (!tovisit.isEmpty()) {
143       FlatNode fnvisit = tovisit.iterator().next();
144       tovisit.remove(fnvisit);
145       visited.add(fnvisit);
146
147       if (fnvisit.kind() == FKind.FlatCondBranch) {
148         FlatCondBranch fcb = (FlatCondBranch) fnvisit;
149
150         if ((fcb.isLoopBranch() && fcb.getLoopEntrance().equals(loopEntrance))
151             || hasLoopExitNode(fcb, true, loopEntrance, loopElements)) {
152           // current FlatCondBranch can introduce loop exits
153           // in this case, guard condition of it should be composed only by loop
154           // invariants and induction variables
155           Set<FlatNode> condSet = getDefinitionInLoop(fnvisit, fcb.getTest(), loopElements);
156           assert condSet.size() == 1;
157           FlatNode condFn = condSet.iterator().next();
158           // assume that guard condition node is always a conditional inequality
159           if (condFn instanceof FlatOpNode) {
160             FlatOpNode condOp = (FlatOpNode) condFn;
161             // check if guard condition is composed only with induction
162             // variables
163             if (checkConditionNode(condOp, fcb.isLoopBranch(), loopElements)) {
164               numMustTerminateGuardCondtion++;
165             } else {
166               if (!fcb.isLoopBranch()) {
167                 // I DON'T THIINK WE NEED TO CARE ABOUT THIS CASE!!!
168                 // if it is if-condition and it leads us to loop exit,
169                 // corresponding guard condition should be composed by induction
170                 // and invariants
171                 // throw new Error("Loop may never terminate at "
172                 // + fm.getMethod().getClassDesc().getSourceFileName() + "::"
173                 // + loopEntrance.numLine);
174               }
175             }
176           }
177         }
178       }
179
180       for (int i = 0; i < fnvisit.numNext(); i++) {
181         FlatNode next = fnvisit.getNext(i);
182         if (loopElements.contains(next) && !visited.contains(next)) {
183           tovisit.add(next);
184         }
185       }
186
187     }
188
189     // # of must-terminate loop condition must be equal to or larger than # of
190     // loop
191     if (numMustTerminateGuardCondtion == 0) {
192       throw new Error("Loop may never terminate at "
193           + fm.getMethod().getClassDesc().getSourceFileName() + "::" + loopEntrance.numLine);
194     }
195   }
196
197   /**
198    * detect derived induction variable
199    * 
200    * @param loopElements
201    *          elements of current loop and all nested loop
202    */
203   private void detectDerivedInductionVar(Set loopElements) {
204     // 2) detect derived induction variables
205     // variable k is a derived induction variable if
206     // 1) there is only one definition of k within the loop, of the form k=j*c
207     // or k=j+d where j is induction variable, c, d are loop-invariant
208     // 2) and if j is a derived induction variable in the family of i, then:
209     // (a) the only definition of j that reaches k is the one in the loop
210     // (b) and there is no definition of i on any path between the definition of
211     // j and the definition of k
212
213     boolean changed = true;
214     Set<TempDescriptor> basicInductionSet = new HashSet<TempDescriptor>();
215     basicInductionSet.addAll(inductionSet);
216
217     while (changed) {
218       changed = false;
219       nextfn: for (Iterator elit = loopElements.iterator(); elit.hasNext();) {
220         FlatNode fn = (FlatNode) elit.next();
221         if (!computed.contains(fn)) {
222           if (fn.kind() == FKind.FlatOpNode) {
223             FlatOpNode fon = (FlatOpNode) fn;
224             int op = fon.getOp().getOp();
225             if (op == Operation.ADD || op == Operation.MULT) {
226               TempDescriptor tdLeft = fon.getLeft();
227               TempDescriptor tdRight = fon.getRight();
228               TempDescriptor tdDest = fon.getDest();
229
230               boolean isLeftLoopInvariant = isLoopInvariantVar(fn, tdLeft, loopElements);
231               boolean isRightLoopInvariant = isLoopInvariantVar(fn, tdRight, loopElements);
232
233               if (isLeftLoopInvariant ^ isRightLoopInvariant) {
234                 TempDescriptor inductionOp;
235                 if (isLeftLoopInvariant) {
236                   inductionOp = tdRight;
237                 } else {
238                   inductionOp = tdLeft;
239                 }
240
241                 if (inductionSet.contains(inductionOp)) {
242                   // find new derived one k
243
244                   if (!basicInductionSet.contains(inductionOp)) {
245                     // in this case, induction variable 'j' is derived from
246                     // basic induction var
247
248                     // check if only definition of j that reaches k is the one
249                     // in the loop
250
251                     Set<FlatNode> defSet = getDefinitionInLoop(fn, inductionOp, loopElements);
252                     if (defSet.size() == 1) {
253                       // check if there is no def of i on any path bet' def of j
254                       // and def of k
255
256                       TempDescriptor originInduc = derivedVar2basicInduction.get(inductionOp);
257                       FlatNode defI = inductionVar2DefNode.get(originInduc);
258                       FlatNode defJ = inductionVar2DefNode.get(inductionOp);
259                       FlatNode defk = fn;
260
261                       if (!checkPath(defI, defJ, defk)) {
262                         continue nextfn;
263                       }
264
265                     }
266                   }
267                   // add new induction var
268
269                   // when tdDest has the form of srctmp(tdDest) = inductionOp +
270                   // loopInvariant
271                   // want to have the definition of srctmp
272                   Set<FlatNode> setUseNode = loopInv.usedef.useMap(fn, tdDest);
273                   assert setUseNode.size() == 1;
274                   assert setUseNode.iterator().next().writesTemps().length == 1;
275
276                   FlatNode srcDefNode = setUseNode.iterator().next();
277                   if (srcDefNode instanceof FlatOpNode) {
278                     if (((FlatOpNode) srcDefNode).getOp().getOp() == Operation.ASSIGN) {
279                       TempDescriptor derivedIndVar = setUseNode.iterator().next().writesTemps()[0];
280                       FlatNode defNode = setUseNode.iterator().next();
281
282                       computed.add(fn);
283                       computed.add(defNode);
284                       inductionSet.add(derivedIndVar);
285                       inductionVar2DefNode.put(derivedIndVar, defNode);
286                       derivedVar2basicInduction.put(derivedIndVar, inductionOp);
287                       changed = true;
288                     }
289                   }
290
291                 }
292
293               }
294
295             }
296
297           }
298         }
299
300       }
301     }
302
303   }
304
305   /**
306    * detect basic induction variable
307    * 
308    * @param loopElements
309    *          elements of current loop and all nested loop
310    */
311   private void detectBasicInductionVar(Set loopElements) {
312     // 1) find out basic induction variable
313     // variable i is a basic induction variable in loop if the only definitions
314     // of i within L are of the form i=i+c where c is loop invariant
315
316     for (Iterator elit = loopElements.iterator(); elit.hasNext();) {
317       FlatNode fn = (FlatNode) elit.next();
318       if (fn.kind() == FKind.FlatOpNode) {
319         FlatOpNode fon = (FlatOpNode) fn;
320         int op = fon.getOp().getOp();
321         if (op == Operation.ADD) {
322           TempDescriptor tdLeft = fon.getLeft();
323           TempDescriptor tdRight = fon.getRight();
324
325           boolean isLeftLoopInvariant = isLoopInvariantVar(fn, tdLeft, loopElements);
326           boolean isRightLoopInvariant = isLoopInvariantVar(fn, tdRight, loopElements);
327
328           if (isLeftLoopInvariant ^ isRightLoopInvariant) {
329
330             TempDescriptor candidateTemp;
331
332             if (isLeftLoopInvariant) {
333               candidateTemp = tdRight;
334             } else {
335               candidateTemp = tdLeft;
336             }
337
338             Set<FlatNode> defSetOfLoop = getDefinitionInLoop(fn, candidateTemp, loopElements);
339             // basic induction variable must have only one definition within the
340             // loop
341             if (defSetOfLoop.size() == 1) {
342               // find out definition of induction var, form of Flat
343               // Assign:inductionVar = candidateTemp
344               FlatNode indNode = defSetOfLoop.iterator().next();
345               assert indNode.readsTemps().length == 1;
346               TempDescriptor readTemp = indNode.readsTemps()[0];
347               if (readTemp.equals(fon.getDest())) {
348                 inductionVar2DefNode.put(candidateTemp, defSetOfLoop.iterator().next());
349                 inductionVar2DefNode.put(readTemp, defSetOfLoop.iterator().next());
350                 inductionSet.add(fon.getDest());
351                 inductionSet.add(candidateTemp);
352                 computed.add(fn);
353               }
354
355             }
356
357           }
358
359         }
360       }
361     }
362
363   }
364
365   /**
366    * check whether there is no definition node 'def' on any path between 'start'
367    * node and 'end' node
368    * 
369    * @param def
370    * @param start
371    * @param end
372    * @return true if there is no def in-bet start and end
373    */
374   private boolean checkPath(FlatNode def, FlatNode start, FlatNode end) {
375     Set<FlatNode> endSet = new HashSet<FlatNode>();
376     endSet.add(end);
377     return !(start.getReachableSet(endSet)).contains(def);
378   }
379
380   /**
381    * check condition node satisfies termination condition
382    * 
383    * @param fon
384    *          condition node FlatOpNode
385    * @param isLoopCondition
386    *          true if condition is loop condition
387    * @param loopElements
388    *          elements of current loop and all nested loop
389    * @return true if it satisfies termination condition
390    */
391   private boolean checkConditionNode(FlatOpNode fon, boolean isLoopCondition, Set loopElements) {
392     // check flatOpNode that computes loop guard condition
393     // currently we assume that induction variable is always getting bigger
394     // and guard variable is constant
395     // so need to check (1) one of operand should be induction variable
396     // (2) another operand should be constant or loop invariant
397
398     TempDescriptor induction = null;
399     TempDescriptor guard = null;
400
401     int op = fon.getOp().getOp();
402     if (op == Operation.LT || op == Operation.LTE) {
403       if (isLoopCondition) {
404         // loop condition is inductionVar <= loop invariant
405         induction = fon.getLeft();
406         guard = fon.getRight();
407       } else {
408         // if-statement condition is loop invariant <= inductionVar since
409         // inductionVar is getting biggier each iteration
410         induction = fon.getRight();
411         guard = fon.getLeft();
412       }
413     } else if (op == Operation.GT || op == Operation.GTE) {
414       if (isLoopCondition) {
415         // condition is loop invariant >= inductionVar
416         induction = fon.getRight();
417         guard = fon.getLeft();
418       } else {
419         // if-statement condition is loop inductionVar >= invariant
420         induction = fon.getLeft();
421         guard = fon.getRight();
422       }
423     } else {
424       return false;
425     }
426
427     // here, verify that guard operand is an induction variable
428     if (!hasInductionVar(fon, induction, loopElements, new HashSet<TempDescriptor>())) {
429       return false;
430     }
431
432     if (guard != null) {
433       Set guardDefSet = getDefinitionInLoop(fon, guard, loopElements);
434       for (Iterator iterator = guardDefSet.iterator(); iterator.hasNext();) {
435         FlatNode guardDef = (FlatNode) iterator.next();
436         if (guardDef.kind() == FKind.FlatFieldNode) {
437           FlatFieldNode ffn = (FlatFieldNode) guardDef;
438           if ((ffn.getField().isStatic() && ffn.getField().isFinal())
439               || (!hasFieldAccessInLoopElements(ffn, loopElements))) {
440             // if field is STATIC FINAL field or field is not appeared inside
441             // the current loop, allow it to be the guard
442             // condition
443             return true;
444           } else {
445             return false;
446           }
447         } else if (!(guardDef.kind() == FKind.FlatLiteralNode)
448             && !loopInv.hoisted.contains(guardDef)) {
449           return false;
450         }
451       }
452     }
453     return true;
454   }
455
456   private boolean hasFieldAccessInLoopElements(FlatFieldNode guardNode, Set loopElements) {
457     for (Iterator iterator = loopElements.iterator(); iterator.hasNext();) {
458       FlatNode fn = (FlatNode) iterator.next();
459       if (fn.kind() == FKind.FlatSetFieldNode) {
460         FlatSetFieldNode ffn = (FlatSetFieldNode) fn;
461         if (!ffn.equals(guardNode) && ffn.getField().equals(guardNode.getField())) {
462           return true;
463         }
464       }
465     }
466     return false;
467   }
468
469   /**
470    * check if TempDescriptor td has at least one induction variable and is
471    * composed only by induction vars +loop invariants
472    * 
473    * @param fn
474    *          FlatNode that contains TempDescriptor 'td'
475    * @param td
476    *          TempDescriptor representing target variable
477    * @param loopElements
478    *          elements of current loop and all nested loop
479    * @return true if 'td' is induction variable
480    */
481   private boolean hasInductionVar(FlatNode fn, TempDescriptor td, Set loopElements,
482       Set<TempDescriptor> visited) {
483
484     visited.add(td);
485     if (inductionSet.contains(td)) {
486       return true;
487     } else {
488       // check if td is composed by induction variables or loop invariants
489       Set<FlatNode> defSet = getDefinitionInLoop(fn, td, loopElements);
490       for (Iterator iterator = defSet.iterator(); iterator.hasNext();) {
491         FlatNode defNode = (FlatNode) iterator.next();
492
493         int inductionVarCount = 0;
494         TempDescriptor[] readTemps = defNode.readsTemps();
495         for (int i = 0; i < readTemps.length; i++) {
496           if (!visited.contains(readTemps[i])) {
497             if (!hasInductionVar(defNode, readTemps[i], loopElements, visited)) {
498               if (!isLoopInvariantVar(defNode, readTemps[i], loopElements)) {
499                 return false;
500               }
501             } else {
502               inductionVarCount++;
503             }
504           }
505         }
506         // check definition of td has at least one induction var
507         if (inductionVarCount > 0) {
508           return true;
509         }
510
511       }
512
513       return false;
514     }
515
516   }
517
518   /**
519    * check if TempDescriptor td is loop invariant variable or constant value wrt
520    * the current loop
521    * 
522    * @param fn
523    *          FlatNode that contains TempDescriptor 'td'
524    * @param td
525    *          TempDescriptor representing target variable
526    * @param loopElements
527    *          elements of current loop and all nested loop
528    * @return true if 'td' is loop invariant variable
529    */
530   private boolean isLoopInvariantVar(FlatNode fn, TempDescriptor td, Set loopElements) {
531
532     Set<FlatNode> defset = loopInv.usedef.defMap(fn, td);
533
534     Set<FlatNode> defSetOfLoop = new HashSet<FlatNode>();
535     for (Iterator<FlatNode> defit = defset.iterator(); defit.hasNext();) {
536       FlatNode def = defit.next();
537       if (loopElements.contains(def)) {
538         defSetOfLoop.add(def);
539       }
540     }
541
542     if (defSetOfLoop.size() == 0) {
543       // all definition comes from outside the loop
544       // so it is loop invariant
545       return true;
546     } else if (defSetOfLoop.size() == 1) {
547       // check if def is 1) constant node or 2) loop invariant
548       FlatNode defFlatNode = defSetOfLoop.iterator().next();
549       if (defFlatNode.kind() == FKind.FlatLiteralNode || loopInv.hoisted.contains(defFlatNode)) {
550         return true;
551       }
552     }
553
554     return false;
555
556   }
557
558   /**
559    * compute the set of definitions of variable 'td' inside of the loop
560    * 
561    * @param fn
562    *          FlatNode that uses 'td'
563    * @param td
564    *          target node that we want to have the set of definitions
565    * @param loopElements
566    *          elements of current loop and all nested loop
567    * @return the set of definition nodes for 'td' in the current loop
568    */
569   private Set<FlatNode> getDefinitionInLoop(FlatNode fn, TempDescriptor td, Set loopElements) {
570
571     Set<FlatNode> defSetOfLoop = new HashSet<FlatNode>();
572
573     Set defSet = loopInv.usedef.defMap(fn, td);
574     for (Iterator iterator = defSet.iterator(); iterator.hasNext();) {
575       FlatNode defFlatNode = (FlatNode) iterator.next();
576       if (loopElements.contains(defFlatNode)) {
577         defSetOfLoop.add(defFlatNode);
578       }
579     }
580
581     return defSetOfLoop;
582
583   }
584
585   /**
586    * check whether FlatCondBranch introduces loop exit
587    * 
588    * @param fcb
589    *          target node
590    * @param fromTrueBlock
591    *          specify which block is possible to have loop exit
592    * @param loopHeader
593    *          loop header of current loop
594    * @param loopElements
595    *          elements of current loop and all nested loop
596    * @return true if input 'fcb' intrroduces loop exit
597    */
598   private boolean hasLoopExitNode(FlatCondBranch fcb, boolean fromTrueBlock, FlatNode loopHeader,
599       Set loopElements) {
600     // return true if fcb possibly introduces loop exit
601
602     FlatNode next;
603     if (fromTrueBlock) {
604       next = fcb.getNext(0);
605     } else {
606       next = fcb.getNext(1);
607     }
608
609     return hasLoopExitNode(loopHeader, next, loopElements);
610
611   }
612
613   /**
614    * check whether start node reaches loop exit
615    * 
616    * @param loopHeader
617    * @param start
618    * @param loopElements
619    * @return true if a path exist from start to loop exit
620    */
621   private boolean hasLoopExitNode(FlatNode loopHeader, FlatNode start, Set loopElements) {
622
623     Set<FlatNode> tovisit = new HashSet<FlatNode>();
624     Set<FlatNode> visited = new HashSet<FlatNode>();
625     tovisit.add(start);
626
627     while (!tovisit.isEmpty()) {
628
629       FlatNode fn = tovisit.iterator().next();
630       tovisit.remove(fn);
631       visited.add(fn);
632
633       for (int i = 0; i < fn.numNext(); i++) {
634         FlatNode next = fn.getNext(i);
635         if (!visited.contains(next)) {
636           // check that if-body statment introduces loop exit.
637           if (!loopElements.contains(next)) {
638             return true;
639           }
640
641           if (loopInv.domtree.idom(next).equals(fn)) {
642             // add next node only if current node is immediate dominator of the
643             // next node
644             tovisit.add(next);
645           }
646         }
647       }
648
649     }
650
651     return false;
652
653   }
654 }