Change references to the Method class to be references to the Function
[oota-llvm.git] / lib / Transforms / Scalar / InstructionCombining.cpp
1 //===- InstructionCombining.cpp - Combine multiple instructions -------------=//
2 //
3 // InstructionCombining - Combine instructions to form fewer, simple
4 //   instructions.  This pass does not modify the CFG, and has a tendancy to
5 //   make instructions dead, so a subsequent DCE pass is useful.
6 //
7 // This pass combines things like:
8 //    %Y = add int 1, %X
9 //    %Z = add int 1, %Y
10 // into:
11 //    %Z = add int 2, %X
12 //
13 // This is a simple worklist driven algorithm.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "llvm/Transforms/Scalar/InstructionCombining.h"
18 #include "llvm/Transforms/Scalar/ConstantHandling.h"
19 #include "llvm/Function.h"
20 #include "llvm/iMemory.h"
21 #include "llvm/InstrTypes.h"
22 #include "llvm/Pass.h"
23 #include "llvm/Support/InstIterator.h"
24 #include "../TransformInternals.h"
25
26 static Instruction *CombineBinOp(BinaryOperator *I) {
27   bool Changed = false;
28
29   // First thing we do is make sure that this instruction has a constant on the
30   // right hand side if it has any constant arguments.
31   //
32   if (isa<Constant>(I->getOperand(0)) && !isa<Constant>(I->getOperand(1)))
33     if (!I->swapOperands())
34       Changed = true;
35
36   bool LocalChange = true;
37   while (LocalChange) {
38     LocalChange = false;
39     Value *Op1 = I->getOperand(0);
40     if (Constant *Op2 = dyn_cast<Constant>(I->getOperand(1))) {
41       switch (I->getOpcode()) {
42       case Instruction::Add:
43         if (I->getType()->isIntegral() && cast<ConstantInt>(Op2)->equalsInt(0)){
44           // Eliminate 'add int %X, 0'
45           I->replaceAllUsesWith(Op1);       // FIXME: This breaks the worklist
46           Changed = true;
47           return I;
48         }
49
50         if (Instruction *IOp1 = dyn_cast<Instruction>(Op1)) {
51           if (IOp1->getOpcode() == Instruction::Add &&
52               isa<Constant>(IOp1->getOperand(1))) {
53             // Fold:
54             //    %Y = add int %X, 1
55             //    %Z = add int %Y, 1
56             // into:
57             //    %Z = add int %X, 2
58             //   
59             // Constant fold both constants...
60             Constant *Val = *Op2 + *cast<Constant>(IOp1->getOperand(1));
61             
62             if (Val) {
63               I->setOperand(0, IOp1->getOperand(0));
64               I->setOperand(1, Val);
65               LocalChange = true;
66               break;
67             }
68           }
69           
70         }
71         break;
72
73       case Instruction::Mul:
74         if (I->getType()->isIntegral() && cast<ConstantInt>(Op2)->equalsInt(1)){
75           // Eliminate 'mul int %X, 1'
76           I->replaceAllUsesWith(Op1);      // FIXME: This breaks the worklist
77           LocalChange = true;
78           break;
79         }
80
81       default:
82         break;
83       }
84     }
85     Changed |= LocalChange;
86   }
87
88   if (!Changed) return 0;
89   return I;
90 }
91
92 // Combine Indices - If the source pointer to this mem access instruction is a
93 // getelementptr instruction, combine the indices of the GEP into this
94 // instruction
95 //
96 static Instruction *CombineIndicies(MemAccessInst *MAI) {
97   GetElementPtrInst *Src =
98     dyn_cast<GetElementPtrInst>(MAI->getPointerOperand());
99   if (!Src) return 0;
100
101   std::vector<Value *> Indices;
102   
103   // Only special case we have to watch out for is pointer arithmetic on the
104   // 0th index of MAI. 
105   unsigned FirstIdx = MAI->getFirstIndexOperandNumber();
106   if (FirstIdx == MAI->getNumOperands() || 
107       (FirstIdx == MAI->getNumOperands()-1 &&
108        MAI->getOperand(FirstIdx) == ConstantUInt::get(Type::UIntTy, 0))) { 
109     // Replace the index list on this MAI with the index on the getelementptr
110     Indices.insert(Indices.end(), Src->idx_begin(), Src->idx_end());
111   } else if (*MAI->idx_begin() == ConstantUInt::get(Type::UIntTy, 0)) { 
112     // Otherwise we can do the fold if the first index of the GEP is a zero
113     Indices.insert(Indices.end(), Src->idx_begin(), Src->idx_end());
114     Indices.insert(Indices.end(), MAI->idx_begin()+1, MAI->idx_end());
115   }
116
117   if (Indices.empty()) return 0;  // Can't do the fold?
118
119   switch (MAI->getOpcode()) {
120   case Instruction::GetElementPtr:
121     return new GetElementPtrInst(Src->getOperand(0), Indices, MAI->getName());
122   case Instruction::Load:
123     return new LoadInst(Src->getOperand(0), Indices, MAI->getName());
124   case Instruction::Store:
125     return new StoreInst(MAI->getOperand(0), Src->getOperand(0),
126                          Indices, MAI->getName());
127   default:
128     assert(0 && "Unknown memaccessinst!");
129     break;
130   }
131   abort();
132   return 0;
133 }
134
135 static bool CombineInstruction(Instruction *I) {
136   Instruction *Result = 0;
137   if (BinaryOperator *BOP = dyn_cast<BinaryOperator>(I))
138     Result = CombineBinOp(BOP);
139   else if (MemAccessInst *MAI = dyn_cast<MemAccessInst>(I))
140     Result = CombineIndicies(MAI);
141
142   if (!Result) return false;
143   if (Result == I) return true;
144
145   // If we get to here, we are to replace I with Result.
146   ReplaceInstWithInst(I, Result);
147   return true;
148 }
149
150 static bool doInstCombining(Function *M) {
151   // Start the worklist out with all of the instructions in the function in it.
152   std::vector<Instruction*> WorkList(inst_begin(M), inst_end(M));
153
154   while (!WorkList.empty()) {
155     Instruction *I = WorkList.back();  // Get an instruction from the worklist
156     WorkList.pop_back();
157
158     // Now that we have an instruction, try combining it to simplify it...
159     if (CombineInstruction(I)) {
160       // The instruction was simplified, add all users of the instruction to
161       // the work lists because they might get more simplified now...
162       //
163       for (Value::use_iterator UI = I->use_begin(), UE = I->use_end();
164            UI != UE; ++UI)
165         if (Instruction *User = dyn_cast<Instruction>(*UI))
166           WorkList.push_back(User);
167     }
168   }
169
170   return false;
171 }
172
173 namespace {
174   struct InstructionCombining : public MethodPass {
175     virtual bool runOnMethod(Function *F) { return doInstCombining(F); }
176   };
177 }
178
179 Pass *createInstructionCombiningPass() {
180   return new InstructionCombining();
181 }