[NVPTX] Fix handling of vector arguments
authorJustin Holewinski <jholewinski@nvidia.com>
Sun, 24 Mar 2013 21:17:47 +0000 (21:17 +0000)
committerJustin Holewinski <jholewinski@nvidia.com>
Sun, 24 Mar 2013 21:17:47 +0000 (21:17 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@177847 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/NVPTX/NVPTXAsmPrinter.cpp
lib/Target/NVPTX/NVPTXISelLowering.cpp
test/CodeGen/NVPTX/vector-args.ll [new file with mode: 0644]

index 0115e1f5d3a82337f59147d0094f6e75028a0a4d..c0e8670658a91d496b739a435ff38a6467b6135e 100644 (file)
@@ -1481,7 +1481,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F,
   O << "(\n";
 
   for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
-    const Type *Ty = I->getType();
+    Type *Ty = I->getType();
 
     if (!first)
       O << ",\n";
@@ -1504,6 +1504,22 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F,
     }
 
     if (PAL.hasAttribute(paramIndex+1, Attribute::ByVal) == false) {
+      if (Ty->isVectorTy()) {
+        // Just print .param .b8 .align <a> .param[size];
+        // <a> = PAL.getparamalignment
+        // size = typeallocsize of element type
+        unsigned align = PAL.getParamAlignment(paramIndex+1);
+        if (align == 0)
+          align = TD->getABITypeAlignment(Ty);
+
+        unsigned sz = TD->getTypeAllocSize(Ty);
+        O << "\t.param .align " << align
+          << " .b8 ";
+        printParamName(I, paramIndex, O);
+        O << "[" << sz << "]";
+
+        continue;
+      }
       // Just a scalar
       const PointerType *PTy = dyn_cast<PointerType>(Ty);
       if (isKernelFunc) {
index e9a9fbfd04fbb9f8c3263788ecca3079fe109765..987d34b2f37d044e9bef5dda9744eea0bc394b6e 100644 (file)
@@ -1058,15 +1058,15 @@ NVPTXTargetLowering::LowerFormalArguments(SDValue Chain,
     theArgs.push_back(I);
     argTypes.push_back(I->getType());
   }
-  assert(argTypes.size() == Ins.size() &&
-         "Ins types and function types did not match");
+  //assert(argTypes.size() == Ins.size() &&
+  //       "Ins types and function types did not match");
 
   int idx = 0;
-  for (unsigned i=0, e=Ins.size(); i!=e; ++i, ++idx) {
+  for (unsigned i=0, e=argTypes.size(); i!=e; ++i, ++idx) {
     Type *Ty = argTypes[i];
     EVT ObjectVT = getValueType(Ty);
-    assert(ObjectVT == Ins[i].VT &&
-           "Ins type did not match function type");
+    //assert(ObjectVT == Ins[i].VT &&
+    //       "Ins type did not match function type");
 
     // If the kernel argument is image*_t or sampler_t, convert it to
     // a i32 constant holding the parameter position. This can later
@@ -1081,7 +1081,15 @@ NVPTXTargetLowering::LowerFormalArguments(SDValue Chain,
 
     if (theArgs[i]->use_empty()) {
       // argument is dead
-      InVals.push_back(DAG.getNode(ISD::UNDEF, dl, ObjectVT));
+      if (ObjectVT.isVector()) {
+        EVT EltVT = ObjectVT.getVectorElementType();
+        unsigned NumElts = ObjectVT.getVectorNumElements();
+        for (unsigned vi = 0; vi < NumElts; ++vi) {
+          InVals.push_back(DAG.getNode(ISD::UNDEF, dl, EltVT));
+        }
+      } else {
+        InVals.push_back(DAG.getNode(ISD::UNDEF, dl, ObjectVT));
+      }
       continue;
     }
 
@@ -1090,6 +1098,31 @@ NVPTXTargetLowering::LowerFormalArguments(SDValue Chain,
     // appear in the same order as their order of appearance
     // in the original function. "idx+1" holds that order.
     if (PAL.hasAttribute(i+1, Attribute::ByVal) == false) {
+      if (ObjectVT.isVector()) {
+        unsigned NumElts = ObjectVT.getVectorNumElements();
+        EVT EltVT = ObjectVT.getVectorElementType();
+        unsigned Offset = 0;
+        for (unsigned vi = 0; vi < NumElts; ++vi) {
+          SDValue A = getParamSymbol(DAG, idx, getPointerTy());
+          SDValue B = DAG.getIntPtrConstant(Offset);
+          SDValue Addr = DAG.getNode(ISD::ADD, dl, getPointerTy(),
+                                     //getParamSymbol(DAG, idx, EltVT),
+                                     //DAG.getConstant(Offset, getPointerTy()));
+                                     A, B);
+          Value *SrcValue = Constant::getNullValue(PointerType::get(
+                                            EltVT.getTypeForEVT(F->getContext()),
+                                            llvm::ADDRESS_SPACE_PARAM));
+          SDValue Ld = DAG.getLoad(EltVT, dl, Root, Addr,
+                                   MachinePointerInfo(SrcValue),
+                                   false, false, false,
+                                   TD->getABITypeAlignment(EltVT.getTypeForEVT(
+                                     F->getContext())));
+          Offset += EltVT.getStoreSizeInBits()/8;
+          InVals.push_back(Ld);
+        }
+        continue;
+      }
+
       // A plain scalar.
       if (isABI || isKernel) {
         // If ABI, load from the param symbol
diff --git a/test/CodeGen/NVPTX/vector-args.ll b/test/CodeGen/NVPTX/vector-args.ll
new file mode 100644 (file)
index 0000000..80deae4
--- /dev/null
@@ -0,0 +1,27 @@
+; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
+
+
+define float @foo(<2 x float> %a) {
+; CHECK: .func (.param .b32 func_retval0) foo
+; CHECK: .param .align 8 .b8 foo_param_0[8]
+; CHECK: ld.param.f32 %f{{[0-9]+}}
+; CHECK: ld.param.f32 %f{{[0-9]+}}
+  %t1 = fmul <2 x float> %a, %a
+  %t2 = extractelement <2 x float> %t1, i32 0
+  %t3 = extractelement <2 x float> %t1, i32 1
+  %t4 = fadd float %t2, %t3
+  ret float %t4
+}
+
+
+define float @bar(<4 x float> %a) {
+; CHECK: .func (.param .b32 func_retval0) bar
+; CHECK: .param .align 16 .b8 bar_param_0[16]
+; CHECK: ld.param.f32 %f{{[0-9]+}}
+; CHECK: ld.param.f32 %f{{[0-9]+}}
+  %t1 = fmul <4 x float> %a, %a
+  %t2 = extractelement <4 x float> %t1, i32 0
+  %t3 = extractelement <4 x float> %t1, i32 1
+  %t4 = fadd float %t2, %t3
+  ret float %t4
+}