SLP Vectorizer: Erase instructions outside the vectorizeTree method.
[oota-llvm.git] / lib / Transforms / Vectorize / SLPVectorizer.cpp
1 //===- SLPVectorizer.cpp - A bottom up SLP Vectorizer ---------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 // This pass implements the Bottom Up SLP vectorizer. It detects consecutive
10 // stores that can be put together into vector-stores. Next, it attempts to
11 // construct vectorizable tree using the use-def chains. If a profitable tree
12 // was found, the SLP vectorizer performs vectorization on the tree.
13 //
14 // The pass is inspired by the work described in the paper:
15 //  "Loop-Aware SLP in GCC" by Ira Rosen, Dorit Nuzman, Ayal Zaks.
16 //
17 //===----------------------------------------------------------------------===//
18 #define SV_NAME "slp-vectorizer"
19 #define DEBUG_TYPE "SLP"
20
21 #include "llvm/Transforms/Vectorize.h"
22 #include "llvm/ADT/MapVector.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/Analysis/AliasAnalysis.h"
26 #include "llvm/Analysis/ScalarEvolution.h"
27 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
28 #include "llvm/Analysis/AliasAnalysis.h"
29 #include "llvm/Analysis/TargetTransformInfo.h"
30 #include "llvm/Analysis/Verifier.h"
31 #include "llvm/Analysis/LoopInfo.h"
32 #include "llvm/IR/DataLayout.h"
33 #include "llvm/IR/Instructions.h"
34 #include "llvm/IR/IntrinsicInst.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/Module.h"
37 #include "llvm/IR/Type.h"
38 #include "llvm/IR/Value.h"
39 #include "llvm/Pass.h"
40 #include "llvm/Support/CommandLine.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/raw_ostream.h"
43 #include <algorithm>
44 #include <map>
45
46 using namespace llvm;
47
48 static cl::opt<int>
49     SLPCostThreshold("slp-threshold", cl::init(0), cl::Hidden,
50                      cl::desc("Only vectorize trees if the gain is above this "
51                               "number. (gain = -cost of vectorization)"));
52 namespace {
53
54 static const unsigned MinVecRegSize = 128;
55
56 static const unsigned RecursionMaxDepth = 6;
57
58 /// RAII pattern to save the insertion point of the IR builder.
59 class BuilderLocGuard {
60 public:
61   BuilderLocGuard(IRBuilder<> &B) : Builder(B), Loc(B.GetInsertPoint()) {}
62   ~BuilderLocGuard() { Builder.SetInsertPoint(Loc); }
63
64 private:
65   // Prevent copying.
66   BuilderLocGuard(const BuilderLocGuard &);
67   BuilderLocGuard &operator=(const BuilderLocGuard &);
68   IRBuilder<> &Builder;
69   BasicBlock::iterator Loc;
70 };
71
72 /// A helper class for numbering instructions in multible blocks.
73 /// Numbers starts at zero for each basic block.
74 struct BlockNumbering {
75
76   BlockNumbering(BasicBlock *Bb) : BB(Bb), Valid(false) {}
77
78   BlockNumbering() : BB(0), Valid(false) {}
79
80   void numberInstructions() {
81     unsigned Loc = 0;
82     InstrIdx.clear();
83     InstrVec.clear();
84     // Number the instructions in the block.
85     for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
86       InstrIdx[it] = Loc++;
87       InstrVec.push_back(it);
88       assert(InstrVec[InstrIdx[it]] == it && "Invalid allocation");
89     }
90     Valid = true;
91   }
92
93   int getIndex(Instruction *I) {
94     if (!Valid)
95       numberInstructions();
96     assert(InstrIdx.count(I) && "Unknown instruction");
97     return InstrIdx[I];
98   }
99
100   Instruction *getInstruction(unsigned loc) {
101     if (!Valid)
102       numberInstructions();
103     assert(InstrVec.size() > loc && "Invalid Index");
104     return InstrVec[loc];
105   }
106
107   void forget() { Valid = false; }
108
109 private:
110   /// The block we are numbering.
111   BasicBlock *BB;
112   /// Is the block numbered.
113   bool Valid;
114   /// Maps instructions to numbers and back.
115   SmallDenseMap<Instruction *, int> InstrIdx;
116   /// Maps integers to Instructions.
117   std::vector<Instruction *> InstrVec;
118 };
119
120 class FuncSLP {
121   typedef SmallVector<Value *, 8> ValueList;
122   typedef SmallVector<Instruction *, 16> InstrList;
123   typedef SmallPtrSet<Value *, 16> ValueSet;
124   typedef SmallVector<StoreInst *, 8> StoreList;
125
126 public:
127   static const int MAX_COST = INT_MIN;
128
129   FuncSLP(Function *Func, ScalarEvolution *Se, DataLayout *Dl,
130           TargetTransformInfo *Tti, AliasAnalysis *Aa, LoopInfo *Li) :
131     F(Func), SE(Se), DL(Dl), TTI(Tti), AA(Aa), LI(Li),
132     Builder(Se->getContext()) {
133     for (Function::iterator it = F->begin(), e = F->end(); it != e; ++it) {
134       BasicBlock *BB = it;
135       BlocksNumbers[BB] = BlockNumbering(BB);
136     }
137   }
138
139   /// \brief Take the pointer operand from the Load/Store instruction.
140   /// \returns NULL if this is not a valid Load/Store instruction.
141   static Value *getPointerOperand(Value *I);
142
143   /// \brief Take the address space operand from the Load/Store instruction.
144   /// \returns -1 if this is not a valid Load/Store instruction.
145   static unsigned getAddressSpaceOperand(Value *I);
146
147   /// \returns true if the memory operations A and B are consecutive.
148   bool isConsecutiveAccess(Value *A, Value *B);
149
150   /// \brief Vectorize the tree that starts with the elements in \p VL.
151   /// \returns the vectorized value.
152   Value *vectorizeTree(ArrayRef<Value *> VL);
153
154   /// \returns the vectorization cost of the subtree that starts at \p VL.
155   /// A negative number means that this is profitable.
156   int getTreeCost(ArrayRef<Value *> VL);
157
158   /// \returns the scalarization cost for this list of values. Assuming that
159   /// this subtree gets vectorized, we may need to extract the values from the
160   /// roots. This method calculates the cost of extracting the values.
161   int getGatherCost(ArrayRef<Value *> VL);
162
163   /// \brief Attempts to order and vectorize a sequence of stores. This
164   /// function does a quadratic scan of the given stores.
165   /// \returns true if the basic block was modified.
166   bool vectorizeStores(ArrayRef<StoreInst *> Stores, int costThreshold);
167
168   /// \brief Vectorize a group of scalars into a vector tree.
169   /// \returns the vectorized value.
170   Value *vectorizeArith(ArrayRef<Value *> Operands);
171
172   /// \brief This method contains the recursive part of getTreeCost.
173   int getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth);
174
175   /// \brief This recursive method looks for vectorization hazards such as
176   /// values that are used by multiple users and checks that values are used
177   /// by only one vector lane. It updates the variables LaneMap, MultiUserVals.
178   void getTreeUses_rec(ArrayRef<Value *> VL, unsigned Depth);
179
180   /// \brief This method contains the recursive part of vectorizeTree.
181   Value *vectorizeTree_rec(ArrayRef<Value *> VL);
182
183   ///  \brief Vectorize a sorted sequence of stores.
184   bool vectorizeStoreChain(ArrayRef<Value *> Chain, int CostThreshold);
185
186   /// \returns the scalarization cost for this type. Scalarization in this
187   /// context means the creation of vectors from a group of scalars.
188   int getGatherCost(Type *Ty);
189
190   /// \returns the AA location that is being access by the instruction.
191   AliasAnalysis::Location getLocation(Instruction *I);
192
193   /// \brief Checks if it is possible to sink an instruction from
194   /// \p Src to \p Dst.
195   /// \returns the pointer to the barrier instruction if we can't sink.
196   Value *getSinkBarrier(Instruction *Src, Instruction *Dst);
197
198   /// \returns the index of the last instrucion in the BB from \p VL.
199   int getLastIndex(ArrayRef<Value *> VL);
200
201   /// \returns the Instrucion in the bundle \p VL.
202   Instruction *getLastInstruction(ArrayRef<Value *> VL);
203
204   /// \returns the Instruction at index \p Index which is in Block \p BB.
205   Instruction *getInstructionForIndex(unsigned Index, BasicBlock *BB);
206
207   /// \returns the index of the first User of \p VL.
208   int getFirstUserIndex(ArrayRef<Value *> VL);
209
210   /// \returns a vector from a collection of scalars in \p VL.
211   Value *Gather(ArrayRef<Value *> VL, VectorType *Ty);
212
213   /// \brief Perform LICM and CSE on the newly generated gather sequences.
214   void optimizeGatherSequence();
215
216   bool needToGatherAny(ArrayRef<Value *> VL) {
217     for (int i = 0, e = VL.size(); i < e; ++i)
218       if (MustGather.count(VL[i]))
219         return true;
220     return false;
221   }
222
223   /// -- Vectorization State --
224
225   /// Maps values in the tree to the vector lanes that uses them. This map must
226   /// be reset between runs of getCost.
227   std::map<Value *, int> LaneMap;
228   /// A list of instructions to ignore while sinking
229   /// memory instructions. This map must be reset between runs of getCost.
230   ValueSet MemBarrierIgnoreList;
231
232   /// Maps between the first scalar to the vector. This map must be reset
233   /// between runs.
234   DenseMap<Value *, Value *> VectorizedValues;
235
236   /// Contains values that must be gathered because they are used
237   /// by multiple lanes, or by users outside the tree.
238   /// NOTICE: The vectorization methods also use this set.
239   ValueSet MustGather;
240
241   /// Contains a list of values that are used outside the current tree. This
242   /// set must be reset between runs.
243   SetVector<Value *> MultiUserVals;
244
245   /// Holds all of the instructions that we gathered.
246   SetVector<Instruction *> GatherSeq;
247
248   /// Numbers instructions in different blocks.
249   std::map<BasicBlock *, BlockNumbering> BlocksNumbers;
250
251   // Analysis and block reference.
252   Function *F;
253   ScalarEvolution *SE;
254   DataLayout *DL;
255   TargetTransformInfo *TTI;
256   AliasAnalysis *AA;
257   LoopInfo *LI;
258   /// Instruction builder to construct the vectorized tree.
259   IRBuilder<> Builder;
260 };
261
262 int FuncSLP::getGatherCost(Type *Ty) {
263   int Cost = 0;
264   for (unsigned i = 0, e = cast<VectorType>(Ty)->getNumElements(); i < e; ++i)
265     Cost += TTI->getVectorInstrCost(Instruction::InsertElement, Ty, i);
266   return Cost;
267 }
268
269 int FuncSLP::getGatherCost(ArrayRef<Value *> VL) {
270   // Find the type of the operands in VL.
271   Type *ScalarTy = VL[0]->getType();
272   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
273     ScalarTy = SI->getValueOperand()->getType();
274   VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
275   // Find the cost of inserting/extracting values from the vector.
276   return getGatherCost(VecTy);
277 }
278
279 AliasAnalysis::Location FuncSLP::getLocation(Instruction *I) {
280   if (StoreInst *SI = dyn_cast<StoreInst>(I))
281     return AA->getLocation(SI);
282   if (LoadInst *LI = dyn_cast<LoadInst>(I))
283     return AA->getLocation(LI);
284   return AliasAnalysis::Location();
285 }
286
287 Value *FuncSLP::getPointerOperand(Value *I) {
288   if (LoadInst *LI = dyn_cast<LoadInst>(I))
289     return LI->getPointerOperand();
290   if (StoreInst *SI = dyn_cast<StoreInst>(I))
291     return SI->getPointerOperand();
292   return 0;
293 }
294
295 unsigned FuncSLP::getAddressSpaceOperand(Value *I) {
296   if (LoadInst *L = dyn_cast<LoadInst>(I))
297     return L->getPointerAddressSpace();
298   if (StoreInst *S = dyn_cast<StoreInst>(I))
299     return S->getPointerAddressSpace();
300   return -1;
301 }
302
303 bool FuncSLP::isConsecutiveAccess(Value *A, Value *B) {
304   Value *PtrA = getPointerOperand(A);
305   Value *PtrB = getPointerOperand(B);
306   unsigned ASA = getAddressSpaceOperand(A);
307   unsigned ASB = getAddressSpaceOperand(B);
308
309   // Check that the address spaces match and that the pointers are valid.
310   if (!PtrA || !PtrB || (ASA != ASB))
311     return false;
312
313   // Check that A and B are of the same type.
314   if (PtrA->getType() != PtrB->getType())
315     return false;
316
317   // Calculate the distance.
318   const SCEV *PtrSCEVA = SE->getSCEV(PtrA);
319   const SCEV *PtrSCEVB = SE->getSCEV(PtrB);
320   const SCEV *OffsetSCEV = SE->getMinusSCEV(PtrSCEVA, PtrSCEVB);
321   const SCEVConstant *ConstOffSCEV = dyn_cast<SCEVConstant>(OffsetSCEV);
322
323   // Non constant distance.
324   if (!ConstOffSCEV)
325     return false;
326
327   int64_t Offset = ConstOffSCEV->getValue()->getSExtValue();
328   Type *Ty = cast<PointerType>(PtrA->getType())->getElementType();
329   // The Instructions are connsecutive if the size of the first load/store is
330   // the same as the offset.
331   int64_t Sz = DL->getTypeStoreSize(Ty);
332   return ((-Offset) == Sz);
333 }
334
335 Value *FuncSLP::getSinkBarrier(Instruction *Src, Instruction *Dst) {
336   assert(Src->getParent() == Dst->getParent() && "Not the same BB");
337   BasicBlock::iterator I = Src, E = Dst;
338   /// Scan all of the instruction from SRC to DST and check if
339   /// the source may alias.
340   for (++I; I != E; ++I) {
341     // Ignore store instructions that are marked as 'ignore'.
342     if (MemBarrierIgnoreList.count(I))
343       continue;
344     if (Src->mayWriteToMemory()) /* Write */ {
345       if (!I->mayReadOrWriteMemory())
346         continue;
347     } else /* Read */ {
348       if (!I->mayWriteToMemory())
349         continue;
350     }
351     AliasAnalysis::Location A = getLocation(&*I);
352     AliasAnalysis::Location B = getLocation(Src);
353
354     if (!A.Ptr || !B.Ptr || AA->alias(A, B))
355       return I;
356   }
357   return 0;
358 }
359
360 static BasicBlock *getSameBlock(ArrayRef<Value *> VL) {
361   BasicBlock *BB = 0;
362   for (int i = 0, e = VL.size(); i < e; i++) {
363     Instruction *I = dyn_cast<Instruction>(VL[i]);
364     if (!I)
365       return 0;
366
367     if (!BB) {
368       BB = I->getParent();
369       continue;
370     }
371
372     if (BB != I->getParent())
373       return 0;
374   }
375   return BB;
376 }
377
378 static bool allConstant(ArrayRef<Value *> VL) {
379   for (unsigned i = 0, e = VL.size(); i < e; ++i)
380     if (!isa<Constant>(VL[i]))
381       return false;
382   return true;
383 }
384
385 static bool isSplat(ArrayRef<Value *> VL) {
386   for (unsigned i = 1, e = VL.size(); i < e; ++i)
387     if (VL[i] != VL[0])
388       return false;
389   return true;
390 }
391
392 static unsigned getSameOpcode(ArrayRef<Value *> VL) {
393   unsigned Opcode = 0;
394   for (int i = 0, e = VL.size(); i < e; i++) {
395     if (Instruction *I = dyn_cast<Instruction>(VL[i])) {
396       if (!Opcode) {
397         Opcode = I->getOpcode();
398         continue;
399       }
400       if (Opcode != I->getOpcode())
401         return 0;
402     }
403   }
404   return Opcode;
405 }
406
407 static bool CanReuseExtract(ArrayRef<Value *> VL, unsigned VF,
408                             VectorType *VecTy) {
409   assert(Instruction::ExtractElement == getSameOpcode(VL) && "Invalid opcode");
410   // Check if all of the extracts come from the same vector and from the
411   // correct offset.
412   Value *VL0 = VL[0];
413   ExtractElementInst *E0 = cast<ExtractElementInst>(VL0);
414   Value *Vec = E0->getOperand(0);
415
416   // We have to extract from the same vector type.
417   if (Vec->getType() != VecTy)
418     return false;
419
420   // Check that all of the indices extract from the correct offset.
421   ConstantInt *CI = dyn_cast<ConstantInt>(E0->getOperand(1));
422   if (!CI || CI->getZExtValue())
423     return false;
424
425   for (unsigned i = 1, e = VF; i < e; ++i) {
426     ExtractElementInst *E = cast<ExtractElementInst>(VL[i]);
427     ConstantInt *CI = dyn_cast<ConstantInt>(E->getOperand(1));
428
429     if (!CI || CI->getZExtValue() != i || E->getOperand(0) != Vec)
430       return false;
431   }
432
433   return true;
434 }
435
436 void FuncSLP::getTreeUses_rec(ArrayRef<Value *> VL, unsigned Depth) {
437   if (Depth == RecursionMaxDepth)
438     return MustGather.insert(VL.begin(), VL.end());
439
440   // Don't handle vectors.
441   if (VL[0]->getType()->isVectorTy())
442     return;
443
444   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
445     if (SI->getValueOperand()->getType()->isVectorTy())
446       return;
447
448   // If all of the operands are identical or constant we have a simple solution.
449   if (allConstant(VL) || isSplat(VL) || !getSameBlock(VL))
450     return MustGather.insert(VL.begin(), VL.end());
451
452   // Stop the scan at unknown IR.
453   Instruction *VL0 = dyn_cast<Instruction>(VL[0]);
454   assert(VL0 && "Invalid instruction");
455
456   // Mark instructions with multiple users.
457   for (unsigned i = 0, e = VL.size(); i < e; ++i) {
458     Instruction *I = dyn_cast<Instruction>(VL[i]);
459     // Remember to check if all of the users of this instruction are vectorized
460     // within our tree. At depth zero we have no local users, only external
461     // users that we don't care about.
462     if (Depth && I && I->getNumUses() > 1) {
463       DEBUG(dbgs() << "SLP: Adding to MultiUserVals "
464                       "because it has multiple users:" << *I << " \n");
465       MultiUserVals.insert(I);
466     }
467   }
468
469   // Check that the instruction is only used within one lane.
470   for (int i = 0, e = VL.size(); i < e; ++i) {
471     if (LaneMap.count(VL[i]) && LaneMap[VL[i]] != i) {
472       DEBUG(dbgs() << "SLP: Value used by multiple lanes:" << *VL[i] << "\n");
473       return MustGather.insert(VL.begin(), VL.end());
474     }
475     // Make this instruction as 'seen' and remember the lane.
476     LaneMap[VL[i]] = i;
477   }
478
479   unsigned Opcode = getSameOpcode(VL);
480   if (!Opcode)
481     return MustGather.insert(VL.begin(), VL.end());
482
483   switch (Opcode) {
484   case Instruction::ExtractElement: {
485     VectorType *VecTy = VectorType::get(VL[0]->getType(), VL.size());
486     // No need to follow ExtractElements that are going to be optimized away.
487     if (CanReuseExtract(VL, VL.size(), VecTy))
488       return;
489     // Fall through.
490   }
491   case Instruction::Load:
492     return;
493   case Instruction::ZExt:
494   case Instruction::SExt:
495   case Instruction::FPToUI:
496   case Instruction::FPToSI:
497   case Instruction::FPExt:
498   case Instruction::PtrToInt:
499   case Instruction::IntToPtr:
500   case Instruction::SIToFP:
501   case Instruction::UIToFP:
502   case Instruction::Trunc:
503   case Instruction::FPTrunc:
504   case Instruction::BitCast:
505   case Instruction::Select:
506   case Instruction::ICmp:
507   case Instruction::FCmp:
508   case Instruction::Add:
509   case Instruction::FAdd:
510   case Instruction::Sub:
511   case Instruction::FSub:
512   case Instruction::Mul:
513   case Instruction::FMul:
514   case Instruction::UDiv:
515   case Instruction::SDiv:
516   case Instruction::FDiv:
517   case Instruction::URem:
518   case Instruction::SRem:
519   case Instruction::FRem:
520   case Instruction::Shl:
521   case Instruction::LShr:
522   case Instruction::AShr:
523   case Instruction::And:
524   case Instruction::Or:
525   case Instruction::Xor: {
526     for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
527       ValueList Operands;
528       // Prepare the operand vector.
529       for (unsigned j = 0; j < VL.size(); ++j)
530         Operands.push_back(cast<Instruction>(VL[j])->getOperand(i));
531
532       getTreeUses_rec(Operands, Depth + 1);
533     }
534     return;
535   }
536   case Instruction::Store: {
537     ValueList Operands;
538     for (unsigned j = 0; j < VL.size(); ++j)
539       Operands.push_back(cast<Instruction>(VL[j])->getOperand(0));
540     getTreeUses_rec(Operands, Depth + 1);
541     return;
542   }
543   default:
544     return MustGather.insert(VL.begin(), VL.end());
545   }
546 }
547
548 int FuncSLP::getLastIndex(ArrayRef<Value *> VL) {
549   BasicBlock *BB = cast<Instruction>(VL[0])->getParent();
550   assert(BB == getSameBlock(VL) && BlocksNumbers.count(BB) && "Invalid block");
551   BlockNumbering &BN = BlocksNumbers[BB];
552
553   int MaxIdx = BN.getIndex(BB->getFirstNonPHI());
554   for (unsigned i = 0, e = VL.size(); i < e; ++i)
555     MaxIdx = std::max(MaxIdx, BN.getIndex(cast<Instruction>(VL[i])));
556   return MaxIdx;
557 }
558
559 Instruction *FuncSLP::getLastInstruction(ArrayRef<Value *> VL) {
560   BasicBlock *BB = cast<Instruction>(VL[0])->getParent();
561   assert(BB == getSameBlock(VL) && BlocksNumbers.count(BB) && "Invalid block");
562   BlockNumbering &BN = BlocksNumbers[BB];
563
564   int MaxIdx = BN.getIndex(cast<Instruction>(VL[0]));
565   for (unsigned i = 1, e = VL.size(); i < e; ++i)
566     MaxIdx = std::max(MaxIdx, BN.getIndex(cast<Instruction>(VL[i])));
567   return BN.getInstruction(MaxIdx);
568 }
569
570 Instruction *FuncSLP::getInstructionForIndex(unsigned Index, BasicBlock *BB) {
571   BlockNumbering &BN = BlocksNumbers[BB];
572   return BN.getInstruction(Index);
573 }
574
575 int FuncSLP::getFirstUserIndex(ArrayRef<Value *> VL) {
576   BasicBlock *BB = getSameBlock(VL);
577   assert(BB && "All instructions must come from the same block");
578   BlockNumbering &BN = BlocksNumbers[BB];
579
580   // Find the first user of the values.
581   int FirstUser = BN.getIndex(BB->getTerminator());
582   for (unsigned i = 0, e = VL.size(); i < e; ++i) {
583     for (Value::use_iterator U = VL[i]->use_begin(), UE = VL[i]->use_end();
584          U != UE; ++U) {
585       Instruction *Instr = dyn_cast<Instruction>(*U);
586
587       if (!Instr || Instr->getParent() != BB)
588         continue;
589
590       FirstUser = std::min(FirstUser, BN.getIndex(Instr));
591     }
592   }
593   return FirstUser;
594 }
595
596 int FuncSLP::getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth) {
597   Type *ScalarTy = VL[0]->getType();
598
599   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
600     ScalarTy = SI->getValueOperand()->getType();
601
602   /// Don't mess with vectors.
603   if (ScalarTy->isVectorTy())
604     return FuncSLP::MAX_COST;
605
606   VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
607
608   if (allConstant(VL))
609     return 0;
610
611   if (isSplat(VL))
612     return TTI->getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, 0);
613
614   if (Depth == RecursionMaxDepth || needToGatherAny(VL))
615     return getGatherCost(VecTy);
616
617   BasicBlock *BB = getSameBlock(VL);
618   unsigned Opcode = getSameOpcode(VL);
619   assert(Opcode && BB && "Invalid Instruction Value");
620
621   // Check if it is safe to sink the loads or the stores.
622   if (Opcode == Instruction::Load || Opcode == Instruction::Store) {
623     int MaxIdx = getLastIndex(VL);
624     Instruction *Last = getInstructionForIndex(MaxIdx, BB);
625
626     for (unsigned i = 0, e = VL.size(); i < e; ++i) {
627       if (VL[i] == Last)
628         continue;
629       Value *Barrier = getSinkBarrier(cast<Instruction>(VL[i]), Last);
630       if (Barrier) {
631         DEBUG(dbgs() << "SLP: Can't sink " << *VL[i] << "\n down to " << *Last
632                      << "\n because of " << *Barrier << "\n");
633         return MAX_COST;
634       }
635     }
636   }
637
638   Instruction *VL0 = cast<Instruction>(VL[0]);
639   switch (Opcode) {
640   case Instruction::ExtractElement: {
641     if (CanReuseExtract(VL, VL.size(), VecTy))
642       return 0;
643     return getGatherCost(VecTy);
644   }
645   case Instruction::ZExt:
646   case Instruction::SExt:
647   case Instruction::FPToUI:
648   case Instruction::FPToSI:
649   case Instruction::FPExt:
650   case Instruction::PtrToInt:
651   case Instruction::IntToPtr:
652   case Instruction::SIToFP:
653   case Instruction::UIToFP:
654   case Instruction::Trunc:
655   case Instruction::FPTrunc:
656   case Instruction::BitCast: {
657     ValueList Operands;
658     Type *SrcTy = VL0->getOperand(0)->getType();
659     // Prepare the operand vector.
660     for (unsigned j = 0; j < VL.size(); ++j) {
661       Operands.push_back(cast<Instruction>(VL[j])->getOperand(0));
662       // Check that the casted type is the same for all users.
663       if (cast<Instruction>(VL[j])->getOperand(0)->getType() != SrcTy)
664         return getGatherCost(VecTy);
665     }
666
667     int Cost = getTreeCost_rec(Operands, Depth + 1);
668     if (Cost == FuncSLP::MAX_COST)
669       return Cost;
670
671     // Calculate the cost of this instruction.
672     int ScalarCost = VL.size() * TTI->getCastInstrCost(VL0->getOpcode(),
673                                                        VL0->getType(), SrcTy);
674
675     VectorType *SrcVecTy = VectorType::get(SrcTy, VL.size());
676     int VecCost = TTI->getCastInstrCost(VL0->getOpcode(), VecTy, SrcVecTy);
677     Cost += (VecCost - ScalarCost);
678     return Cost;
679   }
680   case Instruction::FCmp:
681   case Instruction::ICmp: {
682     // Check that all of the compares have the same predicate.
683     CmpInst::Predicate P0 = dyn_cast<CmpInst>(VL0)->getPredicate();
684     for (unsigned i = 1, e = VL.size(); i < e; ++i) {
685       CmpInst *Cmp = cast<CmpInst>(VL[i]);
686       if (Cmp->getPredicate() != P0)
687         return getGatherCost(VecTy);
688     }
689     // Fall through.
690   }
691   case Instruction::Select:
692   case Instruction::Add:
693   case Instruction::FAdd:
694   case Instruction::Sub:
695   case Instruction::FSub:
696   case Instruction::Mul:
697   case Instruction::FMul:
698   case Instruction::UDiv:
699   case Instruction::SDiv:
700   case Instruction::FDiv:
701   case Instruction::URem:
702   case Instruction::SRem:
703   case Instruction::FRem:
704   case Instruction::Shl:
705   case Instruction::LShr:
706   case Instruction::AShr:
707   case Instruction::And:
708   case Instruction::Or:
709   case Instruction::Xor: {
710     int TotalCost = 0;
711     // Calculate the cost of all of the operands.
712     for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
713       ValueList Operands;
714       // Prepare the operand vector.
715       for (unsigned j = 0; j < VL.size(); ++j)
716         Operands.push_back(cast<Instruction>(VL[j])->getOperand(i));
717
718       int Cost = getTreeCost_rec(Operands, Depth + 1);
719       if (Cost == MAX_COST)
720         return MAX_COST;
721       TotalCost += TotalCost;
722     }
723
724     // Calculate the cost of this instruction.
725     int ScalarCost = 0;
726     int VecCost = 0;
727     if (Opcode == Instruction::FCmp || Opcode == Instruction::ICmp ||
728         Opcode == Instruction::Select) {
729       VectorType *MaskTy = VectorType::get(Builder.getInt1Ty(), VL.size());
730       ScalarCost =
731           VecTy->getNumElements() *
732           TTI->getCmpSelInstrCost(Opcode, ScalarTy, Builder.getInt1Ty());
733       VecCost = TTI->getCmpSelInstrCost(Opcode, VecTy, MaskTy);
734     } else {
735       ScalarCost = VecTy->getNumElements() *
736                    TTI->getArithmeticInstrCost(Opcode, ScalarTy);
737       VecCost = TTI->getArithmeticInstrCost(Opcode, VecTy);
738     }
739     TotalCost += (VecCost - ScalarCost);
740     return TotalCost;
741   }
742   case Instruction::Load: {
743     // If we are scalarize the loads, add the cost of forming the vector.
744     for (unsigned i = 0, e = VL.size() - 1; i < e; ++i)
745       if (!isConsecutiveAccess(VL[i], VL[i + 1]))
746         return getGatherCost(VecTy);
747
748     // Cost of wide load - cost of scalar loads.
749     int ScalarLdCost = VecTy->getNumElements() *
750                        TTI->getMemoryOpCost(Instruction::Load, ScalarTy, 1, 0);
751     int VecLdCost = TTI->getMemoryOpCost(Instruction::Load, ScalarTy, 1, 0);
752     return VecLdCost - ScalarLdCost;
753   }
754   case Instruction::Store: {
755     // We know that we can merge the stores. Calculate the cost.
756     int ScalarStCost = VecTy->getNumElements() *
757                        TTI->getMemoryOpCost(Instruction::Store, ScalarTy, 1, 0);
758     int VecStCost = TTI->getMemoryOpCost(Instruction::Store, ScalarTy, 1, 0);
759     int StoreCost = VecStCost - ScalarStCost;
760
761     ValueList Operands;
762     for (unsigned j = 0; j < VL.size(); ++j) {
763       Operands.push_back(cast<Instruction>(VL[j])->getOperand(0));
764       MemBarrierIgnoreList.insert(VL[j]);
765     }
766
767     int Cost = getTreeCost_rec(Operands, Depth + 1);
768     if (Cost == MAX_COST)
769       return MAX_COST;
770
771     int TotalCost = StoreCost + Cost;
772     return TotalCost;
773   }
774   default:
775     // Unable to vectorize unknown instructions.
776     return getGatherCost(VecTy);
777   }
778 }
779
780 int FuncSLP::getTreeCost(ArrayRef<Value *> VL) {
781   // Get rid of the list of stores that were removed, and from the
782   // lists of instructions with multiple users.
783   MemBarrierIgnoreList.clear();
784   LaneMap.clear();
785   MultiUserVals.clear();
786   MustGather.clear();
787
788   if (!getSameBlock(VL))
789     return MAX_COST;
790
791   // Find the location of the last root.
792   int LastRootIndex = getLastIndex(VL);
793   int FirstUserIndex = getFirstUserIndex(VL);
794
795   // Don't vectorize if there are users of the tree roots inside the tree
796   // itself.
797   if (LastRootIndex > FirstUserIndex)
798     return MAX_COST;
799
800   // Scan the tree and find which value is used by which lane, and which values
801   // must be scalarized.
802   getTreeUses_rec(VL, 0);
803
804   // Check that instructions with multiple users can be vectorized. Mark unsafe
805   // instructions.
806   for (SetVector<Value *>::iterator it = MultiUserVals.begin(),
807                                     e = MultiUserVals.end();
808        it != e; ++it) {
809     // Check that all of the users of this instr are within the tree.
810     for (Value::use_iterator I = (*it)->use_begin(), E = (*it)->use_end();
811          I != E; ++I) {
812       if (LaneMap.find(*I) == LaneMap.end()) {
813         DEBUG(dbgs() << "SLP: Adding to MustExtract "
814                         "because of an out of tree usage.\n");
815         MustGather.insert(*it);
816         continue;
817       }
818     }
819   }
820
821   // Now calculate the cost of vectorizing the tree.
822   return getTreeCost_rec(VL, 0);
823 }
824 bool FuncSLP::vectorizeStoreChain(ArrayRef<Value *> Chain, int CostThreshold) {
825   unsigned ChainLen = Chain.size();
826   DEBUG(dbgs() << "SLP: Analyzing a store chain of length " << ChainLen
827                << "\n");
828   Type *StoreTy = cast<StoreInst>(Chain[0])->getValueOperand()->getType();
829   unsigned Sz = DL->getTypeSizeInBits(StoreTy);
830   unsigned VF = MinVecRegSize / Sz;
831
832   if (!isPowerOf2_32(Sz) || VF < 2)
833     return false;
834
835   bool Changed = false;
836   // Look for profitable vectorizable trees at all offsets, starting at zero.
837   for (unsigned i = 0, e = ChainLen; i < e; ++i) {
838     if (i + VF > e)
839       break;
840     DEBUG(dbgs() << "SLP: Analyzing " << VF << " stores at offset " << i
841                  << "\n");
842     ArrayRef<Value *> Operands = Chain.slice(i, VF);
843
844     int Cost = getTreeCost(Operands);
845     if (Cost == FuncSLP::MAX_COST)
846       continue;
847     DEBUG(dbgs() << "SLP: Found cost=" << Cost << " for VF=" << VF << "\n");
848     if (Cost < CostThreshold) {
849       DEBUG(dbgs() << "SLP: Decided to vectorize cost=" << Cost << "\n");
850       vectorizeTree(Operands);
851
852       // Remove the scalar stores.
853       for (int i = 0, e = VF; i < e; ++i)
854         cast<Instruction>(Operands[i])->eraseFromParent();
855
856       // Move to the next bundle.
857       i += VF - 1;
858       Changed = true;
859     }
860   }
861
862   if (Changed || ChainLen > VF)
863     return Changed;
864
865   // Handle short chains. This helps us catch types such as <3 x float> that
866   // are smaller than vector size.
867   int Cost = getTreeCost(Chain);
868   if (Cost == FuncSLP::MAX_COST)
869     return false;
870   if (Cost < CostThreshold) {
871     DEBUG(dbgs() << "SLP: Found store chain cost = " << Cost
872                  << " for size = " << ChainLen << "\n");
873     vectorizeTree(Chain);
874
875     // Remove all of the scalar stores.
876     for (int i = 0, e = Chain.size(); i < e; ++i)
877       cast<Instruction>(Chain[i])->eraseFromParent();
878
879     return true;
880   }
881
882   return false;
883 }
884
885 bool FuncSLP::vectorizeStores(ArrayRef<StoreInst *> Stores, int costThreshold) {
886   SetVector<Value *> Heads, Tails;
887   SmallDenseMap<Value *, Value *> ConsecutiveChain;
888
889   // We may run into multiple chains that merge into a single chain. We mark the
890   // stores that we vectorized so that we don't visit the same store twice.
891   ValueSet VectorizedStores;
892   bool Changed = false;
893
894   // Do a quadratic search on all of the given stores and find
895   // all of the pairs of loads that follow each other.
896   for (unsigned i = 0, e = Stores.size(); i < e; ++i)
897     for (unsigned j = 0; j < e; ++j) {
898       if (i == j)
899         continue;
900
901       if (isConsecutiveAccess(Stores[i], Stores[j])) {
902         Tails.insert(Stores[j]);
903         Heads.insert(Stores[i]);
904         ConsecutiveChain[Stores[i]] = Stores[j];
905       }
906     }
907
908   // For stores that start but don't end a link in the chain:
909   for (SetVector<Value *>::iterator it = Heads.begin(), e = Heads.end();
910        it != e; ++it) {
911     if (Tails.count(*it))
912       continue;
913
914     // We found a store instr that starts a chain. Now follow the chain and try
915     // to vectorize it.
916     ValueList Operands;
917     Value *I = *it;
918     // Collect the chain into a list.
919     while (Tails.count(I) || Heads.count(I)) {
920       if (VectorizedStores.count(I))
921         break;
922       Operands.push_back(I);
923       // Move to the next value in the chain.
924       I = ConsecutiveChain[I];
925     }
926
927     bool Vectorized = vectorizeStoreChain(Operands, costThreshold);
928
929     // Mark the vectorized stores so that we don't vectorize them again.
930     if (Vectorized)
931       VectorizedStores.insert(Operands.begin(), Operands.end());
932     Changed |= Vectorized;
933   }
934
935   return Changed;
936 }
937
938 Value *FuncSLP::Gather(ArrayRef<Value *> VL, VectorType *Ty) {
939   Value *Vec = UndefValue::get(Ty);
940   // Generate the 'InsertElement' instruction.
941   for (unsigned i = 0; i < Ty->getNumElements(); ++i) {
942     Vec = Builder.CreateInsertElement(Vec, VL[i], Builder.getInt32(i));
943     if (Instruction *I = dyn_cast<Instruction>(Vec))
944       GatherSeq.insert(I);
945   }
946
947   return Vec;
948 }
949
950 Value *FuncSLP::vectorizeTree_rec(ArrayRef<Value *> VL) {
951   BuilderLocGuard Guard(Builder);
952
953   Type *ScalarTy = VL[0]->getType();
954   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
955     ScalarTy = SI->getValueOperand()->getType();
956   VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
957
958   if (needToGatherAny(VL))
959     return Gather(VL, VecTy);
960
961   if (VectorizedValues.count(VL[0])) {
962     DEBUG(dbgs() << "SLP: Diamond merged at depth.\n");
963     return VectorizedValues[VL[0]];
964   }
965
966   Instruction *VL0 = cast<Instruction>(VL[0]);
967   unsigned Opcode = VL0->getOpcode();
968   assert(Opcode == getSameOpcode(VL) && "Invalid opcode");
969
970   switch (Opcode) {
971   case Instruction::ExtractElement: {
972     if (CanReuseExtract(VL, VL.size(), VecTy))
973       return VL0->getOperand(0);
974     return Gather(VL, VecTy);
975   }
976   case Instruction::ZExt:
977   case Instruction::SExt:
978   case Instruction::FPToUI:
979   case Instruction::FPToSI:
980   case Instruction::FPExt:
981   case Instruction::PtrToInt:
982   case Instruction::IntToPtr:
983   case Instruction::SIToFP:
984   case Instruction::UIToFP:
985   case Instruction::Trunc:
986   case Instruction::FPTrunc:
987   case Instruction::BitCast: {
988     ValueList INVL;
989     for (int i = 0, e = VL.size(); i < e; ++i)
990       INVL.push_back(cast<Instruction>(VL[i])->getOperand(0));
991
992     Builder.SetInsertPoint(getLastInstruction(VL));
993     Value *InVec = vectorizeTree_rec(INVL);
994     CastInst *CI = dyn_cast<CastInst>(VL0);
995     Value *V = Builder.CreateCast(CI->getOpcode(), InVec, VecTy);
996     VectorizedValues[VL0] = V;
997     return V;
998   }
999   case Instruction::FCmp:
1000   case Instruction::ICmp: {
1001     // Check that all of the compares have the same predicate.
1002     CmpInst::Predicate P0 = dyn_cast<CmpInst>(VL0)->getPredicate();
1003     for (unsigned i = 1, e = VL.size(); i < e; ++i) {
1004       CmpInst *Cmp = cast<CmpInst>(VL[i]);
1005       if (Cmp->getPredicate() != P0)
1006         return Gather(VL, VecTy);
1007     }
1008
1009     ValueList LHSV, RHSV;
1010     for (int i = 0, e = VL.size(); i < e; ++i) {
1011       LHSV.push_back(cast<Instruction>(VL[i])->getOperand(0));
1012       RHSV.push_back(cast<Instruction>(VL[i])->getOperand(1));
1013     }
1014
1015     Builder.SetInsertPoint(getLastInstruction(VL));
1016     Value *L = vectorizeTree_rec(LHSV);
1017     Value *R = vectorizeTree_rec(RHSV);
1018     Value *V;
1019
1020     if (Opcode == Instruction::FCmp)
1021       V = Builder.CreateFCmp(P0, L, R);
1022     else
1023       V = Builder.CreateICmp(P0, L, R);
1024
1025     VectorizedValues[VL0] = V;
1026     return V;
1027   }
1028   case Instruction::Select: {
1029     ValueList TrueVec, FalseVec, CondVec;
1030     for (int i = 0, e = VL.size(); i < e; ++i) {
1031       CondVec.push_back(cast<Instruction>(VL[i])->getOperand(0));
1032       TrueVec.push_back(cast<Instruction>(VL[i])->getOperand(1));
1033       FalseVec.push_back(cast<Instruction>(VL[i])->getOperand(2));
1034     }
1035
1036     Builder.SetInsertPoint(getLastInstruction(VL));
1037     Value *True = vectorizeTree_rec(TrueVec);
1038     Value *False = vectorizeTree_rec(FalseVec);
1039     Value *Cond = vectorizeTree_rec(CondVec);
1040     Value *V = Builder.CreateSelect(Cond, True, False);
1041     VectorizedValues[VL0] = V;
1042     return V;
1043   }
1044   case Instruction::Add:
1045   case Instruction::FAdd:
1046   case Instruction::Sub:
1047   case Instruction::FSub:
1048   case Instruction::Mul:
1049   case Instruction::FMul:
1050   case Instruction::UDiv:
1051   case Instruction::SDiv:
1052   case Instruction::FDiv:
1053   case Instruction::URem:
1054   case Instruction::SRem:
1055   case Instruction::FRem:
1056   case Instruction::Shl:
1057   case Instruction::LShr:
1058   case Instruction::AShr:
1059   case Instruction::And:
1060   case Instruction::Or:
1061   case Instruction::Xor: {
1062     ValueList LHSVL, RHSVL;
1063     for (int i = 0, e = VL.size(); i < e; ++i) {
1064       LHSVL.push_back(cast<Instruction>(VL[i])->getOperand(0));
1065       RHSVL.push_back(cast<Instruction>(VL[i])->getOperand(1));
1066     }
1067
1068     Builder.SetInsertPoint(getLastInstruction(VL));
1069     Value *LHS = vectorizeTree_rec(LHSVL);
1070     Value *RHS = vectorizeTree_rec(RHSVL);
1071
1072     if (LHS == RHS) {
1073       assert((VL0->getOperand(0) == VL0->getOperand(1)) && "Invalid order");
1074     }
1075
1076     BinaryOperator *BinOp = cast<BinaryOperator>(VL0);
1077     Value *V = Builder.CreateBinOp(BinOp->getOpcode(), LHS, RHS);
1078     VectorizedValues[VL0] = V;
1079     return V;
1080   }
1081   case Instruction::Load: {
1082     // Check if all of the loads are consecutive.
1083     for (unsigned i = 1, e = VL.size(); i < e; ++i)
1084       if (!isConsecutiveAccess(VL[i - 1], VL[i]))
1085         return Gather(VL, VecTy);
1086
1087     // Loads are inserted at the head of the tree because we don't want to
1088     // sink them all the way down past store instructions.
1089     Builder.SetInsertPoint(getLastInstruction(VL));
1090     LoadInst *LI = cast<LoadInst>(VL0);
1091     Value *VecPtr =
1092         Builder.CreateBitCast(LI->getPointerOperand(), VecTy->getPointerTo());
1093     unsigned Alignment = LI->getAlignment();
1094     LI = Builder.CreateLoad(VecPtr);
1095     LI->setAlignment(Alignment);
1096
1097     VectorizedValues[VL0] = LI;
1098     return LI;
1099   }
1100   case Instruction::Store: {
1101     StoreInst *SI = cast<StoreInst>(VL0);
1102     unsigned Alignment = SI->getAlignment();
1103
1104     ValueList ValueOp;
1105     for (int i = 0, e = VL.size(); i < e; ++i)
1106       ValueOp.push_back(cast<StoreInst>(VL[i])->getValueOperand());
1107
1108     Value *VecValue = vectorizeTree_rec(ValueOp);
1109
1110     Builder.SetInsertPoint(getLastInstruction(VL));
1111     Value *VecPtr =
1112         Builder.CreateBitCast(SI->getPointerOperand(), VecTy->getPointerTo());
1113     Builder.CreateStore(VecValue, VecPtr)->setAlignment(Alignment);
1114     return 0;
1115   }
1116   default:
1117     return Gather(VL, VecTy);
1118   }
1119 }
1120
1121 Value *FuncSLP::vectorizeTree(ArrayRef<Value *> VL) {
1122   Builder.SetInsertPoint(getLastInstruction(VL));
1123   Value *V = vectorizeTree_rec(VL);
1124
1125   // We moved some instructions around. We have to number them again
1126   // before we can do any analysis.
1127   for (Function::iterator it = F->begin(), e = F->end(); it != e; ++it)
1128     BlocksNumbers[it].forget();
1129   // Clear the state.
1130   MustGather.clear();
1131   VectorizedValues.clear();
1132   MemBarrierIgnoreList.clear();
1133   return V;
1134 }
1135
1136 Value *FuncSLP::vectorizeArith(ArrayRef<Value *> Operands) {
1137   Value *Vec = vectorizeTree(Operands);
1138   // After vectorizing the operands we need to generate extractelement
1139   // instructions and replace all of the uses of the scalar values with
1140   // the values that we extracted from the vectorized tree.
1141   for (unsigned i = 0, e = Operands.size(); i != e; ++i) {
1142     Value *S = Builder.CreateExtractElement(Vec, Builder.getInt32(i));
1143     Operands[i]->replaceAllUsesWith(S);
1144   }
1145
1146   return Vec;
1147 }
1148
1149 void FuncSLP::optimizeGatherSequence() {
1150   // LICM InsertElementInst sequences.
1151   for (SetVector<Instruction *>::iterator it = GatherSeq.begin(),
1152        e = GatherSeq.end(); it != e; ++it) {
1153     InsertElementInst *Insert = dyn_cast<InsertElementInst>(*it);
1154
1155     if (!Insert)
1156       continue;
1157
1158     // Check if this block is inside a loop.
1159     Loop *L = LI->getLoopFor(Insert->getParent());
1160     if (!L)
1161       continue;
1162
1163     // Check if it has a preheader.
1164     BasicBlock *PreHeader = L->getLoopPreheader();
1165     if (!PreHeader)
1166       return;
1167
1168     // If the vector or the element that we insert into it are
1169     // instructions that are defined in this basic block then we can't
1170     // hoist this instruction.
1171     Instruction *CurrVec = dyn_cast<Instruction>(Insert->getOperand(0));
1172     Instruction *NewElem = dyn_cast<Instruction>(Insert->getOperand(1));
1173     if (CurrVec && L->contains(CurrVec))
1174       continue;
1175     if (NewElem && L->contains(NewElem))
1176       continue;
1177
1178     // We can hoist this instruction. Move it to the pre-header.
1179     Insert->moveBefore(PreHeader->getTerminator());
1180   }
1181
1182   // Perform O(N^2) search over the gather sequences and merge identical
1183   // instructions. TODO: We can further optimize this scan if we split the
1184   // instructions into different buckets based on the insert lane.
1185   SmallPtrSet<Instruction*, 16> Visited;
1186   ReversePostOrderTraversal<Function*> RPOT(F);
1187   for (ReversePostOrderTraversal<Function*>::rpo_iterator I = RPOT.begin(),
1188        E = RPOT.end(); I != E; ++I) {
1189     BasicBlock *BB = *I;
1190     // For all instructions in the function:
1191     for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
1192       InsertElementInst *Insert = dyn_cast<InsertElementInst>(it);
1193       if (!Insert || !GatherSeq.count(Insert))
1194         continue;
1195
1196      // Check if we can replace this instruction with any of the
1197      // visited instructions.
1198       for (SmallPtrSet<Instruction*, 16>::iterator v = Visited.begin(),
1199            ve = Visited.end(); v != ve; ++v) {
1200         if (Insert->isIdenticalTo(*v)) {
1201           Insert->replaceAllUsesWith(*v);
1202           break;
1203         }
1204       }
1205       Visited.insert(Insert);
1206     }
1207   }
1208 }
1209
1210 /// The SLPVectorizer Pass.
1211 struct SLPVectorizer : public FunctionPass {
1212   typedef SmallVector<StoreInst *, 8> StoreList;
1213   typedef MapVector<Value *, StoreList> StoreListMap;
1214
1215   /// Pass identification, replacement for typeid
1216   static char ID;
1217
1218   explicit SLPVectorizer() : FunctionPass(ID) {
1219     initializeSLPVectorizerPass(*PassRegistry::getPassRegistry());
1220   }
1221
1222   ScalarEvolution *SE;
1223   DataLayout *DL;
1224   TargetTransformInfo *TTI;
1225   AliasAnalysis *AA;
1226   LoopInfo *LI;
1227
1228   virtual bool runOnFunction(Function &F) {
1229     SE = &getAnalysis<ScalarEvolution>();
1230     DL = getAnalysisIfAvailable<DataLayout>();
1231     TTI = &getAnalysis<TargetTransformInfo>();
1232     AA = &getAnalysis<AliasAnalysis>();
1233     LI = &getAnalysis<LoopInfo>();
1234
1235     StoreRefs.clear();
1236     bool Changed = false;
1237
1238     // Must have DataLayout. We can't require it because some tests run w/o
1239     // triple.
1240     if (!DL)
1241       return false;
1242
1243     DEBUG(dbgs() << "SLP: Analyzing blocks in " << F.getName() << ".\n");
1244
1245     // Use the bollom up slp vectorizer to construct chains that start with
1246     // he store instructions.
1247     FuncSLP R(&F, SE, DL, TTI, AA, LI);
1248
1249     for (Function::iterator it = F.begin(), e = F.end(); it != e; ++it) {
1250       BasicBlock *BB = it;
1251
1252       // Vectorize trees that end at reductions.
1253       Changed |= vectorizeChainsInBlock(BB, R);
1254
1255       // Vectorize trees that end at stores.
1256       if (unsigned count = collectStores(BB, R)) {
1257         (void)count;
1258         DEBUG(dbgs() << "SLP: Found " << count << " stores to vectorize.\n");
1259         Changed |= vectorizeStoreChains(R);
1260       }
1261     }
1262
1263     if (Changed) {
1264       R.optimizeGatherSequence();
1265       DEBUG(dbgs() << "SLP: vectorized \"" << F.getName() << "\"\n");
1266       DEBUG(verifyFunction(F));
1267     }
1268     return Changed;
1269   }
1270
1271   virtual void getAnalysisUsage(AnalysisUsage &AU) const {
1272     FunctionPass::getAnalysisUsage(AU);
1273     AU.addRequired<ScalarEvolution>();
1274     AU.addRequired<AliasAnalysis>();
1275     AU.addRequired<TargetTransformInfo>();
1276     AU.addRequired<LoopInfo>();
1277   }
1278
1279 private:
1280
1281   /// \brief Collect memory references and sort them according to their base
1282   /// object. We sort the stores to their base objects to reduce the cost of the
1283   /// quadratic search on the stores. TODO: We can further reduce this cost
1284   /// if we flush the chain creation every time we run into a memory barrier.
1285   unsigned collectStores(BasicBlock *BB, FuncSLP &R);
1286
1287   /// \brief Try to vectorize a chain that starts at two arithmetic instrs.
1288   bool tryToVectorizePair(Value *A, Value *B, FuncSLP &R);
1289
1290   /// \brief Try to vectorize a list of operands. If \p NeedExtracts is true
1291   /// then we calculate the cost of extracting the scalars from the vector.
1292   /// \returns true if a value was vectorized.
1293   bool tryToVectorizeList(ArrayRef<Value *> VL, FuncSLP &R, bool NeedExtracts);
1294
1295   /// \brief Try to vectorize a chain that may start at the operands of \V;
1296   bool tryToVectorize(BinaryOperator *V, FuncSLP &R);
1297
1298   /// \brief Vectorize the stores that were collected in StoreRefs.
1299   bool vectorizeStoreChains(FuncSLP &R);
1300
1301   /// \brief Scan the basic block and look for patterns that are likely to start
1302   /// a vectorization chain.
1303   bool vectorizeChainsInBlock(BasicBlock *BB, FuncSLP &R);
1304
1305 private:
1306   StoreListMap StoreRefs;
1307 };
1308
1309 unsigned SLPVectorizer::collectStores(BasicBlock *BB, FuncSLP &R) {
1310   unsigned count = 0;
1311   StoreRefs.clear();
1312   for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
1313     StoreInst *SI = dyn_cast<StoreInst>(it);
1314     if (!SI)
1315       continue;
1316
1317     // Check that the pointer points to scalars.
1318     Type *Ty = SI->getValueOperand()->getType();
1319     if (Ty->isAggregateType() || Ty->isVectorTy())
1320       return 0;
1321
1322     // Find the base of the GEP.
1323     Value *Ptr = SI->getPointerOperand();
1324     if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr))
1325       Ptr = GEP->getPointerOperand();
1326
1327     // Save the store locations.
1328     StoreRefs[Ptr].push_back(SI);
1329     count++;
1330   }
1331   return count;
1332 }
1333
1334 bool SLPVectorizer::tryToVectorizePair(Value *A, Value *B, FuncSLP &R) {
1335   if (!A || !B)
1336     return false;
1337   Value *VL[] = { A, B };
1338   return tryToVectorizeList(VL, R, true);
1339 }
1340
1341 bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, FuncSLP &R,
1342                                        bool NeedExtracts) {
1343   if (VL.size() < 2)
1344     return false;
1345
1346   DEBUG(dbgs() << "SLP: Vectorizing a list of length = " << VL.size() << ".\n");
1347
1348   // Check that all of the parts are scalar instructions of the same type.
1349   Instruction *I0 = dyn_cast<Instruction>(VL[0]);
1350   if (!I0)
1351     return 0;
1352
1353   unsigned Opcode0 = I0->getOpcode();
1354
1355   for (int i = 0, e = VL.size(); i < e; ++i) {
1356     Type *Ty = VL[i]->getType();
1357     if (Ty->isAggregateType() || Ty->isVectorTy())
1358       return 0;
1359     Instruction *Inst = dyn_cast<Instruction>(VL[i]);
1360     if (!Inst || Inst->getOpcode() != Opcode0)
1361       return 0;
1362   }
1363
1364   int Cost = R.getTreeCost(VL);
1365   if (Cost == FuncSLP::MAX_COST)
1366     return false;
1367
1368   int ExtrCost = NeedExtracts ? R.getGatherCost(VL) : 0;
1369   DEBUG(dbgs() << "SLP: Cost of pair:" << Cost
1370                << " Cost of extract:" << ExtrCost << ".\n");
1371   if ((Cost + ExtrCost) >= -SLPCostThreshold)
1372     return false;
1373   DEBUG(dbgs() << "SLP: Vectorizing pair.\n");
1374   R.vectorizeArith(VL);
1375   return true;
1376 }
1377
1378 bool SLPVectorizer::tryToVectorize(BinaryOperator *V, FuncSLP &R) {
1379   if (!V)
1380     return false;
1381
1382   // Try to vectorize V.
1383   if (tryToVectorizePair(V->getOperand(0), V->getOperand(1), R))
1384     return true;
1385
1386   BinaryOperator *A = dyn_cast<BinaryOperator>(V->getOperand(0));
1387   BinaryOperator *B = dyn_cast<BinaryOperator>(V->getOperand(1));
1388   // Try to skip B.
1389   if (B && B->hasOneUse()) {
1390     BinaryOperator *B0 = dyn_cast<BinaryOperator>(B->getOperand(0));
1391     BinaryOperator *B1 = dyn_cast<BinaryOperator>(B->getOperand(1));
1392     if (tryToVectorizePair(A, B0, R)) {
1393       B->moveBefore(V);
1394       return true;
1395     }
1396     if (tryToVectorizePair(A, B1, R)) {
1397       B->moveBefore(V);
1398       return true;
1399     }
1400   }
1401
1402   // Try to skip A.
1403   if (A && A->hasOneUse()) {
1404     BinaryOperator *A0 = dyn_cast<BinaryOperator>(A->getOperand(0));
1405     BinaryOperator *A1 = dyn_cast<BinaryOperator>(A->getOperand(1));
1406     if (tryToVectorizePair(A0, B, R)) {
1407       A->moveBefore(V);
1408       return true;
1409     }
1410     if (tryToVectorizePair(A1, B, R)) {
1411       A->moveBefore(V);
1412       return true;
1413     }
1414   }
1415   return 0;
1416 }
1417
1418 bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, FuncSLP &R) {
1419   bool Changed = false;
1420   for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
1421     if (isa<DbgInfoIntrinsic>(it))
1422       continue;
1423
1424     // Try to vectorize reductions that use PHINodes.
1425     if (PHINode *P = dyn_cast<PHINode>(it)) {
1426       // Check that the PHI is a reduction PHI.
1427       if (P->getNumIncomingValues() != 2)
1428         return Changed;
1429       Value *Rdx =
1430           (P->getIncomingBlock(0) == BB
1431                ? (P->getIncomingValue(0))
1432                : (P->getIncomingBlock(1) == BB ? P->getIncomingValue(1) : 0));
1433       // Check if this is a Binary Operator.
1434       BinaryOperator *BI = dyn_cast_or_null<BinaryOperator>(Rdx);
1435       if (!BI)
1436         continue;
1437
1438       Value *Inst = BI->getOperand(0);
1439       if (Inst == P)
1440         Inst = BI->getOperand(1);
1441
1442       Changed |= tryToVectorize(dyn_cast<BinaryOperator>(Inst), R);
1443       continue;
1444     }
1445
1446     // Try to vectorize trees that start at compare instructions.
1447     if (CmpInst *CI = dyn_cast<CmpInst>(it)) {
1448       if (tryToVectorizePair(CI->getOperand(0), CI->getOperand(1), R)) {
1449         Changed |= true;
1450         continue;
1451       }
1452       for (int i = 0; i < 2; ++i)
1453         if (BinaryOperator *BI = dyn_cast<BinaryOperator>(CI->getOperand(i)))
1454           Changed |=
1455               tryToVectorizePair(BI->getOperand(0), BI->getOperand(1), R);
1456       continue;
1457     }
1458   }
1459
1460   // Scan the PHINodes in our successors in search for pairing hints.
1461   for (succ_iterator it = succ_begin(BB), e = succ_end(BB); it != e; ++it) {
1462     BasicBlock *Succ = *it;
1463     SmallVector<Value *, 4> Incoming;
1464
1465     // Collect the incoming values from the PHIs.
1466     for (BasicBlock::iterator instr = Succ->begin(), ie = Succ->end();
1467          instr != ie; ++instr) {
1468       PHINode *P = dyn_cast<PHINode>(instr);
1469
1470       if (!P)
1471         break;
1472
1473       Value *V = P->getIncomingValueForBlock(BB);
1474       if (Instruction *I = dyn_cast<Instruction>(V))
1475         if (I->getParent() == BB)
1476           Incoming.push_back(I);
1477     }
1478
1479     if (Incoming.size() > 1)
1480       Changed |= tryToVectorizeList(Incoming, R, true);
1481   }
1482
1483   return Changed;
1484 }
1485
1486 bool SLPVectorizer::vectorizeStoreChains(FuncSLP &R) {
1487   bool Changed = false;
1488   // Attempt to sort and vectorize each of the store-groups.
1489   for (StoreListMap::iterator it = StoreRefs.begin(), e = StoreRefs.end();
1490        it != e; ++it) {
1491     if (it->second.size() < 2)
1492       continue;
1493
1494     DEBUG(dbgs() << "SLP: Analyzing a store chain of length "
1495                  << it->second.size() << ".\n");
1496
1497     Changed |= R.vectorizeStores(it->second, -SLPCostThreshold);
1498   }
1499   return Changed;
1500 }
1501
1502 } // end anonymous namespace
1503
1504 char SLPVectorizer::ID = 0;
1505 static const char lv_name[] = "SLP Vectorizer";
1506 INITIALIZE_PASS_BEGIN(SLPVectorizer, SV_NAME, lv_name, false, false)
1507 INITIALIZE_AG_DEPENDENCY(AliasAnalysis)
1508 INITIALIZE_AG_DEPENDENCY(TargetTransformInfo)
1509 INITIALIZE_PASS_DEPENDENCY(ScalarEvolution)
1510 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
1511 INITIALIZE_PASS_END(SLPVectorizer, SV_NAME, lv_name, false, false)
1512
1513 namespace llvm {
1514 Pass *createSLPVectorizerPass() { return new SLPVectorizer(); }
1515 }