a whole bunch of optimizations...should be useful for transactions
authorbdemsky <bdemsky>
Thu, 2 Apr 2009 06:32:55 +0000 (06:32 +0000)
committerbdemsky <bdemsky>
Thu, 2 Apr 2009 06:32:55 +0000 (06:32 +0000)
Robust/src/Analysis/Loops/CSE.java [new file with mode: 0644]
Robust/src/Analysis/Loops/CopyPropagation.java [new file with mode: 0644]
Robust/src/Analysis/Loops/DeadCode.java [new file with mode: 0644]
Robust/src/Analysis/Loops/GlobalFieldType.java [new file with mode: 0644]
Robust/src/Analysis/Loops/LoopInvariant.java
Robust/src/Analysis/Loops/LoopOptimize.java
Robust/src/Analysis/Loops/UseDef.java
Robust/src/Analysis/Loops/localCSE.java [new file with mode: 0644]

diff --git a/Robust/src/Analysis/Loops/CSE.java b/Robust/src/Analysis/Loops/CSE.java
new file mode 100644 (file)
index 0000000..0454068
--- /dev/null
@@ -0,0 +1,195 @@
+package Analysis.Loops;
+
+import IR.Flat.*;
+import IR.TypeUtil;
+import IR.Operation;
+import IR.FieldDescriptor;
+import IR.MethodDescriptor;
+import IR.TypeDescriptor;
+import java.util.Map;
+import java.util.Iterator;
+import java.util.Hashtable;
+import java.util.HashSet;
+import java.util.Set;
+
+public class CSE {
+  GlobalFieldType gft;
+  TypeUtil typeutil;
+  public CSE(GlobalFieldType gft, TypeUtil typeutil) {
+    this.gft=gft;
+    this.typeutil=typeutil;
+  }
+
+  public void doAnalysis(FlatMethod fm) {
+    Hashtable<FlatNode,Hashtable<Expression, TempDescriptor>> availexpr=new Hashtable<FlatNode,Hashtable<Expression, TempDescriptor>>();
+
+    HashSet toprocess=new HashSet();
+    HashSet discovered=new HashSet();
+    toprocess.add(fm);
+    discovered.add(fm);
+    while(!toprocess.isEmpty()) {
+      FlatNode fn=(FlatNode)toprocess.iterator().next();
+      toprocess.remove(fn);
+      for(int i=0;i<fn.numNext();i++) {
+       FlatNode nnext=fn.getNext(i);
+       if (!discovered.contains(nnext)) {
+         toprocess.add(nnext);
+         discovered.add(nnext);
+       }
+      }
+      Hashtable<Expression, TempDescriptor> tab=new Hashtable<Expression, TempDescriptor>();
+      boolean first=true;
+      
+      //compute intersection
+      for(int i=0;i<fn.numPrev();i++) {
+       FlatNode prev=fn.getPrev(i);
+       if (first) {
+         if (availexpr.containsKey(prev))
+           tab.putAll(availexpr.get(prev));
+         first=false;
+       } else {
+         if (availexpr.containsKey(prev)) {
+           Hashtable<Expression, TempDescriptor> table=availexpr.get(prev);
+           for(Iterator mapit=tab.entrySet().iterator();mapit.hasNext();) {
+             Object entry=mapit.next();
+             if (!table.contains(entry))
+               mapit.remove();
+           }
+         }
+       }
+      }
+      //Do kills of expression/variable mappings
+      TempDescriptor[] write=fn.writesTemps();
+      for(int i=0;i<write.length;i++) {
+       if (tab.containsKey(write[i]))
+         tab.remove(write[i]);
+      }
+      
+      switch(fn.kind()) {
+      case FKind.FlatCall:
+       {
+         FlatCall fc=(FlatCall) fn;
+         MethodDescriptor md=fc.getMethod();
+         Set<FieldDescriptor> fields=gft.getFields(md);
+         Set<TypeDescriptor> arrays=gft.getArrays(md);
+         killexpressions(tab, fields, arrays);
+         break;
+       }
+      case FKind.FlatOpNode:
+       {
+         FlatOpNode fon=(FlatOpNode) fn;
+         Expression e=new Expression(fon.getLeft(), fon.getRight(),fon.getOp());
+         tab.put(e, fon.getDest());
+         break;
+       }
+      case FKind.FlatSetFieldNode:
+       {
+         FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
+         Set<FieldDescriptor> fields=new HashSet<FieldDescriptor>();
+         fields.add(fsfn.getField());
+         killexpressions(tab, fields, null);
+         Expression e=new Expression(fsfn.getDst(), fsfn.getField());
+         tab.put(e, fsfn.getSrc());
+         break;
+       }
+      case FKind.FlatFieldNode:
+       {
+         FlatFieldNode ffn=(FlatFieldNode)fn;
+         Expression e=new Expression(ffn.getSrc(), ffn.getField());
+         tab.put(e, ffn.getDst());
+         break;
+       }
+      case FKind.FlatSetElementNode:
+       {
+         FlatSetElementNode fsen=(FlatSetElementNode)fn;
+         Expression e=new Expression(fsen.getDst(),fsen.getIndex());
+         tab.put(e, fsen.getSrc());
+         break;
+       }
+      case FKind.FlatElementNode:
+       {
+         FlatElementNode fen=(FlatElementNode)fn;
+         Expression e=new Expression(fen.getSrc(),fen.getIndex());
+         tab.put(e, fen.getDst());
+         break;
+       }
+      default:
+      }
+      
+      if (write.length==1) {
+       TempDescriptor w=write[0];
+       for(Iterator it=tab.entrySet().iterator();it.hasNext();) {
+         Map.Entry m=(Map.Entry)it.next();
+         Expression e=(Expression)m.getKey();
+         if (e.a==w||e.b==w)
+           it.remove();
+       }
+      }
+      if (!availexpr.containsKey(fn)||!availexpr.get(fn).equals(tab)) {
+       availexpr.put(fn, tab);
+       for(int i=0;i<fn.numNext();i++) {
+         FlatNode nnext=fn.getNext(i);
+         toprocess.add(nnext);
+       }
+      }
+    }
+  }
+
+  public void killexpressions(Hashtable<Expression, TempDescriptor> tab, Set<FieldDescriptor> fields, Set<TypeDescriptor> arrays) {
+    for(Iterator it=tab.entrySet().iterator();it.hasNext();) {
+      Map.Entry m=(Map.Entry)it.next();
+      Expression e=(Expression)m.getKey();
+      if (e.f!=null&&fields!=null&&fields.contains(e.f)) 
+       it.remove();
+      else if ((e.a!=null)&&(arrays!=null)) {
+       for(Iterator<TypeDescriptor> arit=arrays.iterator();arit.hasNext();) {
+         TypeDescriptor artd=arit.next();
+         if (typeutil.isSuperorType(artd,e.a.getType())||
+             typeutil.isSuperorType(e.a.getType(),artd)) {
+           it.remove();
+           break;
+         }
+       }
+      }
+    }
+  }
+}
+
+class Expression {
+  Operation op;
+  TempDescriptor a;
+  TempDescriptor b;
+  FieldDescriptor f;
+  Expression(TempDescriptor a, TempDescriptor b, Operation op) {
+    this.a=a;
+    this.b=b;
+    this.op=op;
+  }
+  Expression(TempDescriptor a, FieldDescriptor f) {
+    this.a=a;
+    this.f=f;
+  }
+  Expression(TempDescriptor a, TempDescriptor index) {
+    this.a=a;
+    this.b=index;
+  }
+  public int hashCode() {
+    int h=0;
+    h^=a.hashCode();
+    if (op!=null)
+      h^=op.getOp();
+    if (b!=null)
+      h^=b.hashCode();
+    if (f!=null)
+      h^=f.hashCode();
+    return h;
+  }
+  public boolean equals(Object o) {
+    Expression e=(Expression)o;
+    if (a!=e.a||f!=e.f||b!=e.b)
+      return false;
+    if (op!=null)
+      return op.getOp()==e.op.getOp();
+    return true;
+  }
+}
diff --git a/Robust/src/Analysis/Loops/CopyPropagation.java b/Robust/src/Analysis/Loops/CopyPropagation.java
new file mode 100644 (file)
index 0000000..85ca787
--- /dev/null
@@ -0,0 +1,91 @@
+package Analysis.Loops;
+import IR.Flat.*;
+import IR.Operation;
+import java.util.Iterator;
+import java.util.Hashtable;
+import java.util.HashSet;
+import java.util.Map;
+
+public class CopyPropagation {
+  public CopyPropagation() {
+  }
+
+  public void optimize(FlatMethod fm) {
+    Hashtable<FlatNode, Hashtable<TempDescriptor, TempDescriptor>> table
+      =new Hashtable<FlatNode, Hashtable<TempDescriptor, TempDescriptor>>();
+    boolean changed=false;
+    do {
+      changed=false;
+      HashSet tovisit=new HashSet();
+      HashSet discovered=new HashSet();
+      tovisit.add(fm);
+      discovered.add(fm);
+      while(!tovisit.isEmpty()) {
+       FlatNode fn=(FlatNode) tovisit.iterator().next();
+       tovisit.remove(fn);
+       for(int i=0;i<fn.numNext();i++) {
+         FlatNode nnext=fn.getNext(i);
+         if (!discovered.contains(nnext)) {
+           discovered.add(nnext);
+           tovisit.add(nnext);
+         }
+       }
+       Hashtable<TempDescriptor, TempDescriptor> tab;
+       if (fn.numPrev()>=1&&table.containsKey(fn.getPrev(0)))
+         tab=new Hashtable<TempDescriptor, TempDescriptor>(table.get(fn.getPrev(0)));
+       else
+         tab=new Hashtable<TempDescriptor, TempDescriptor>();
+       //Compute intersection
+       for(int i=1;i<fn.numPrev();i++) {
+         Hashtable<TempDescriptor, TempDescriptor> tp=table.get(fn.getPrev(i));
+         for(Iterator tmpit=tab.entrySet().iterator();tmpit.hasNext();) {
+           Map.Entry t=(Map.Entry)tmpit.next();
+           TempDescriptor tmp=(TempDescriptor)t.getKey();
+           if (tp!=null&&(!tp.containsKey(tmp)||tp.get(tmp)!=tab.get(tmp))) {
+             tmpit.remove();
+           }
+         }
+       }
+       TempDescriptor[]writes=fn.writesTemps();
+       for(int i=0;i<writes.length;i++) {
+         TempDescriptor tmp=writes[i];
+         for(Iterator<TempDescriptor> tmpit=tab.keySet().iterator();tmpit.hasNext();) {        
+           TempDescriptor tmp2=tmpit.next();
+           if (tmp==tab.get(tmp2))
+             tmpit.remove();
+         }
+       }
+       if (fn.kind()==FKind.FlatOpNode) {
+         FlatOpNode fon=(FlatOpNode)fn;
+         if (fon.getOp().getOp()==Operation.ASSIGN) {
+           tab.put(fon.getDest(), fon.getLeft());
+         }
+       }
+       if (!table.containsKey(fn)||!table.get(fn).equals(tab)) {
+         table.put(fn,tab);
+         changed=true;
+         for(int i=0;i<fn.numNext();i++) {
+           FlatNode nnext=fn.getNext(i);
+           tovisit.add(nnext);
+         }
+       }
+      }
+      for(Iterator<FlatNode> it=fm.getNodeSet().iterator();it.hasNext();) {
+       FlatNode fn=it.next();
+       Hashtable<TempDescriptor, TempDescriptor> tab=table.get(fn);
+       TempMap tmap=null;
+       TempDescriptor[]reads=fn.readsTemps();
+       for(int i=0;i<reads.length;i++) {
+         TempDescriptor tmp=reads[i];
+         if (tab.containsKey(tmp)) {
+           if (tmap==null)
+             tmap=new TempMap();
+           tmap.addPair(tmp, tab.get(tmp));
+         }
+       }
+       if (tmap!=null)
+         fn.rewriteUse(tmap);
+      }
+    } while(changed);
+  }
+}
\ No newline at end of file
diff --git a/Robust/src/Analysis/Loops/DeadCode.java b/Robust/src/Analysis/Loops/DeadCode.java
new file mode 100644 (file)
index 0000000..ebbba25
--- /dev/null
@@ -0,0 +1,91 @@
+package Analysis.Loops;
+import IR.Flat.*;
+import IR.Operation;
+import java.util.HashSet;
+import java.util.Iterator;
+
+public class DeadCode {
+  public DeadCode() {
+  }
+  public void optimize(FlatMethod fm) {
+    UseDef ud=new UseDef(fm);
+    HashSet useful=new HashSet();
+    boolean changed=true;
+    while(changed) {
+      changed=false;
+      nextfn:
+      for(Iterator<FlatNode> it=fm.getNodeSet().iterator();it.hasNext();) {
+       FlatNode fn=it.next();
+       switch(fn.kind()) {
+       case FKind.FlatCall:
+       case FKind.FlatFieldNode:
+       case FKind.FlatSetFieldNode:
+       case FKind.FlatNew:
+       case FKind.FlatCastNode:
+       case FKind.FlatReturnNode:
+       case FKind.FlatCondBranch:
+       case FKind.FlatSetElementNode:
+       case FKind.FlatElementNode:
+       case FKind.FlatFlagActionNode:
+       case FKind.FlatCheckNode:
+       case FKind.FlatBackEdge:
+       case FKind.FlatTagDeclaration:
+       case FKind.FlatMethod:
+       case FKind.FlatAtomicEnterNode:
+       case FKind.FlatAtomicExitNode:
+       case FKind.FlatPrefetchNode:
+       case FKind.FlatSESEEnterNode:
+       case FKind.FlatSESEExitNode:
+         if (!useful.contains(fn)) {
+           useful.add(fn);
+           changed=true;
+         }       
+         break;
+       case FKind.FlatOpNode:
+         FlatOpNode fon=(FlatOpNode)fn;
+         if (fon.getOp().getOp()==Operation.DIV||
+             fon.getOp().getOp()==Operation.MOD) {
+           if (!useful.contains(fn)) {
+             useful.add(fn);
+             changed=true;
+           }
+           break;
+         }
+       default:
+         TempDescriptor[] writes=fn.writesTemps();
+         if (!useful.contains(fn))
+           for(int i=0;i<writes.length;i++) {
+             for(Iterator<FlatNode> uit=ud.useMap(fn,writes[i]).iterator();uit.hasNext();) {
+               FlatNode ufn=uit.next();
+               if (useful.contains(ufn)) {
+                 //we are useful
+                 useful.add(fn);
+                 changed=true;
+                 continue nextfn;
+               }
+             }
+           }
+       }
+      }
+    }
+    //get rid of useless nodes
+    for(Iterator<FlatNode> it=fm.getNodeSet().iterator();it.hasNext();) {
+      FlatNode fn=it.next();
+      if (!useful.contains(fn)) {
+       //We have a useless node
+       FlatNode fnnext=fn.getNext(0);
+       for(int i=0;i<fn.numPrev();i++) {
+         FlatNode nprev=fn.getPrev(i);
+         for(int j=0;j<nprev.numNext();j++) {
+           if (nprev.getNext(j)==fn) {
+             nprev.setnext(j, fnnext);
+             fnnext.addPrev(nprev);
+           }
+         }
+       }
+       //fix up prev edge of fnnext
+       fnnext.removePrev(fn);
+      }
+    }
+  }
+}
\ No newline at end of file
diff --git a/Robust/src/Analysis/Loops/GlobalFieldType.java b/Robust/src/Analysis/Loops/GlobalFieldType.java
new file mode 100644 (file)
index 0000000..cbef454
--- /dev/null
@@ -0,0 +1,88 @@
+package Analysis.Loops;
+
+import IR.Flat.*;
+import IR.State;
+import IR.MethodDescriptor;
+import IR.FieldDescriptor;
+import IR.TypeDescriptor;
+import Analysis.CallGraph.*;
+import java.util.Iterator;
+import java.util.HashSet;
+import java.util.Hashtable;
+import java.util.Set;
+
+public class GlobalFieldType {
+  CallGraph cg;
+  State st;
+  MethodDescriptor root;
+  Hashtable<MethodDescriptor, Set<FieldDescriptor>> fields;
+  Hashtable<MethodDescriptor, Set<TypeDescriptor>> arrays;
+  
+  public GlobalFieldType(CallGraph cg, State st, MethodDescriptor root) {
+    this.cg=cg;
+    this.st=st;
+    this.root=root;
+    this.fields=new Hashtable<MethodDescriptor, Set<FieldDescriptor>>();
+    this.arrays=new Hashtable<MethodDescriptor, Set<TypeDescriptor>>();
+    doAnalysis();
+  }
+  private void doAnalysis() {
+    HashSet toprocess=new HashSet();
+    toprocess.add(root);
+    HashSet discovered=new HashSet();
+    discovered.add(root);
+    while(!toprocess.isEmpty()) {
+      MethodDescriptor md=(MethodDescriptor)toprocess.iterator().next();
+      toprocess.remove(md);
+      analyzeMethod(md);
+      Set callees=cg.getCalleeSet(md);
+      for(Iterator it=callees.iterator();it.hasNext();) {
+       MethodDescriptor md2=(MethodDescriptor)it.next();
+       if (!discovered.contains(md2)) {
+         discovered.add(md2);
+         toprocess.add(md2);
+       }
+      }
+    }
+    boolean changed=true;
+    while(changed) {
+      changed=false;
+      for(Iterator it=discovered.iterator();it.hasNext();) {
+       MethodDescriptor md=(MethodDescriptor)it.next();
+       Set callees=cg.getCalleeSet(md);
+       for(Iterator cit=callees.iterator();cit.hasNext();) {
+         MethodDescriptor md2=(MethodDescriptor)cit.next();
+         if (fields.get(md).addAll(fields.get(md2)))
+           changed=true;
+         if (arrays.get(md).addAll(arrays.get(md2)))
+           changed=true;
+       }
+      }
+    }
+  }
+
+  public Set<FieldDescriptor> getFields(MethodDescriptor md) {
+    return fields.get(md);
+  }
+
+  public Set<TypeDescriptor> getArrays(MethodDescriptor md) {
+    return arrays.get(md);
+  }
+
+  public void analyzeMethod(MethodDescriptor md) {
+    fields.put(md, new HashSet<FieldDescriptor>());
+    arrays.put(md, new HashSet<TypeDescriptor>());
+    
+    FlatMethod fm=st.getMethodFlat(md);
+    for(Iterator it=fm.getNodeSet().iterator();it.hasNext();) {
+      FlatNode fn=(FlatNode)it.next();
+      if (fn.kind()==FKind.FlatSetElementNode) {
+       FlatSetElementNode fsen=(FlatSetElementNode)fn;
+       arrays.get(md).add(fsen.getDst().getType());
+      } else if (fn.kind()==FKind.FlatSetFieldNode) {
+       FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
+       fields.get(md).add(fsfn.getField());
+      }
+    }
+  }
+}
\ No newline at end of file
index db2e2d9355f843dc34bb24c459ca3dbecb4410fd..0e007e1fa3950a7a02c1384acc0509cd94496211 100644 (file)
@@ -1,6 +1,15 @@
 package Analysis.Loops;
 
 import IR.Flat.*;
+import IR.FieldDescriptor;
+import IR.TypeDescriptor;
+import IR.TypeUtil;
+import IR.Operation;
+import java.util.Iterator;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.Vector;
+import java.util.Hashtable;
 
 public class LoopInvariant {
   public LoopInvariant(TypeUtil typeutil) {
@@ -8,18 +17,20 @@ public class LoopInvariant {
   }
   LoopFinder loops;
   DomTree posttree;
-  Hashtable<Loops, Set<FlatNode>> table;
+  Hashtable<Loops, Vector<FlatNode>> table;
   Set<FlatNode> hoisted;
   UseDef usedef;
   TypeUtil typeutil;
+  Set tounroll;
 
   public void analyze(FlatMethod fm) {
     loops=new LoopFinder(fm);
-    Loops root=loops.getRootLoop(fm);
-    table=new Hashtable<Loops, Set<FlatNode>>();
+    Loops root=loops.getRootloop(fm);
+    table=new Hashtable<Loops, Vector<FlatNode>>();
     hoisted=new HashSet<FlatNode>();
     usedef=new UseDef(fm);
     posttree=new DomTree(fm,true);
+    tounroll=new HashSet();
     recurse(root);
   }
 
@@ -40,18 +51,18 @@ public class LoopInvariant {
 
     HashSet<FieldDescriptor> fields=new HashSet<FieldDescriptor>();
     HashSet<TypeDescriptor> types=new HashSet<TypeDescriptor>();
-
+    
     if (!isLeaf) {
       unsafe=true; 
     } else {
       /* Check whether it is safe to reuse values. */
       for(Iterator elit=elements.iterator();elit.hasNext();) {
-       FlatNode fn=elit.next();
+       FlatNode fn=(FlatNode)elit.next();
        if (fn.kind()==FKind.FlatAtomicEnterNode||
            fn.kind()==FKind.FlatAtomicExitNode||
-           fn.kind()==FKind.Call) {
+           fn.kind()==FKind.FlatCall) {
          unsafe=true;
-qq       break;
+         break;
        } else if (fn.kind()==FKind.FlatSetFieldNode) {
          FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
          fields.add(fsfn.getField());
@@ -62,10 +73,10 @@ qq    break;
       }   
     }
     
-    HashSet dominatorset=computeAlways(l);
+    HashSet dominatorset=unsafe?null:computeAlways(l);
 
     /* Compute loop invariants */
-    table.put(l, new HashSet<FlatNode>());
+    table.put(l, new Vector<FlatNode>());
     while(changed) {
       changed=false;
       nextfn:
@@ -73,8 +84,8 @@ qq      break;
        FlatNode fn=(FlatNode)tpit.next();
        switch(fn.kind()) {
        case FKind.FlatOpNode:
-         int op=((FlatOpNode)fn).getOperation();
-         if (op==Operation.DIV||op==Operation.MID||
+         int op=((FlatOpNode)fn).getOp().getOp();
+         if (op==Operation.DIV||op==Operation.MOD||
              checkNode(fn,elements)) {
            continue nextfn;
          }
@@ -91,13 +102,15 @@ qq   break;
              checkNode(fn,elements))
            continue nextfn;
          TypeDescriptor td=((FlatElementNode)fn).getSrc().getType();
-         for(Iterator<TypeDescriptor> tpit=types.iterator();tpit.hasNext();) {
-           TypeDescriptor td2=tpit.next();
+         for(Iterator<TypeDescriptor> tdit=types.iterator();tdit.hasNext();) {
+           TypeDescriptor td2=tdit.next();
            if (typeutil.isSuperorType(td,td2)||
                typeutil.isSuperorType(td2,td)) {
              continue nextfn;
            }
          }
+         if (isLeaf)
+           tounroll.add(l);
          break;
 
        case FKind.FlatFieldNode:
@@ -106,6 +119,8 @@ qq    break;
              checkNode(fn,elements)) {
            continue nextfn;
          }
+         if (isLeaf)
+           tounroll.add(l);
          break;
 
        default:
@@ -121,22 +136,21 @@ qq          break;
   public HashSet computeAlways(Loops l) {
     /* Compute nodes that are always executed in loop */
     HashSet dominatorset=null;
-    if (!unsafe) {
-      /* Compute nodes that definitely get executed in a loop */
-      Set entrances=l.loopEntraces();
-      assert entrances.size()==1;
-      FlatNode entrance=(FlatNode)entrances.iterator().next();
-      boolean first=true;
-      for (int i=0;i<entrances.numPrev();i++) {
-       FlatNode incoming=entrence.getPrev(i);
-       if (elements.contains(incoming)) {
-         HashSet domset=new HashSet();
-         domset.add(incoming);
-         FlatNode tmp=incoming;
-         while(tmp!=entrance) {
-           tmp=domtree.idom(tmp);
-           domset.add(tmp);
-         }
+    /* Compute nodes that definitely get executed in a loop */
+    Set elements=l.loopIncElements();
+    Set entrances=l.loopEntrances();
+    assert entrances.size()==1;
+    FlatNode entrance=(FlatNode)entrances.iterator().next();
+    boolean first=true;
+    for (int i=0;i<entrance.numPrev();i++) {
+      FlatNode incoming=entrance.getPrev(i);
+      if (elements.contains(incoming)) {
+       HashSet domset=new HashSet();
+       domset.add(incoming);
+       FlatNode tmp=incoming;
+       while(tmp!=entrance) {
+         tmp=posttree.idom(tmp);
+         domset.add(tmp);
        }
        if (first) {
          dominatorset=domset;
@@ -144,7 +158,7 @@ qq    break;
        } else {
          for(Iterator it=dominatorset.iterator();it.hasNext();) {
            FlatNode fn=(FlatNode)it.next();
-           if (!domset.containsKey(fn))
+           if (!domset.contains(fn))
              it.remove();
          }
        }
index 9456ea1e83a056abbf58d458d66efcdb20ffb8d0..2e5845065a7a17d86c6cee3165299c7d499d53a4 100644 (file)
@@ -1,6 +1,12 @@
 package Analysis.Loops;
 
 import IR.Flat.*;
+import IR.TypeUtil;
+import IR.Operation;
+import java.util.Set;
+import java.util.Vector;
+import java.util.Iterator;
+import java.util.Hashtable;
 
 public class LoopOptimize {
   LoopInvariant loopinv;
@@ -12,7 +18,122 @@ public class LoopOptimize {
     dooptimize(fm);
   } 
   private void dooptimize(FlatMethod fm) {
+    Loops root=loopinv.loops.getRootloop(fm);
+    recurse(root);
+  }
+  private void recurse(Loops parent) {
+    processLoop(parent);
+    for(Iterator lpit=parent.nestedLoops().iterator();lpit.hasNext();) {
+      Loops child=(Loops)lpit.next();
+      recurse(child);
+    }
+  }
+  public void processLoop(Loops l) {
+    if (loopinv.tounroll.contains(l)) {
+      unrollLoop(l);
+    } else {
+      hoistOps(l);
+    }
+  }
+  public void hoistOps(Loops l) {
+    Vector<FlatNode> tohoist=loopinv.table.get(l);
+    Set lelements=l.loopIncElements();
+    TempMap t=new TempMap();
+    FlatNode first=null;
+    FlatNode last=null;
+    for(int i=0;i<tohoist.size();i++) {
+      FlatNode fn=tohoist.elementAt(i);
+      TempDescriptor[] writes=fn.writesTemps();
+      for(int j=0;j<writes.length;j++) {
+       if (writes[j]!=null&&!t.maps(writes[j])) {
+         TempDescriptor cp=writes[j].createNew();
+         t.addPair(writes[j],cp);
+       }
+      }
+      FlatNode fnnew=fn.clone(t);
+      if (first==null)
+       first=fnnew;
+      else
+       last.addNext(fnnew);
+      last=fnnew;
+      /* Splice out old node */
+      if (writes.length==1) {
+       FlatOpNode fon=new FlatOpNode(t.tempMap(writes[0]),writes[0], null, new Operation(Operation.ASSIGN));
+       fn.replace(fon);
+      } else if (writes.length>1) {
+       throw new Error();
+      }
+    }
+    /* The chain is built at this point. */
+    
+    assert l.loopEntrances().size()==1;
+    FlatNode entrance=(FlatNode)l.loopEntrances().iterator().next();
+    for(int i=0;i<entrance.numPrev();i++) {
+      FlatNode prev=entrance.getPrev(i);
+      if (!lelements.contains(prev)) {
+       //need to fix this edge
+       for(int j=0;j<prev.numNext();j++) {
+         if (prev.getNext(j)==entrance)
+           prev.setNext(j, first);
+       }
+      }
+    }
+    last.addNext(entrance);
+  }
+  public void unrollLoop(Loops l) {
+    assert l.loopEntrances().size()==1;
+    FlatNode entrance=(FlatNode)l.loopEntrances().iterator().next();
+    Set lelements=l.loopIncElements();
+    Set<FlatNode> tohoist=loopinv.hoisted;
+    Hashtable<FlatNode, TempDescriptor> temptable=new Hashtable<FlatNode, TempDescriptor>();
+    Hashtable<FlatNode, FlatNode> copytable=new Hashtable<FlatNode, FlatNode>();
+    Hashtable<FlatNode, FlatNode> copyendtable=new Hashtable<FlatNode, FlatNode>();
+    
+    TempMap t=new TempMap();
+    /* Copy the nodes */
+    for(Iterator it=lelements.iterator();it.hasNext();) {
+      FlatNode fn=(FlatNode)it.next();
+      FlatNode copy=fn.clone(t);
+      FlatNode copyend=copy;
+      if (tohoist.contains(fn)) {
+       TempDescriptor[] writes=fn.writesTemps();
+       TempDescriptor tmp=writes[0];
+       TempDescriptor ntmp=tmp.createNew();
+       temptable.put(fn, ntmp);
+       copyend=new FlatOpNode(ntmp, tmp, null, new Operation(Operation.ASSIGN));
+       copy.addNext(copyend);
+      }
+      copytable.put(fn, copy);
+      copyendtable.put(fn, copyend);
+    }
+    /* Copy the edges */
+    for(Iterator it=lelements.iterator();it.hasNext();) {
+      FlatNode fn=(FlatNode)it.next();
+      FlatNode copyend=copyendtable.get(fn);
+      for(int i=0;i<fn.numNext();i++) {
+       FlatNode nnext=fn.getNext(i);
+       if (nnext==entrance) {
+         /* Back to loop header...point to old graph */
+         copyend.addNext(nnext);
+       } else if (lelements.contains(nnext)) {
+         /* In graph...point to first graph */
+         copyend.addNext(copytable.get(nnext));
+       } else {
+         /* Outside loop */
+         /* Just goto same place as before */
+         copyend.addNext(nnext);
+       }
+      }
+    }
+    /* Splice out loop invariant stuff */
+    for(Iterator it=lelements.iterator();it.hasNext();) {
+      FlatNode fn=(FlatNode)it.next();
+      if (tohoist.contains(fn)) {
+       TempDescriptor[] writes=fn.writesTemps();
+       TempDescriptor tmp=writes[0];
+       FlatOpNode fon=new FlatOpNode(temptable.get(fn),tmp, null, new Operation(Operation.ASSIGN));
+       fn.replace(fon);
+      }
+    }
   }
-  
-
 }
index 243c7da0f004571686f0770cc339365223382ec5..7bf6b54c31f3d82e8628b34b9e446bcadd26e7dc 100644 (file)
@@ -19,12 +19,20 @@ public class UseDef{
 
   /* Return FlatNodes that define Temp */
   public Set<FlatNode> defMap(FlatNode fn, TempDescriptor t) {
-    return defs.get(new TempFlatPair(t,fn));
+    Set<FlatNode> s=defs.get(new TempFlatPair(t,fn));
+    if (s==null)
+      return new HashSet<FlatNode>();
+    else
+      return s;
   }
 
   /* Return FlatNodes that use Temp */
   public Set<FlatNode> useMap(FlatNode fn, TempDescriptor t) {
-    return uses.get(new TempFlatPair(t,fn));
+    Set<FlatNode> s=uses.get(new TempFlatPair(t,fn));
+    if (s==null)
+      return new HashSet<FlatNode>();
+    else
+      return s;
   }
 
   public void analyze(FlatMethod fm) {
@@ -40,14 +48,16 @@ public class UseDef{
       for(int i=0;i<fn.numPrev();i++) {
        FlatNode prev=fn.getPrev(i);
        Set<TempFlatPair> prevs=tmp.get(prev);
-       nexttfp:
-       for(Iterator<TempFlatPair> tfit=prevs.iterator();tfit.hasNext();) {
-         TempFlatPair tfp=tfit.next();
-         for(int j=0;j<fnwrites.length;j++) {
-           if (tfp.t==fnwrites[j])
-             continue nexttfp;
+       if (prevs!=null) {
+         nexttfp:
+         for(Iterator<TempFlatPair> tfit=prevs.iterator();tfit.hasNext();) {
+           TempFlatPair tfp=tfit.next();
+           for(int j=0;j<fnwrites.length;j++) {
+             if (tfp.t==fnwrites[j])
+               continue nexttfp;
+           }
+           s.add(tfp);
          }
-         s.add(tfp);
        }
        for(int j=0;j<fnwrites.length;j++) {
          TempFlatPair tfp=new TempFlatPair(fnwrites[j], fn);
diff --git a/Robust/src/Analysis/Loops/localCSE.java b/Robust/src/Analysis/Loops/localCSE.java
new file mode 100644 (file)
index 0000000..1678837
--- /dev/null
@@ -0,0 +1,283 @@
+package Analysis.Loops;
+
+import IR.MethodDescriptor;
+import IR.TypeDescriptor;
+import IR.TypeUtil;
+import IR.Operation;
+import IR.Flat.*;
+import IR.FieldDescriptor;
+import java.util.Set;
+import java.util.HashSet;
+import java.util.Hashtable;
+import java.util.Iterator;
+
+public class localCSE {
+  GlobalFieldType gft;
+  TypeUtil typeutil;
+  public localCSE(GlobalFieldType gft, TypeUtil typeutil) {
+    this.gft=gft;
+    this.typeutil=typeutil;
+  }
+  int index;
+
+  public Group getGroup(Hashtable<LocalExpression, Group> tab, TempDescriptor t) {
+    LocalExpression e=new LocalExpression(t);
+    return getGroup(tab, e);
+  }
+  public Group getGroup(Hashtable<LocalExpression, Group> tab, LocalExpression e) {
+    if (tab.containsKey(e))
+      return tab.get(e);
+    else {
+      Group g=new Group(index++);
+      g.set.add(e);
+      tab.put(e,g);
+      return g;
+    }
+  }
+  public TempDescriptor getTemp(Group g) {
+    for(Iterator it=g.set.iterator();it.hasNext();) {
+      LocalExpression e=(LocalExpression)it.next();
+      if (e.t!=null)
+       return e.t;
+    }
+    return null;
+  }
+
+  public void doAnalysis(FlatMethod fm) {
+    Set nodes=fm.getNodeSet();
+    HashSet<FlatNode> toanalyze=new HashSet<FlatNode>();
+    for(Iterator it=nodes.iterator();it.hasNext();) {
+      FlatNode fn=(FlatNode)it.next();
+      if (fn.numPrev()>1)
+       toanalyze.add(fn);
+    }
+    for(Iterator<FlatNode> it=toanalyze.iterator();it.hasNext();) {
+      FlatNode fn=it.next();
+      Hashtable<LocalExpression, Group> table=new Hashtable<LocalExpression,Group>();
+      do {
+       index=0;
+       switch(fn.kind()) {
+       case FKind.FlatOpNode: {
+         FlatOpNode fon=(FlatOpNode)fn;
+         Group left=getGroup(table, fon.getLeft());
+         Group right=getGroup(table, fon.getRight());
+         LocalExpression dst=new LocalExpression(fon.getDest());
+         if (fon.getOp().getOp()==Operation.ASSIGN) {
+           left.set.add(dst);
+           kill(table, fon.getDest());
+           table.put(dst, left);
+         } else {
+           LocalExpression e=new LocalExpression(left, right, fon.getOp());
+           Group g=getGroup(table,e);
+           TempDescriptor td=getTemp(g);
+           if (td!=null) {
+             FlatOpNode nfon=new FlatOpNode(fon.getDest(),td,null,new Operation(Operation.ASSIGN));
+             fn.replace(nfon);
+           }
+           g.set.add(dst);
+           kill(table, fon.getDest());
+           table.put(dst,g);
+         }
+         break;
+       }
+       case FKind.FlatLiteralNode: {
+         FlatLiteralNode fln=(FlatLiteralNode)fn;
+         LocalExpression e=new LocalExpression(fln.getValue());
+         Group src=getGroup(table, e);
+         LocalExpression dst=new LocalExpression(fln.getDst());
+         src.set.add(dst);
+         kill(table, fln.getDst());
+         table.put(dst, src);
+         break;
+       }
+       case FKind.FlatFieldNode: {
+         FlatFieldNode ffn=(FlatFieldNode) fn;
+         Group src=getGroup(table, ffn.getSrc());
+         LocalExpression e=new LocalExpression(src, ffn.getField());
+         Group srcf=getGroup(table, e);
+         LocalExpression dst=new LocalExpression(ffn.getDst());
+         TempDescriptor td=getTemp(srcf);
+         if (td!=null) {
+           FlatOpNode fon=new FlatOpNode(ffn.getDst(),td,null,new Operation(Operation.ASSIGN));
+           fn.replace(fon);
+         }
+         srcf.set.add(dst);
+         kill(table, ffn.getDst());
+         table.put(dst, srcf);
+         break;
+       }
+       case FKind.FlatElementNode: {
+         FlatElementNode fen=(FlatElementNode) fn;
+         Group src=getGroup(table, fen.getSrc());
+         Group index=getGroup(table, fen.getIndex());
+         LocalExpression e=new LocalExpression(src, fen.getSrc().getType(), index);
+         Group srcf=getGroup(table, e);
+         LocalExpression dst=new LocalExpression(fen.getDst());
+         TempDescriptor td=getTemp(srcf);
+         if (td!=null) {
+           FlatOpNode fon=new FlatOpNode(fen.getDst(),td,null,new Operation(Operation.ASSIGN));
+           fn.replace(fon);
+         }
+         srcf.set.add(dst);
+         kill(table, fen.getDst());
+         table.put(dst, srcf);
+         break;
+       }
+       case FKind.FlatSetFieldNode: {
+         FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
+         Group dst=getGroup(table, fsfn.getDst());
+         LocalExpression e=new LocalExpression(dst, fsfn.getField());
+         Group dstf=getGroup(table, e);
+         LocalExpression src=new LocalExpression(fsfn.getSrc());
+         dstf.set.add(src);
+         HashSet<FieldDescriptor> fields=new HashSet<FieldDescriptor>();
+         fields.add(fsfn.getField());
+         kill(table, fields, null);
+         table.put(src, dstf);
+         break;
+       }
+       case FKind.FlatSetElementNode: {
+         FlatSetElementNode fsen=(FlatSetElementNode)fn;
+         Group dst=getGroup(table, fsen.getDst());
+         Group index=getGroup(table, fsen.getIndex());
+         LocalExpression e=new LocalExpression(dst, fsen.getDst().getType(), index);
+         Group dstf=getGroup(table, e);
+         LocalExpression src=new LocalExpression(fsen.getSrc());
+         dstf.set.add(src);
+         HashSet<TypeDescriptor> arrays=new HashSet<TypeDescriptor>();
+         arrays.add(fsen.getDst().getType());
+         kill(table, null, arrays);
+         table.put(src, dstf);
+         break;
+       }
+       case FKind.FlatCall:{
+         //do side effects
+         FlatCall fc=(FlatCall)fn;
+         MethodDescriptor md=fc.getMethod();
+         Set<FieldDescriptor> fields=gft.getFields(md);
+         Set<TypeDescriptor> arrays=gft.getArrays(md);
+         kill(table, fields, arrays);
+       }
+       default: {
+         TempDescriptor[] writes=fn.writesTemps();
+         for(int i=0;i<writes.length;i++) {
+           kill(table,writes[i]);
+         }
+       }
+       }
+      } while(fn.numPrev()==1);
+    }
+  }
+  public void kill(Hashtable<LocalExpression, Group> tab, Set<FieldDescriptor> fields, Set<TypeDescriptor> arrays) {
+    Set<LocalExpression> eset=tab.keySet();
+    for(Iterator<LocalExpression> it=eset.iterator();it.hasNext();) {
+      LocalExpression e=it.next();
+      if (e.td!=null) {
+       //have array
+       TypeDescriptor artd=e.td;
+       for(Iterator<TypeDescriptor> arit=arrays.iterator();arit.hasNext();) {
+         TypeDescriptor td=arit.next();
+         if (typeutil.isSuperorType(artd,td)||
+             typeutil.isSuperorType(td,artd)) {
+           Group g=tab.get(e);
+           g.set.remove(e);
+           it.remove();
+           break;
+         }
+       }
+      } else if (e.f!=null) {
+       if (fields.contains(e.f)) {
+         Group g=tab.get(e);
+         g.set.remove(e);
+         it.remove();
+       }
+      }
+    }
+  }
+  public void kill(Hashtable<LocalExpression, Group> tab, TempDescriptor t) {
+    LocalExpression e=new LocalExpression(t);
+    Group g=tab.get(e);
+    tab.remove(e);
+    g.set.remove(e);
+  }
+}
+
+class Group {
+  HashSet set;
+  int i;
+  Group(int i) {
+    set=new HashSet();
+    this.i=i;
+  }
+  public int hashCode() {
+    return i;
+  }
+  public boolean equals(Object o) {
+    Group g=(Group)o;
+    return i==g.i;
+  }
+}
+
+class LocalExpression {
+  Operation op;
+  Object o;
+  Group a;
+  Group b;
+  TempDescriptor t;
+  FieldDescriptor f;
+  TypeDescriptor td;
+  LocalExpression(TempDescriptor t) {
+    this.t=t;
+  }
+  LocalExpression(Object o) {
+    this.o=o;
+  }
+  LocalExpression(Group a, Group b, Operation op) {
+    this.a=a;
+    this.b=b;
+    this.op=op;
+  }
+  LocalExpression(Group a, FieldDescriptor f) {
+    this.a=a;
+    this.f=f;
+  }
+  LocalExpression(Group a, TypeDescriptor td, Group index) {
+    this.a=a;
+    this.td=td;
+    this.b=index;
+  }
+  public int hashCode() {
+    int h=0;
+    if (td!=null)
+      h^=td.hashCode();
+    if (t!=null)
+      h^=t.hashCode();
+    if (a!=null)
+      h^=a.hashCode();
+    if (o!=null)
+      h^=o.hashCode();
+    if (op!=null)
+      h^=op.getOp();
+    if (b!=null)
+      h^=b.hashCode();
+    if (f!=null)
+      h^=f.hashCode();
+    return h;
+  }
+  public boolean equals(Object o) {
+    LocalExpression e=(LocalExpression)o;
+    if (a!=e.a||f!=e.f||b!=e.b)
+      return false;
+    if (td!=null) {
+      if (!td.equals(e.td))
+       return false;
+    }
+    if (o!=null) {
+      if (e.o==null)
+       return false;
+      return o.equals(e.o);
+    } else if (op!=null)
+      return op.getOp()==e.op.getOp();
+    return true;
+  }
+}
\ No newline at end of file