1 //===-- NVPTXLowerStructArgs.cpp - Copy struct args to local memory =====--===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // Copy struct args to local memory. This is needed for kernel functions only.
11 // This is a preparation for handling cases like
13 // kernel void foo(struct A arg, ...)
15 // struct A *p = &arg;
17 // ... = p->filed1 ... (this is no generic address for .param)
18 // p->filed2 = ... (this is no write access to .param)
21 //===----------------------------------------------------------------------===//
24 #include "NVPTXUtilities.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/Instructions.h"
27 #include "llvm/IR/IntrinsicInst.h"
28 #include "llvm/IR/Module.h"
29 #include "llvm/IR/Type.h"
30 #include "llvm/Pass.h"
35 void initializeNVPTXLowerStructArgsPass(PassRegistry &);
39 class NVPTXLowerStructArgs : public FunctionPass {
40 bool runOnFunction(Function &F) override;
42 void handleStructPtrArgs(Function &);
43 void handleParam(Argument *);
46 static char ID; // Pass identification, replacement for typeid
47 NVPTXLowerStructArgs() : FunctionPass(ID) {}
48 const char *getPassName() const override {
49 return "Copy structure (byval *) arguments to stack";
54 char NVPTXLowerStructArgs::ID = 1;
56 INITIALIZE_PASS(NVPTXLowerStructArgs, "nvptx-lower-struct-args",
57 "Lower structure arguments (NVPTX)", false, false)
59 void NVPTXLowerStructArgs::handleParam(Argument *Arg) {
60 Function *Func = Arg->getParent();
61 Instruction *FirstInst = &(Func->getEntryBlock().front());
62 PointerType *PType = dyn_cast<PointerType>(Arg->getType());
64 assert(PType && "Expecting pointer type in handleParam");
66 Type *StructType = PType->getElementType();
67 AllocaInst *AllocA = new AllocaInst(StructType, Arg->getName(), FirstInst);
69 /* Set the alignment to alignment of the byval parameter. This is because,
70 * later load/stores assume that alignment, and we are going to replace
71 * the use of the byval parameter with this alloca instruction.
73 AllocA->setAlignment(Func->getParamAlignment(Arg->getArgNo() + 1));
75 Arg->replaceAllUsesWith(AllocA);
77 // Get the cvt.gen.to.param intrinsic
79 Type::getInt8PtrTy(Func->getParent()->getContext(), ADDRESS_SPACE_PARAM),
80 Type::getInt8PtrTy(Func->getParent()->getContext(),
81 ADDRESS_SPACE_GENERIC)};
82 Function *CvtFunc = Intrinsic::getDeclaration(
83 Func->getParent(), Intrinsic::nvvm_ptr_gen_to_param, CvtTypes);
85 Value *BitcastArgs[] = {
86 new BitCastInst(Arg, Type::getInt8PtrTy(Func->getParent()->getContext(),
87 ADDRESS_SPACE_GENERIC),
88 Arg->getName(), FirstInst)};
90 CallInst::Create(CvtFunc, BitcastArgs, "cvt_to_param", FirstInst);
92 BitCastInst *BitCast = new BitCastInst(
93 CallCVT, PointerType::get(StructType, ADDRESS_SPACE_PARAM),
94 Arg->getName(), FirstInst);
95 LoadInst *LI = new LoadInst(BitCast, Arg->getName(), FirstInst);
96 new StoreInst(LI, AllocA, FirstInst);
99 // =============================================================================
100 // If the function had a struct ptr arg, say foo(%struct.x *byval %d), then
101 // add the following instructions to the first basic block :
103 // %temp = alloca %struct.x, align 8
104 // %tt1 = bitcast %struct.x * %d to i8 *
105 // %tt2 = llvm.nvvm.cvt.gen.to.param %tt2
106 // %tempd = bitcast i8 addrspace(101) * to %struct.x addrspace(101) *
107 // %tv = load %struct.x addrspace(101) * %tempd
108 // store %struct.x %tv, %struct.x * %temp, align 8
110 // The above code allocates some space in the stack and copies the incoming
111 // struct from param space to local space.
112 // Then replace all occurences of %d by %temp.
113 // =============================================================================
114 void NVPTXLowerStructArgs::handleStructPtrArgs(Function &F) {
115 for (Argument &Arg : F.args()) {
116 if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) {
122 // =============================================================================
123 // Main function for this pass.
124 // =============================================================================
125 bool NVPTXLowerStructArgs::runOnFunction(Function &F) {
126 // Skip non-kernels. See the comments at the top of this file.
127 if (!isKernelFunction(F))
130 handleStructPtrArgs(F);
134 FunctionPass *llvm::createNVPTXLowerStructArgsPass() {
135 return new NVPTXLowerStructArgs();