[NVPTX] convert pointers in byval kernel arguments to global
authorJingyue Wu <jingyue@google.com>
Fri, 31 Jul 2015 21:44:14 +0000 (21:44 +0000)
committerJingyue Wu <jingyue@google.com>
Fri, 31 Jul 2015 21:44:14 +0000 (21:44 +0000)
Summary:
For example, in

  struct S {
    int *x;
    int *y;
  };
  __global__ void foo(S s) {
    int *b = s.y;
    // use b
  }

"b" is guaranteed to point to global. NVPTX should emit ld.global/st.global for
accessing "b".

Reviewers: jholewinski

Subscribers: llvm-commits, jholewinski

Differential Revision: http://reviews.llvm.org/D11505

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@243790 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/NVPTX/NVPTXLowerKernelArgs.cpp
test/CodeGen/NVPTX/lower-kernel-ptr-arg.ll

index b533f316d8a989594b18f9c93f3d4e76dd133584..daf44482934b2e7a4c3ea5fdc0213f2c500aa435 100644 (file)
 //      ...
 //    }
 //
+// 3. Convert pointers in a byval kernel parameter to pointers in the global
+//    address space. As #2, it allows NVPTX to emit more ld/st.global. E.g.,
+//
+//    struct S {
+//      int *x;
+//      int *y;
+//    };
+//    __global__ void foo(S s) {
+//      int *b = s.y;
+//      // use b
+//    }
+//
+//    "b" points to the global address space. In the IR level,
+//
+//    define void @foo({i32*, i32*}* byval %input) {
+//      %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
+//      %b = load i32*, i32** %b_ptr
+//      ; use %b
+//    }
+//
+//    becomes
+//
+//    define void @foo({i32*, i32*}* byval %input) {
+//      %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
+//      %b = load i32*, i32** %b_ptr
+//      %b_global = addrspacecast i32* %b to i32 addrspace(1)*
+//      %b_generic = addrspacecast i32 addrspace(1)* %b_global to i32*
+//      ; use %b_generic
+//    }
+//
 // TODO: merge this pass with NVPTXFavorNonGenericAddrSpace so that other passes
 // don't cancel the addrspacecast pair this pass emits.
 //===----------------------------------------------------------------------===//
@@ -54,6 +84,7 @@
 #include "NVPTX.h"
 #include "NVPTXUtilities.h"
 #include "NVPTXTargetMachine.h"
+#include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
@@ -71,9 +102,12 @@ class NVPTXLowerKernelArgs : public FunctionPass {
   bool runOnFunction(Function &F) override;
 
   // handle byval parameters
-  void handleByValParam(Argument *);
-  // handle non-byval pointer parameters
-  void handlePointerParam(Argument *);
+  void handleByValParam(Argument *Arg);
+  // Knowing Ptr must point to the global address space, this function
+  // addrspacecasts Ptr to global and then back to generic. This allows
+  // NVPTXFavorNonGenericAddrSpace to fold the global-to-generic cast into
+  // loads/stores that appear later.
+  void markPointerAsGlobal(Value *Ptr);
 
 public:
   static char ID; // Pass identification, replacement for typeid
@@ -128,27 +162,34 @@ void NVPTXLowerKernelArgs::handleByValParam(Argument *Arg) {
   new StoreInst(LI, AllocA, FirstInst);
 }
 
-void NVPTXLowerKernelArgs::handlePointerParam(Argument *Arg) {
-  assert(!Arg->hasByValAttr() &&
-         "byval params should be handled by handleByValParam");
-
-  // Do nothing if the argument already points to the global address space.
-  if (Arg->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GLOBAL)
+void NVPTXLowerKernelArgs::markPointerAsGlobal(Value *Ptr) {
+  if (Ptr->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GLOBAL)
     return;
 
-  Instruction *FirstInst = Arg->getParent()->getEntryBlock().begin();
-  Instruction *ArgInGlobal = new AddrSpaceCastInst(
-      Arg, PointerType::get(Arg->getType()->getPointerElementType(),
+  // Deciding where to emit the addrspacecast pair.
+  BasicBlock::iterator InsertPt;
+  if (Argument *Arg = dyn_cast<Argument>(Ptr)) {
+    // Insert at the functon entry if Ptr is an argument.
+    InsertPt = Arg->getParent()->getEntryBlock().begin();
+  } else {
+    // Insert right after Ptr if Ptr is an instruction.
+    InsertPt = cast<Instruction>(Ptr);
+    ++InsertPt;
+    assert(InsertPt != InsertPt->getParent()->end() &&
+           "We don't call this function with Ptr being a terminator.");
+  }
+
+  Instruction *PtrInGlobal = new AddrSpaceCastInst(
+      Ptr, PointerType::get(Ptr->getType()->getPointerElementType(),
                             ADDRESS_SPACE_GLOBAL),
-      Arg->getName(), FirstInst);
-  Value *ArgInGeneric = new AddrSpaceCastInst(ArgInGlobal, Arg->getType(),
-                                              Arg->getName(), FirstInst);
-  // Replace with ArgInGeneric all uses of Args except ArgInGlobal.
-  Arg->replaceAllUsesWith(ArgInGeneric);
-  ArgInGlobal->setOperand(0, Arg);
+      Ptr->getName(), InsertPt);
+  Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(),
+                                              Ptr->getName(), InsertPt);
+  // Replace with PtrInGeneric all uses of Ptr except PtrInGlobal.
+  Ptr->replaceAllUsesWith(PtrInGeneric);
+  PtrInGlobal->setOperand(0, Ptr);
 }
 
-
 // =============================================================================
 // Main function for this pass.
 // =============================================================================
@@ -157,12 +198,32 @@ bool NVPTXLowerKernelArgs::runOnFunction(Function &F) {
   if (!isKernelFunction(F))
     return false;
 
+  if (TM && TM->getDrvInterface() == NVPTX::CUDA) {
+    // Mark pointers in byval structs as global.
+    for (auto &B : F) {
+      for (auto &I : B) {
+        if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
+          if (LI->getType()->isPointerTy()) {
+            Value *UO = GetUnderlyingObject(LI->getPointerOperand(),
+                                            F.getParent()->getDataLayout());
+            if (Argument *Arg = dyn_cast<Argument>(UO)) {
+              if (Arg->hasByValAttr()) {
+                // LI is a load from a pointer within a byval kernel parameter.
+                markPointerAsGlobal(LI);
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+
   for (Argument &Arg : F.args()) {
     if (Arg.getType()->isPointerTy()) {
       if (Arg.hasByValAttr())
         handleByValParam(&Arg);
       else if (TM && TM->getDrvInterface() == NVPTX::CUDA)
-        handlePointerParam(&Arg);
+        markPointerAsGlobal(&Arg);
     }
   }
   return true;
index 0de72c4a1aed0bad7378374608ceab819d886cf3..2fffa3eeac15f4a749684ff4ea3e91a15c360156 100644 (file)
@@ -1,7 +1,7 @@
 ; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 | FileCheck %s
 
 target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
-target triple = "nvptx64-unknown-unknown"
+target triple = "nvptx64-nvidia-cuda"
 
 ; Verify that both %input and %output are converted to global pointers and then
 ; addrspacecast'ed back to the original type.
@@ -26,6 +26,22 @@ define void @kernel2(float addrspace(1)* %input, float addrspace(1)* %output) {
   ret void
 }
 
-!nvvm.annotations = !{!0, !1}
+%struct.S = type { i32*, i32* }
+
+define void @ptr_in_byval(%struct.S* byval %input, i32* %output) {
+; CHECK-LABEL: .visible .entry ptr_in_byval(
+; CHECK: cvta.to.global.u64
+; CHECK: cvta.to.global.u64
+  %b_ptr = getelementptr inbounds %struct.S, %struct.S* %input, i64 0, i32 1
+  %b = load i32*, i32** %b_ptr, align 4
+  %v = load i32, i32* %b, align 4
+; CHECK: ld.global.u32
+  store i32 %v, i32* %output, align 4
+; CHECK: st.global.u32
+  ret void
+}
+
+!nvvm.annotations = !{!0, !1, !2}
 !0 = !{void (float*, float*)* @kernel, !"kernel", i32 1}
 !1 = !{void (float addrspace(1)*, float addrspace(1)*)* @kernel2, !"kernel", i32 1}
+!2 = !{void (%struct.S*, i32*)* @ptr_in_byval, !"kernel", i32 1}