SLPVectorizer: Add support for trees that don't start at binary operators, and add...
[oota-llvm.git] / lib / Transforms / Vectorize / SLPVectorizer.cpp
1 //===- SLPVectorizer.cpp - A bottom up SLP Vectorizer ---------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 // This pass implements the Bottom Up SLP vectorizer. It detects consecutive
10 // stores that can be put together into vector-stores. Next, it attempts to
11 // construct vectorizable tree using the use-def chains. If a profitable tree
12 // was found, the SLP vectorizer performs vectorization on the tree.
13 //
14 // The pass is inspired by the work described in the paper:
15 //  "Loop-Aware SLP in GCC" by Ira Rosen, Dorit Nuzman, Ayal Zaks.
16 //
17 //===----------------------------------------------------------------------===//
18 #define SV_NAME "slp-vectorizer"
19 #define DEBUG_TYPE SV_NAME
20
21 #include "VecUtils.h"
22 #include "llvm/Transforms/Vectorize.h"
23 #include "llvm/Analysis/AliasAnalysis.h"
24 #include "llvm/Analysis/ScalarEvolution.h"
25 #include "llvm/Analysis/TargetTransformInfo.h"
26 #include "llvm/Analysis/Verifier.h"
27 #include "llvm/IR/DataLayout.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/IntrinsicInst.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/IR/Value.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <map>
38
39 using namespace llvm;
40
41 static cl::opt<int>
42 SLPCostThreshold("slp-threshold", cl::init(0), cl::Hidden,
43                  cl::desc("Only vectorize trees if the gain is above this "
44                           "number. (gain = -cost of vectorization)"));
45 namespace {
46
47 /// The SLPVectorizer Pass.
48 struct SLPVectorizer : public BasicBlockPass {
49   typedef std::map<Value*, BoUpSLP::StoreList> StoreListMap;
50
51   /// Pass identification, replacement for typeid
52   static char ID;
53
54   explicit SLPVectorizer() : BasicBlockPass(ID) {
55     initializeSLPVectorizerPass(*PassRegistry::getPassRegistry());
56   }
57
58   ScalarEvolution *SE;
59   DataLayout *DL;
60   TargetTransformInfo *TTI;
61   AliasAnalysis *AA;
62
63   /// \brief Collect memory references and sort them according to their base
64   /// object. We sort the stores to their base objects to reduce the cost of the
65   /// quadratic search on the stores. TODO: We can further reduce this cost
66   /// if we flush the chain creation every time we run into a memory barrier.
67   bool collectStores(BasicBlock *BB, BoUpSLP &R) {
68     for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
69       StoreInst *SI = dyn_cast<StoreInst>(it);
70       if (!SI)
71         continue;
72
73       // Check that the pointer points to scalars.
74       if (SI->getValueOperand()->getType()->isAggregateType())
75         return false;
76
77       // Find the base of the GEP.
78       Value *Ptr = SI->getPointerOperand();
79       if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr))
80         Ptr = GEP->getPointerOperand();
81
82       // Save the store locations.
83       StoreRefs[Ptr].push_back(SI);
84     }
85     return true;
86   }
87
88   bool tryToVectorizePair(Value *A, Value *B,  BoUpSLP &R) {
89     if (!A || !B) return false;
90     BoUpSLP::ValueList VL;
91     VL.push_back(A);
92     VL.push_back(B);
93     int Cost = R.getTreeCost(VL);
94     int ExtrCost = R.getScalarizationCost(VL);
95     DEBUG(dbgs()<<"SLP: Cost of pair:" << Cost <<
96                   " Cost of extract:" << ExtrCost << ".\n");
97     if ((Cost+ExtrCost) >= -SLPCostThreshold) return false;
98     DEBUG(dbgs()<<"SLP: Vectorizing pair.\n");
99     R.vectorizeArith(VL);
100     return true;
101   }
102
103   bool tryToVectorizeCandidate(BinaryOperator *V,  BoUpSLP &R) {
104     if (!V) return false;
105     // Try to vectorize V.
106     if (tryToVectorizePair(V->getOperand(0), V->getOperand(1), R))
107       return true;
108
109     BinaryOperator *A = dyn_cast<BinaryOperator>(V->getOperand(0));
110     BinaryOperator *B = dyn_cast<BinaryOperator>(V->getOperand(1));
111     // Try to skip B.
112     if (B && B->hasOneUse()) {
113       BinaryOperator *B0 = dyn_cast<BinaryOperator>(B->getOperand(0));
114       BinaryOperator *B1 = dyn_cast<BinaryOperator>(B->getOperand(1));
115       if (tryToVectorizePair(A, B0, R)) {
116         B->moveBefore(V);
117         return true;
118       }
119       if (tryToVectorizePair(A, B1, R)) {
120         B->moveBefore(V);
121         return true;
122       }
123     }
124
125     // Try to slip A.
126     if (A && A->hasOneUse()) {
127       BinaryOperator *A0 = dyn_cast<BinaryOperator>(A->getOperand(0));
128       BinaryOperator *A1 = dyn_cast<BinaryOperator>(A->getOperand(1));
129       if (tryToVectorizePair(A0, B, R)) {
130         A->moveBefore(V);
131         return true;
132       }
133       if (tryToVectorizePair(A1, B, R)) {
134         A->moveBefore(V);
135         return true;
136       }
137     }
138     return 0;
139   }
140
141   bool vectorizeReductions(BasicBlock *BB, BoUpSLP &R) {
142     bool Changed = false;
143     for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
144       if (isa<DbgInfoIntrinsic>(it)) continue;
145       PHINode *P = dyn_cast<PHINode>(it);
146       if (!P) return Changed;
147       // Check that the PHI is a reduction PHI.
148       if (P->getNumIncomingValues() != 2) return Changed;
149       Value *Rdx = (P->getIncomingBlock(0) == BB ? P->getIncomingValue(0) :
150                    (P->getIncomingBlock(1) == BB ? P->getIncomingValue(1) : 0));
151       // Check if this is a Binary Operator.
152       BinaryOperator *BI = dyn_cast_or_null<BinaryOperator>(Rdx);
153       if (!BI) continue;
154
155       Value *Inst = BI->getOperand(0);
156       if (Inst == P) Inst = BI->getOperand(1);
157       Changed |= tryToVectorizeCandidate(dyn_cast<BinaryOperator>(Inst), R);
158     }
159
160     return Changed;
161   }
162
163   bool rollStoreChains(BoUpSLP &R) {
164     bool Changed = false;
165     // Attempt to sort and vectorize each of the store-groups.
166     for (StoreListMap::iterator it = StoreRefs.begin(), e = StoreRefs.end();
167          it != e; ++it) {
168       if (it->second.size() < 2)
169         continue;
170
171       DEBUG(dbgs()<<"SLP: Analyzing a store chain of length " <<
172             it->second.size() << ".\n");
173
174       Changed |= R.vectorizeStores(it->second, -SLPCostThreshold);
175     }
176     return Changed;
177   }
178
179   virtual bool runOnBasicBlock(BasicBlock &BB) {
180     SE = &getAnalysis<ScalarEvolution>();
181     DL = getAnalysisIfAvailable<DataLayout>();
182     TTI = &getAnalysis<TargetTransformInfo>();
183     AA = &getAnalysis<AliasAnalysis>();
184     StoreRefs.clear();
185
186     // Must have DataLayout. We can't require it because some tests run w/o
187     // triple.
188     if (!DL)
189       return false;
190
191     // Use the bollom up slp vectorizer to construct chains that start with
192     // he store instructions.
193     BoUpSLP R(&BB, SE, DL, TTI, AA);
194
195     bool Changed = vectorizeReductions(&BB, R);
196
197     if (!collectStores(&BB, R))
198       return Changed;
199
200     if (rollStoreChains(R)) {
201       DEBUG(dbgs()<<"SLP: vectorized in \""<<BB.getParent()->getName()<<"\"\n");
202       DEBUG(verifyFunction(*BB.getParent()));
203       Changed |= true;
204     }
205
206     return Changed;
207   }
208
209   virtual void getAnalysisUsage(AnalysisUsage &AU) const {
210     BasicBlockPass::getAnalysisUsage(AU);
211     AU.addRequired<ScalarEvolution>();
212     AU.addRequired<AliasAnalysis>();
213     AU.addRequired<TargetTransformInfo>();
214   }
215
216 private:
217   StoreListMap StoreRefs;
218 };
219
220 } // end anonymous namespace
221
222 char SLPVectorizer::ID = 0;
223 static const char lv_name[] = "SLP Vectorizer";
224 INITIALIZE_PASS_BEGIN(SLPVectorizer, SV_NAME, lv_name, false, false)
225 INITIALIZE_AG_DEPENDENCY(AliasAnalysis)
226 INITIALIZE_AG_DEPENDENCY(TargetTransformInfo)
227 INITIALIZE_PASS_DEPENDENCY(ScalarEvolution)
228 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
229 INITIALIZE_PASS_END(SLPVectorizer, SV_NAME, lv_name, false, false)
230
231 namespace llvm {
232   Pass *createSLPVectorizerPass() {
233     return new SLPVectorizer();
234   }
235 }
236