Implement the trivial cases in InstCombine/store.ll
[oota-llvm.git] / lib / Transforms / Scalar / LowerPacked.cpp
1 //===- LowerPacked.cpp -  Implementation of LowerPacked Transform ---------===//
2 // 
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by Brad Jones and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 // 
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements lowering Packed datatypes into more primitive
11 // Packed datatypes, and finally to scalar operations.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/Transforms/Scalar.h"
16 #include "llvm/Argument.h"
17 #include "llvm/Constants.h"
18 #include "llvm/DerivedTypes.h"
19 #include "llvm/Function.h"
20 #include "llvm/Instructions.h"
21 #include "llvm/Pass.h"
22 #include "llvm/Support/InstVisitor.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include <algorithm>
25 #include <map>
26 #include <iostream>
27
28 using namespace llvm;
29
30 namespace {
31
32 /// This pass converts packed operators to an
33 /// equivalent operations on smaller packed data, to possibly
34 /// scalar operations.  Currently it supports lowering
35 /// to scalar operations.
36 ///
37 /// @brief Transforms packed instructions to simpler instructions.
38 ///
39 class LowerPacked : public FunctionPass, public InstVisitor<LowerPacked> {
40 public:
41    /// @brief Lowers packed operations to scalar operations. 
42    /// @param F The fuction to process
43    virtual bool runOnFunction(Function &F);
44
45    /// @brief Lowers packed load instructions.
46    /// @param LI the load instruction to convert
47    void visitLoadInst(LoadInst& LI);
48
49    /// @brief Lowers packed store instructions.
50    /// @param SI the store instruction to convert
51    void visitStoreInst(StoreInst& SI);
52
53    /// @brief Lowers packed binary operations.
54    /// @param BO the binary operator to convert
55    void visitBinaryOperator(BinaryOperator& BO);
56
57    /// @brief Lowers packed select instructions.
58    /// @param SELI the select operator to convert
59    void visitSelectInst(SelectInst& SELI);
60
61    /// This function asserts if the instruction is a PackedType but
62    /// is handled by another function.
63    /// 
64    /// @brief Asserts if PackedType instruction is not handled elsewhere.
65    /// @param I the unhandled instruction
66    void visitInstruction(Instruction &I)
67    {
68       if(isa<PackedType>(I.getType())) {
69          std::cerr << "Unhandled Instruction with Packed ReturnType: " << 
70                       I << '\n';
71       }
72    }
73 private:
74    /// @brief Retrieves lowered values for a packed value.
75    /// @param val the packed value
76    /// @return the lowered values
77    std::vector<Value*>& getValues(Value* val);
78
79    /// @brief Sets lowered values for a packed value.
80    /// @param val the packed value
81    /// @param values the corresponding lowered values
82    void setValues(Value* val,const std::vector<Value*>& values);
83
84    // Data Members
85    /// @brief whether we changed the function or not   
86    bool Changed;
87
88    /// @brief a map from old packed values to new smaller packed values
89    std::map<Value*,std::vector<Value*> > packedToScalarMap;
90
91    /// Instructions in the source program to get rid of
92    /// after we do a pass (the old packed instructions)
93    std::vector<Instruction*> instrsToRemove;
94 }; 
95
96 RegisterOpt<LowerPacked> 
97 X("lower-packed", 
98   "lowers packed operations to operations on smaller packed datatypes");
99
100 } // end namespace   
101
102 FunctionPass *llvm::createLowerPackedPass() { return new LowerPacked(); }
103
104
105 // This function sets lowered values for a corresponding
106 // packed value.  Note, in the case of a forward reference
107 // getValues(Value*) will have already been called for 
108 // the packed parameter.  This function will then replace 
109 // all references in the in the function of the "dummy" 
110 // value the previous getValues(Value*) call 
111 // returned with actual references.
112 void LowerPacked::setValues(Value* value,const std::vector<Value*>& values)
113 {
114    std::map<Value*,std::vector<Value*> >::iterator it = 
115          packedToScalarMap.lower_bound(value);
116    if (it == packedToScalarMap.end() || it->first != value) {
117        // there was not a forward reference to this element
118        packedToScalarMap.insert(it,std::make_pair(value,values));
119    }
120    else {
121       // replace forward declarations with actual definitions
122       assert(it->second.size() == values.size() && 
123              "Error forward refences and actual definition differ in size");
124       for (unsigned i = 0, e = values.size(); i != e; ++i) {
125            // replace and get rid of old forward references
126            it->second[i]->replaceAllUsesWith(values[i]);
127            delete it->second[i];
128            it->second[i] = values[i];
129       }
130    }
131 }
132
133 // This function will examine the packed value parameter
134 // and if it is a packed constant or a forward reference
135 // properly create the lowered values needed.  Otherwise
136 // it will simply retreive values from a  
137 // setValues(Value*,const std::vector<Value*>&) 
138 // call.  Failing both of these cases, it will abort
139 // the program.
140 std::vector<Value*>& LowerPacked::getValues(Value* value)
141 {
142    assert(isa<PackedType>(value->getType()) &&
143           "Value must be PackedType");
144
145    // reject further processing if this one has
146    // already been handled
147    std::map<Value*,std::vector<Value*> >::iterator it = 
148       packedToScalarMap.lower_bound(value);
149    if (it != packedToScalarMap.end() && it->first == value) {
150        return it->second;
151    }
152
153    if (ConstantPacked* CP = dyn_cast<ConstantPacked>(value)) {
154        // non-zero constant case
155        std::vector<Value*> results;
156        results.reserve(CP->getNumOperands());
157        for (unsigned i = 0, e = CP->getNumOperands(); i != e; ++i) {
158           results.push_back(CP->getOperand(i));
159        }
160        return packedToScalarMap.insert(it,
161                                        std::make_pair(value,results))->second;
162    }
163    else if (ConstantAggregateZero* CAZ =
164             dyn_cast<ConstantAggregateZero>(value)) {
165        // zero constant 
166        const PackedType* PKT = cast<PackedType>(CAZ->getType());
167        std::vector<Value*> results;
168        results.reserve(PKT->getNumElements());
169    
170        Constant* C = Constant::getNullValue(PKT->getElementType());
171        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
172             results.push_back(C);
173        }
174        return packedToScalarMap.insert(it,
175                                        std::make_pair(value,results))->second;
176    }
177    else if (isa<Instruction>(value)) {
178        // foward reference
179        const PackedType* PKT = cast<PackedType>(value->getType());
180        std::vector<Value*> results;
181        results.reserve(PKT->getNumElements());
182    
183       for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
184            results.push_back(new Argument(PKT->getElementType()));
185       }
186       return packedToScalarMap.insert(it,
187                                       std::make_pair(value,results))->second;
188    }
189    else {
190        // we don't know what it is, and we are trying to retrieve
191        // a value for it
192        assert(false && "Unhandled PackedType value");
193        abort();
194    }
195 }
196
197 void LowerPacked::visitLoadInst(LoadInst& LI)
198 {
199    // Make sure what we are dealing with is a packed type
200    if (const PackedType* PKT = dyn_cast<PackedType>(LI.getType())) {
201        // Initialization, Idx is needed for getelementptr needed later
202        std::vector<Value*> Idx(2);
203        Idx[0] = ConstantUInt::get(Type::UIntTy,0);
204
205        ArrayType* AT = ArrayType::get(PKT->getContainedType(0),
206                                       PKT->getNumElements());
207        PointerType* APT = PointerType::get(AT);
208
209        // Cast the packed type to an array
210        Value* array = new CastInst(LI.getPointerOperand(),
211                                    APT,
212                                    LI.getName() + ".a",
213                                    &LI);
214
215        // Convert this load into num elements number of loads
216        std::vector<Value*> values;
217        values.reserve(PKT->getNumElements());
218
219        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
220             // Calculate the second index we will need
221             Idx[1] = ConstantUInt::get(Type::UIntTy,i);
222
223             // Get the pointer
224             Value* val = new GetElementPtrInst(array, 
225                                                Idx,
226                                                LI.getName() + 
227                                                ".ge." + utostr(i),
228                                                &LI);
229
230             // generate the new load and save the result in packedToScalar map
231             values.push_back(new LoadInst(val, 
232                              LI.getName()+"."+utostr(i),
233                              LI.isVolatile(),
234                              &LI));
235        }
236                
237        setValues(&LI,values);
238        Changed = true;
239        instrsToRemove.push_back(&LI);
240    }
241 }
242
243 void LowerPacked::visitBinaryOperator(BinaryOperator& BO)
244 {
245    // Make sure both operands are PackedTypes
246    if (isa<PackedType>(BO.getOperand(0)->getType())) {
247        std::vector<Value*>& op0Vals = getValues(BO.getOperand(0));
248        std::vector<Value*>& op1Vals = getValues(BO.getOperand(1));
249        std::vector<Value*> result;
250        assert((op0Vals.size() == op1Vals.size()) &&
251               "The two packed operand to scalar maps must be equal in size.");
252
253        result.reserve(op0Vals.size());
254    
255        // generate the new binary op and save the result
256        for (unsigned i = 0; i != op0Vals.size(); ++i) {
257             result.push_back(BinaryOperator::create(BO.getOpcode(), 
258                                                     op0Vals[i], 
259                                                     op1Vals[i],
260                                                     BO.getName() + 
261                                                     "." + utostr(i),
262                                                     &BO));
263        }
264
265        setValues(&BO,result);
266        Changed = true;
267        instrsToRemove.push_back(&BO);
268    }
269 }
270
271 void LowerPacked::visitStoreInst(StoreInst& SI)
272 {
273    if (const PackedType* PKT = 
274        dyn_cast<PackedType>(SI.getOperand(0)->getType())) {
275        // We will need this for getelementptr
276        std::vector<Value*> Idx(2);
277        Idx[0] = ConstantUInt::get(Type::UIntTy,0);
278          
279        ArrayType* AT = ArrayType::get(PKT->getContainedType(0),
280                                       PKT->getNumElements());
281        PointerType* APT = PointerType::get(AT);
282
283        // cast the packed to an array type
284        Value* array = new CastInst(SI.getPointerOperand(),
285                                    APT,
286                                    "store.ge.a.",
287                                    &SI);
288        std::vector<Value*>& values = getValues(SI.getOperand(0));
289       
290        assert((values.size() == PKT->getNumElements()) &&
291               "Scalar must have the same number of elements as Packed Type");
292
293        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
294             // Generate the indices for getelementptr
295             Idx[1] = ConstantUInt::get(Type::UIntTy,i);
296             Value* val = new GetElementPtrInst(array, 
297                                                Idx,
298                                                "store.ge." +
299                                                utostr(i) + ".",
300                                                &SI);
301             new StoreInst(values[i], val, SI.isVolatile(),&SI);
302        }
303                  
304        Changed = true;
305        instrsToRemove.push_back(&SI);
306    }
307 }
308
309 void LowerPacked::visitSelectInst(SelectInst& SELI)
310 {
311    // Make sure both operands are PackedTypes
312    if (isa<PackedType>(SELI.getType())) {
313        std::vector<Value*>& op0Vals = getValues(SELI.getTrueValue());
314        std::vector<Value*>& op1Vals = getValues(SELI.getFalseValue());
315        std::vector<Value*> result;
316
317       assert((op0Vals.size() == op1Vals.size()) &&
318              "The two packed operand to scalar maps must be equal in size.");
319
320       for (unsigned i = 0; i != op0Vals.size(); ++i) {
321            result.push_back(new SelectInst(SELI.getCondition(),
322                                            op0Vals[i], 
323                                            op1Vals[i],
324                                            SELI.getName()+ "." + utostr(i),
325                                            &SELI));
326       }
327    
328       setValues(&SELI,result);
329       Changed = true;
330       instrsToRemove.push_back(&SELI);
331    }
332 }
333
334 bool LowerPacked::runOnFunction(Function& F)
335 {
336    // initialize
337    Changed = false; 
338   
339    // Does three passes:
340    // Pass 1) Converts Packed Operations to 
341    //         new Packed Operations on smaller
342    //         datatypes
343    visit(F);
344   
345    // Pass 2) Drop all references
346    std::for_each(instrsToRemove.begin(),
347                  instrsToRemove.end(),
348                  std::mem_fun(&Instruction::dropAllReferences));
349
350    // Pass 3) Delete the Instructions to remove aka packed instructions
351    for (std::vector<Instruction*>::iterator i = instrsToRemove.begin(), 
352                                             e = instrsToRemove.end(); 
353         i != e; ++i) {
354         (*i)->getParent()->getInstList().erase(*i);   
355    }
356
357    // clean-up
358    packedToScalarMap.clear();
359    instrsToRemove.clear();
360
361    return Changed;
362 }
363