1 //===-- AMDGPUPromoteAlloca.cpp - Promote Allocas -------------------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // This pass eliminates allocas by either converting them into vectors or
11 // by migrating them to local address space.
13 //===----------------------------------------------------------------------===//
16 #include "AMDGPUSubtarget.h"
17 #include "llvm/Analysis/ValueTracking.h"
18 #include "llvm/IR/IRBuilder.h"
19 #include "llvm/IR/InstVisitor.h"
20 #include "llvm/Support/Debug.h"
22 #define DEBUG_TYPE "amdgpu-promote-alloca"
28 class AMDGPUPromoteAlloca : public FunctionPass,
29 public InstVisitor<AMDGPUPromoteAlloca> {
33 const AMDGPUSubtarget &ST;
34 int LocalMemAvailable;
37 AMDGPUPromoteAlloca(const AMDGPUSubtarget &st) : FunctionPass(ID), ST(st),
38 LocalMemAvailable(0) { }
39 bool doInitialization(Module &M) override;
40 bool runOnFunction(Function &F) override;
41 const char *getPassName() const override { return "AMDGPU Promote Alloca"; }
42 void visitAlloca(AllocaInst &I);
45 } // End anonymous namespace
47 char AMDGPUPromoteAlloca::ID = 0;
49 bool AMDGPUPromoteAlloca::doInitialization(Module &M) {
54 bool AMDGPUPromoteAlloca::runOnFunction(Function &F) {
56 const FunctionType *FTy = F.getFunctionType();
58 LocalMemAvailable = ST.getLocalMemorySize();
61 // If the function has any arguments in the local address space, then it's
62 // possible these arguments require the entire local memory space, so
63 // we cannot use local memory in the pass.
64 for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i) {
65 const Type *ParamTy = FTy->getParamType(i);
66 if (ParamTy->isPointerTy() &&
67 ParamTy->getPointerAddressSpace() == AMDGPUAS::LOCAL_ADDRESS) {
68 LocalMemAvailable = 0;
69 DEBUG(dbgs() << "Function has local memory argument. Promoting to "
70 "local memory disabled.\n");
75 if (LocalMemAvailable > 0) {
76 // Check how much local memory is being used by global objects
77 for (Module::global_iterator I = Mod->global_begin(),
78 E = Mod->global_end(); I != E; ++I) {
79 GlobalVariable *GV = I;
80 PointerType *GVTy = GV->getType();
81 if (GVTy->getAddressSpace() != AMDGPUAS::LOCAL_ADDRESS)
83 for (Value::use_iterator U = GV->use_begin(),
84 UE = GV->use_end(); U != UE; ++U) {
85 Instruction *Use = dyn_cast<Instruction>(*U);
88 if (Use->getParent()->getParent() == &F)
90 Mod->getDataLayout()->getTypeAllocSize(GVTy->getElementType());
95 LocalMemAvailable = std::max(0, LocalMemAvailable);
96 DEBUG(dbgs() << LocalMemAvailable << "bytes free in local memory.\n");
103 static VectorType *arrayTypeToVecType(const Type *ArrayTy) {
104 return VectorType::get(ArrayTy->getArrayElementType(),
105 ArrayTy->getArrayNumElements());
108 static Value* calculateVectorIndex(Value *Ptr,
109 std::map<GetElementPtrInst*, Value*> GEPIdx) {
110 if (isa<AllocaInst>(Ptr))
111 return Constant::getNullValue(Type::getInt32Ty(Ptr->getContext()));
113 GetElementPtrInst *GEP = cast<GetElementPtrInst>(Ptr);
118 static Value* GEPToVectorIndex(GetElementPtrInst *GEP) {
119 // FIXME we only support simple cases
120 if (GEP->getNumOperands() != 3)
123 ConstantInt *I0 = dyn_cast<ConstantInt>(GEP->getOperand(1));
124 if (!I0 || !I0->isZero())
127 return GEP->getOperand(2);
130 // Not an instruction handled below to turn into a vector.
132 // TODO: Check isTriviallyVectorizable for calls and handle other
134 static bool canVectorizeInst(Instruction *Inst) {
135 switch (Inst->getOpcode()) {
136 case Instruction::Load:
137 case Instruction::Store:
138 case Instruction::BitCast:
139 case Instruction::AddrSpaceCast:
146 static bool tryPromoteAllocaToVector(AllocaInst *Alloca) {
147 Type *AllocaTy = Alloca->getAllocatedType();
149 DEBUG(dbgs() << "Alloca Candidate for vectorization \n");
151 // FIXME: There is no reason why we can't support larger arrays, we
152 // are just being conservative for now.
153 if (!AllocaTy->isArrayTy() ||
154 AllocaTy->getArrayElementType()->isVectorTy() ||
155 AllocaTy->getArrayNumElements() > 4) {
157 DEBUG(dbgs() << " Cannot convert type to vector");
161 std::map<GetElementPtrInst*, Value*> GEPVectorIdx;
162 std::vector<Value*> WorkList;
163 for (User *AllocaUser : Alloca->users()) {
164 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(AllocaUser);
166 if (!canVectorizeInst(cast<Instruction>(AllocaUser)))
169 WorkList.push_back(AllocaUser);
173 Value *Index = GEPToVectorIndex(GEP);
175 // If we can't compute a vector index from this GEP, then we can't
176 // promote this alloca to vector.
178 DEBUG(dbgs() << " Cannot compute vector index for GEP " << *GEP << '\n');
182 GEPVectorIdx[GEP] = Index;
183 for (User *GEPUser : AllocaUser->users()) {
184 if (!canVectorizeInst(cast<Instruction>(GEPUser)))
187 WorkList.push_back(GEPUser);
191 VectorType *VectorTy = arrayTypeToVecType(AllocaTy);
193 DEBUG(dbgs() << " Converting alloca to vector "
194 << *AllocaTy << " -> " << *VectorTy << '\n');
196 for (std::vector<Value*>::iterator I = WorkList.begin(),
197 E = WorkList.end(); I != E; ++I) {
198 Instruction *Inst = cast<Instruction>(*I);
199 IRBuilder<> Builder(Inst);
200 switch (Inst->getOpcode()) {
201 case Instruction::Load: {
202 Value *Ptr = Inst->getOperand(0);
203 Value *Index = calculateVectorIndex(Ptr, GEPVectorIdx);
204 Value *BitCast = Builder.CreateBitCast(Alloca, VectorTy->getPointerTo(0));
205 Value *VecValue = Builder.CreateLoad(BitCast);
206 Value *ExtractElement = Builder.CreateExtractElement(VecValue, Index);
207 Inst->replaceAllUsesWith(ExtractElement);
208 Inst->eraseFromParent();
211 case Instruction::Store: {
212 Value *Ptr = Inst->getOperand(1);
213 Value *Index = calculateVectorIndex(Ptr, GEPVectorIdx);
214 Value *BitCast = Builder.CreateBitCast(Alloca, VectorTy->getPointerTo(0));
215 Value *VecValue = Builder.CreateLoad(BitCast);
216 Value *NewVecValue = Builder.CreateInsertElement(VecValue,
219 Builder.CreateStore(NewVecValue, BitCast);
220 Inst->eraseFromParent();
223 case Instruction::BitCast:
224 case Instruction::AddrSpaceCast:
229 llvm_unreachable("Inconsistency in instructions promotable to vector");
235 static void collectUsesWithPtrTypes(Value *Val, std::vector<Value*> &WorkList) {
236 for (User *User : Val->users()) {
237 if(std::find(WorkList.begin(), WorkList.end(), User) != WorkList.end())
239 if (isa<CallInst>(User)) {
240 WorkList.push_back(User);
243 if (!User->getType()->isPointerTy())
245 WorkList.push_back(User);
246 collectUsesWithPtrTypes(User, WorkList);
250 void AMDGPUPromoteAlloca::visitAlloca(AllocaInst &I) {
251 IRBuilder<> Builder(&I);
253 // First try to replace the alloca with a vector
254 Type *AllocaTy = I.getAllocatedType();
256 DEBUG(dbgs() << "Trying to promote " << I << '\n');
258 if (tryPromoteAllocaToVector(&I))
261 DEBUG(dbgs() << " alloca is not a candidate for vectorization.\n");
263 // FIXME: This is the maximum work group size. We should try to get
264 // value from the reqd_work_group_size function attribute if it is
266 unsigned WorkGroupSize = 256;
267 int AllocaSize = WorkGroupSize *
268 Mod->getDataLayout()->getTypeAllocSize(AllocaTy);
270 if (AllocaSize > LocalMemAvailable) {
271 DEBUG(dbgs() << " Not enough local memory to promote alloca.\n");
275 DEBUG(dbgs() << "Promoting alloca to local memory\n");
276 LocalMemAvailable -= AllocaSize;
278 GlobalVariable *GV = new GlobalVariable(
279 *Mod, ArrayType::get(I.getAllocatedType(), 256), false,
280 GlobalValue::ExternalLinkage, 0, I.getName(), 0,
281 GlobalVariable::NotThreadLocal, AMDGPUAS::LOCAL_ADDRESS);
283 FunctionType *FTy = FunctionType::get(
284 Type::getInt32Ty(Mod->getContext()), false);
285 AttributeSet AttrSet;
286 AttrSet.addAttribute(Mod->getContext(), 0, Attribute::ReadNone);
288 Value *ReadLocalSizeY = Mod->getOrInsertFunction(
289 "llvm.r600.read.local.size.y", FTy, AttrSet);
290 Value *ReadLocalSizeZ = Mod->getOrInsertFunction(
291 "llvm.r600.read.local.size.z", FTy, AttrSet);
292 Value *ReadTIDIGX = Mod->getOrInsertFunction(
293 "llvm.r600.read.tidig.x", FTy, AttrSet);
294 Value *ReadTIDIGY = Mod->getOrInsertFunction(
295 "llvm.r600.read.tidig.y", FTy, AttrSet);
296 Value *ReadTIDIGZ = Mod->getOrInsertFunction(
297 "llvm.r600.read.tidig.z", FTy, AttrSet);
300 Value *TCntY = Builder.CreateCall(ReadLocalSizeY);
301 Value *TCntZ = Builder.CreateCall(ReadLocalSizeZ);
302 Value *TIdX = Builder.CreateCall(ReadTIDIGX);
303 Value *TIdY = Builder.CreateCall(ReadTIDIGY);
304 Value *TIdZ = Builder.CreateCall(ReadTIDIGZ);
306 Value *Tmp0 = Builder.CreateMul(TCntY, TCntZ);
307 Tmp0 = Builder.CreateMul(Tmp0, TIdX);
308 Value *Tmp1 = Builder.CreateMul(TIdY, TCntZ);
309 Value *TID = Builder.CreateAdd(Tmp0, Tmp1);
310 TID = Builder.CreateAdd(TID, TIdZ);
312 std::vector<Value*> Indices;
313 Indices.push_back(Constant::getNullValue(Type::getInt32Ty(Mod->getContext())));
314 Indices.push_back(TID);
316 Value *Offset = Builder.CreateGEP(GV, Indices);
317 I.mutateType(Offset->getType());
318 I.replaceAllUsesWith(Offset);
321 std::vector<Value*> WorkList;
323 collectUsesWithPtrTypes(Offset, WorkList);
325 for (std::vector<Value*>::iterator i = WorkList.begin(),
326 e = WorkList.end(); i != e; ++i) {
328 CallInst *Call = dyn_cast<CallInst>(V);
330 Type *EltTy = V->getType()->getPointerElementType();
331 PointerType *NewTy = PointerType::get(EltTy, AMDGPUAS::LOCAL_ADDRESS);
333 // The operand's value should be corrected on its own.
334 if (isa<AddrSpaceCastInst>(V))
337 // FIXME: It doesn't really make sense to try to do this for all
339 V->mutateType(NewTy);
343 IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(Call);
345 std::vector<Type*> ArgTypes;
346 for (unsigned ArgIdx = 0, ArgEnd = Call->getNumArgOperands();
347 ArgIdx != ArgEnd; ++ArgIdx) {
348 ArgTypes.push_back(Call->getArgOperand(ArgIdx)->getType());
350 Function *F = Call->getCalledFunction();
351 FunctionType *NewType = FunctionType::get(Call->getType(), ArgTypes,
353 Constant *C = Mod->getOrInsertFunction(StringRef(F->getName().str() + ".local"), NewType,
355 Function *NewF = cast<Function>(C);
356 Call->setCalledFunction(NewF);
360 Builder.SetInsertPoint(Intr);
361 switch (Intr->getIntrinsicID()) {
362 case Intrinsic::lifetime_start:
363 case Intrinsic::lifetime_end:
364 // These intrinsics are for address space 0 only
365 Intr->eraseFromParent();
367 case Intrinsic::memcpy: {
368 MemCpyInst *MemCpy = cast<MemCpyInst>(Intr);
369 Builder.CreateMemCpy(MemCpy->getRawDest(), MemCpy->getRawSource(),
370 MemCpy->getLength(), MemCpy->getAlignment(),
371 MemCpy->isVolatile());
372 Intr->eraseFromParent();
375 case Intrinsic::memset: {
376 MemSetInst *MemSet = cast<MemSetInst>(Intr);
377 Builder.CreateMemSet(MemSet->getRawDest(), MemSet->getValue(),
378 MemSet->getLength(), MemSet->getAlignment(),
379 MemSet->isVolatile());
380 Intr->eraseFromParent();
385 llvm_unreachable("Don't know how to promote alloca intrinsic use.");
390 FunctionPass *llvm::createAMDGPUPromoteAlloca(const AMDGPUSubtarget &ST) {
391 return new AMDGPUPromoteAlloca(ST);