New pass to decompose multi-dimensional array references into
[oota-llvm.git] / lib / Transforms / Scalar / DecomposeMultiDimRefs.cpp
1 //===- llvm/Transforms/DecomposeArrayRefs.cpp - Lower array refs to 1D -----=//
2 //
3 // DecomposeArrayRefs - 
4 // Convert multi-dimensional array references into a sequence of
5 // instructions (using getelementpr and cast) so that each instruction
6 // has at most one array offset.
7 //
8 //===---------------------------------------------------------------------===//
9
10 #include "llvm/Transforms/DecomposeArrayRefs.h"
11 #include "llvm/iMemory.h"
12 #include "llvm/iOther.h"
13 #include "llvm/BasicBlock.h"
14 #include "llvm/Method.h"
15 #include "llvm/Pass.h"
16
17
18 // 
19 // This function repeats until we have a one-dim. reference: {
20 //      // For an N-dim array ref, where N > 1, insert:
21 //      aptr1 = getElementPtr [N-dim array] * lastPtr, uint firstIndex
22 //      aptr2 = cast [N-dim-arry] * aptr to [<N-1>-dim-array] *
23 // }
24 // Then it replaces the original instruction with an equivalent one that
25 // uses the last aptr2 generated in the loop and a single index.
26 // 
27 static BasicBlock::reverse_iterator
28 decomposeArrayRef(BasicBlock::reverse_iterator& BBI)
29 {
30   MemAccessInst *memI = cast<MemAccessInst>(*BBI);
31   BasicBlock* BB = memI->getParent();
32   Value* lastPtr = memI->getPointerOperand();
33   vector<Instruction*> newIvec;
34   
35   MemAccessInst::const_op_iterator OI = memI->idx_begin();
36   for (MemAccessInst::const_op_iterator OE = memI->idx_end(); OI != OE; ++OI)
37     {
38       if (OI+1 == OE)                     // skip the last operand
39         break;
40       
41       assert(isa<PointerType>(lastPtr->getType()));
42       vector<Value*> idxVec(1, *OI);
43
44       // The first index does not change the type of the pointer
45       // since all pointers are treated as potential arrays (i.e.,
46       // int *X is either a scalar X[0] or an array at X[i]).
47       // 
48       const Type* nextPtrType;
49       // if (OI == memI->idx_begin())
50       //   nextPtrType = lastPtr->getType();
51       // else
52       //   {
53              const Type* nextArrayType =  
54                MemAccessInst::getIndexedType(lastPtr->getType(), idxVec,
55                                              /*allowCompositeLeaf*/ true);
56              nextPtrType = PointerType::get(cast<SequentialType>(nextArrayType)
57                                             ->getElementType());
58       //   }
59       
60       Instruction* gepInst  = new GetElementPtrInst(lastPtr, idxVec, "aptr1");
61       Instruction* castInst = new CastInst(gepInst, nextPtrType, "aptr2");
62       lastPtr  = castInst;
63       
64       newIvec.push_back(gepInst);
65       newIvec.push_back(castInst);
66     }
67   
68   // Now create a new instruction to replace the original one
69   assert(lastPtr != memI->getPointerOperand() && "the above loop did not execute?");
70   assert(isa<PointerType>(lastPtr->getType()));
71   vector<Value*> idxVec(1, *OI);
72   const std::string newInstName = memI->hasName()? memI->getName()
73                                                  : string("oneDimRef");
74   Instruction* newInst = NULL;
75   
76   switch(memI->getOpcode())
77     {
78     case Instruction::Load:
79       newInst = new LoadInst(lastPtr, idxVec /*, newInstName */); break;
80     case Instruction::Store:
81       newInst = new StoreInst(memI->getOperand(0),
82                               lastPtr, idxVec /*, newInstName */); break;
83       break;
84     case Instruction::GetElementPtr:
85       newInst = new GetElementPtrInst(lastPtr, idxVec /*, newInstName */); break;
86     default:
87       assert(0 && "Unrecognized memory access instruction"); break;
88     }
89   
90   newIvec.push_back(newInst);
91   
92   // Replace all uses of the old instruction with the new
93   memI->replaceAllUsesWith(newInst);
94   
95   // Insert the instructions created in reverse order.  insert is destructive
96   // so we always have to use the new pointer returned by insert.
97   BasicBlock::iterator newI = BBI.base(); // gives ptr to instr. after memI
98   --newI;                                 // step back to memI
99   for (int i = newIvec.size()-1; i >= 0; i--)
100     newI = BB->getInstList().insert(newI, newIvec[i]);
101   
102   // Now delete the old instruction and return a pointer to the first new one
103   BB->getInstList().remove(memI);
104   delete memI;
105   
106   BasicBlock::reverse_iterator retI(newI); // reverse ptr to instr before newI
107   return --retI;                           // reverse pointer to newI
108 }
109
110
111 //---------------------------------------------------------------------------
112 // Entry point for decomposing multi-dimensional array references
113 //---------------------------------------------------------------------------
114
115 static bool
116 doDecomposeArrayRefs(Method *M)
117 {
118   bool changed = false;
119   
120   for (Method::iterator BI = M->begin(), BE = M->end(); BI != BE; ++BI)
121     for (BasicBlock::reverse_iterator newI, II=(*BI)->rbegin();
122          II != (*BI)->rend(); II = ++newI)
123       {
124         newI = II;
125         if (MemAccessInst *memI = dyn_cast<MemAccessInst>(*II))
126           { // Check for a multi-dimensional array access
127             const PointerType* ptrType =
128               cast<PointerType>(memI->getPointerOperand()->getType()); 
129             if (isa<ArrayType>(ptrType->getElementType()) &&
130                 memI->getNumOperands() > 1+ memI->getFirstIndexOperandNumber())
131               {
132                 newI = decomposeArrayRef(II);
133                 changed = true;
134               }
135           }
136       }
137   
138   return changed;
139 }
140
141
142 namespace {
143   struct DecomposeArrayRefsPass : public MethodPass {
144     virtual bool runOnMethod(Method *M) { return doDecomposeArrayRefs(M); }
145   };
146 }
147
148 Pass *createDecomposeArrayRefsPass() { return new DecomposeArrayRefsPass(); }