PTX: Prevent DCE from eliminating st.param calls, and unify the handling of
authorJustin Holewinski <justin.holewinski@gmail.com>
Thu, 23 Jun 2011 18:10:05 +0000 (18:10 +0000)
committerJustin Holewinski <justin.holewinski@gmail.com>
Thu, 23 Jun 2011 18:10:05 +0000 (18:10 +0000)
     st.param and ld.param

FIXME: Test cases still need to be updated

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@133733 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/PTX/PTXAsmPrinter.cpp
lib/Target/PTX/PTXISelDAGToDAG.cpp
lib/Target/PTX/PTXISelLowering.cpp
lib/Target/PTX/PTXISelLowering.h
lib/Target/PTX/PTXInstrInfo.td

index 0b055c24c7aebc6048f19d7d25eec910621c43fb..87b390383bf0bf70f59ea339f71e78c573e44112 100644 (file)
@@ -63,6 +63,8 @@ public:
                        const char *Modifier = 0);
   void printParamOperand(const MachineInstr *MI, int opNum, raw_ostream &OS,
                          const char *Modifier = 0);
+  void printReturnOperand(const MachineInstr *MI, int opNum, raw_ostream &OS,
+                          const char *Modifier = 0); 
   void printPredicateOperand(const MachineInstr *MI, raw_ostream &O);
 
   // autogen'd.
@@ -76,6 +78,7 @@ private:
 } // namespace
 
 static const char PARAM_PREFIX[] = "__param_";
+static const char RETURN_PREFIX[] = "__ret_";
 
 static const char *getRegisterTypeName(unsigned RegNo) {
 #define TEST_REGCLS(cls, clsstr)                \
@@ -298,6 +301,11 @@ void PTXAsmPrinter::printParamOperand(const MachineInstr *MI, int opNum,
   OS << PARAM_PREFIX << (int) MI->getOperand(opNum).getImm() + 1;
 }
 
+void PTXAsmPrinter::printReturnOperand(const MachineInstr *MI, int opNum,
+                                       raw_ostream &OS, const char *Modifier) {
+  OS << RETURN_PREFIX << (int) MI->getOperand(opNum).getImm() + 1;
+}
+
 void PTXAsmPrinter::EmitVariableDeclaration(const GlobalVariable *gv) {
   // Check to see if this is a special global used by LLVM, if so, emit it.
   if (EmitSpecialLLVMGlobal(gv))
@@ -421,6 +429,8 @@ void PTXAsmPrinter::EmitFunctionDeclaration() {
 
   std::string decl = isKernel ? ".entry" : ".func";
 
+  unsigned cnt = 0;
+
   if (!isKernel) {
     decl += " (";
 
@@ -430,10 +440,18 @@ void PTXAsmPrinter::EmitFunctionDeclaration() {
       if (i != b) {
         decl += ", ";
       }
-      decl += ".reg .";
-      decl += getRegisterTypeName(*i);
-      decl += " ";
-      decl += getRegisterName(*i);
+      if (ST.getShaderModel() >= PTXSubtarget::PTX_SM_2_0) {
+        decl += ".param .b";
+        decl += utostr(*i);
+        decl += " ";
+        decl += RETURN_PREFIX;
+        decl += utostr(++cnt);
+      } else {
+        decl += ".reg .";
+        decl += getRegisterTypeName(*i);
+        decl += " ";
+        decl += getRegisterName(*i);
+      }
     }
     decl += ")";
   }
@@ -444,7 +462,7 @@ void PTXAsmPrinter::EmitFunctionDeclaration() {
 
   decl += " (";
 
-  unsigned cnt = 0;
+  cnt = 0;
 
   // Print parameters
   for (PTXMachineFunctionInfo::reg_iterator
index 1cae8f33bb6071e244ea1d1b8f2419708589f635..9adfa624b29ed4a77a8bf960deba63ba01c57f60 100644 (file)
@@ -42,9 +42,6 @@ class PTXDAGToDAGISel : public SelectionDAGISel {
 #include "PTXGenDAGISel.inc"
 
   private:
-    SDNode *SelectREAD_PARAM(SDNode *Node);
-    //SDNode *SelectSTORE_PARAM(SDNode *Node);
-    
     // We need this only because we can't match intruction BRAdp
     // pattern (PTXbrcond bb:$d, ...) in PTXInstrInfo.td
     SDNode *SelectBRCOND(SDNode *Node);
@@ -69,10 +66,6 @@ PTXDAGToDAGISel::PTXDAGToDAGISel(PTXTargetMachine &TM,
 
 SDNode *PTXDAGToDAGISel::Select(SDNode *Node) {
   switch (Node->getOpcode()) {
-    case PTXISD::READ_PARAM:
-      return SelectREAD_PARAM(Node);
-    // case PTXISD::STORE_PARAM:
-    //   return SelectSTORE_PARAM(Node);
     case ISD::BRCOND:
       return SelectBRCOND(Node);
     default:
@@ -80,68 +73,6 @@ SDNode *PTXDAGToDAGISel::Select(SDNode *Node) {
   }
 }
 
-SDNode *PTXDAGToDAGISel::SelectREAD_PARAM(SDNode *Node) {
-  SDValue  index = Node->getOperand(1);
-  DebugLoc dl    = Node->getDebugLoc();
-  unsigned opcode;
-
-  if (index.getOpcode() != ISD::TargetConstant)
-    llvm_unreachable("READ_PARAM: index is not ISD::TargetConstant");
-
-  if (Node->getValueType(0) == MVT::i16) {
-    opcode = PTX::LDpiU16;
-  } else if (Node->getValueType(0) == MVT::i32) {
-    opcode = PTX::LDpiU32;
-  } else if (Node->getValueType(0) == MVT::i64) {
-    opcode = PTX::LDpiU64;
-  } else if (Node->getValueType(0) == MVT::f32) {
-    opcode = PTX::LDpiF32;
-  } else if (Node->getValueType(0) == MVT::f64) {
-    opcode = PTX::LDpiF64;
-  } else {
-    llvm_unreachable("Unknown parameter type for ld.param");
-  }
-
-  return PTXInstrInfo::
-    GetPTXMachineNode(CurDAG, opcode, dl, Node->getValueType(0), index);
-}
-
-// SDNode *PTXDAGToDAGISel::SelectSTORE_PARAM(SDNode *Node) {
-//   SDValue  Chain = Node->getOperand(0);
-//   SDValue  index = Node->getOperand(1);
-//   SDValue  value = Node->getOperand(2);
-//   DebugLoc dl    = Node->getDebugLoc();
-//   unsigned opcode;
-
-//   if (index.getOpcode() != ISD::TargetConstant)
-//     llvm_unreachable("STORE_PARAM: index is not ISD::TargetConstant");
-
-//   if (value->getValueType(0) == MVT::i16) {
-//     opcode = PTX::STpiU16;
-//   } else if (value->getValueType(0) == MVT::i32) {
-//     opcode = PTX::STpiU32;
-//   } else if (value->getValueType(0) == MVT::i64) {
-//     opcode = PTX::STpiU64;
-//   } else if (value->getValueType(0) == MVT::f32) {
-//     opcode = PTX::STpiF32;
-//   } else if (value->getValueType(0) == MVT::f64) {
-//     opcode = PTX::STpiF64;
-//   } else {
-//     llvm_unreachable("Unknown parameter type for st.param");
-//   }
-
-//   SDVTList VTs = CurDAG->getVTList(MVT::Other, MVT::Glue);
-//   SDValue PredReg = CurDAG->getRegister(PTX::NoRegister, MVT::i1);
-//   SDValue PredOp = CurDAG->getTargetConstant(PTX::PRED_NORMAL, MVT::i32);
-//   SDValue Ops[] = { Chain, index, value, PredReg, PredOp };
-//   //SDNode *RetNode = PTXInstrInfo::
-//   //  GetPTXMachineNode(CurDAG, opcode, dl, VTs, index, value);
-//   SDNode *RetNode = CurDAG->getMachineNode(opcode, dl, VTs, Ops, array_lengthof(Ops));
-//   DEBUG(dbgs() << "SelectSTORE_PARAM: Selected: ");
-//   RetNode->dumpr(CurDAG);
-//   return RetNode;
-// }
-
 SDNode *PTXDAGToDAGISel::SelectBRCOND(SDNode *Node) {
   assert(Node->getNumOperands() >= 3);
 
index 782d916595715b29cf43ab4f7b35e9343e0bdf31..34660bf8986646d9b4b912620c03fd4aa9d6c0d9 100644 (file)
@@ -105,8 +105,8 @@ const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
       llvm_unreachable("Unknown opcode");
     case PTXISD::COPY_ADDRESS:
       return "PTXISD::COPY_ADDRESS";
-    case PTXISD::READ_PARAM:
-      return "PTXISD::READ_PARAM";
+    case PTXISD::LOAD_PARAM:
+      return "PTXISD::LOAD_PARAM";
     case PTXISD::STORE_PARAM:
       return "PTXISD::STORE_PARAM";
     case PTXISD::EXIT:
@@ -215,13 +215,13 @@ SDValue PTXTargetLowering::
   // SM < 2.0               ->  Use registers for arguments
   
   if (MFI->isKernel() || ST.getShaderModel() >= PTXSubtarget::PTX_SM_2_0) {
-    // We just need to emit the proper READ_PARAM ISDs
+    // We just need to emit the proper LOAD_PARAM ISDs
     for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
 
       assert((!MFI->isKernel() || Ins[i].VT != MVT::i1) &&
              "Kernels cannot take pred operands");
 
-      SDValue ArgValue = DAG.getNode(PTXISD::READ_PARAM, dl, Ins[i].VT, Chain,
+      SDValue ArgValue = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain,
                                      DAG.getTargetConstant(i, MVT::i32));
       InVals.push_back(ArgValue);
 
index e33c0bd9540f7d552aa365c47d58bbeded30b5eb..43185416e1fc7fca8e4a5ed5c85e986b72c4f58d 100644 (file)
@@ -24,7 +24,7 @@ class PTXTargetMachine;
 namespace PTXISD {
   enum NodeType {
     FIRST_NUMBER = ISD::BUILTIN_OP_END,
-    READ_PARAM,
+    LOAD_PARAM,
     STORE_PARAM,
     EXIT,
     RET,
index b5597d4addc0ed38b5f438204da1d97c96b25542..1c18c4aa33e830f3b12129bc9767b6235d4e96bd 100644 (file)
@@ -163,6 +163,10 @@ def MEMpi : Operand<i32> {
   let PrintMethod = "printParamOperand";
   let MIOperandInfo = (ops i32imm);
 }
+def MEMret : Operand<i32> {
+  let PrintMethod = "printReturnOperand";
+  let MIOperandInfo = (ops i32imm);
+}
 
 // Branch & call targets have OtherVT type.
 def brtarget   : Operand<OtherVT>;
@@ -185,6 +189,10 @@ def PTXret
 def PTXcopyaddress
   : SDNode<"PTXISD::COPY_ADDRESS", SDTypeProfile<1, 1, []>, []>;
 
+// Load/store .param space
+def PTXloadparam
+  : SDNode<"PTXISD::LOAD_PARAM", SDTypeProfile<1, 1, [SDTCisVT<1, i32>]>,
+           [SDNPHasChain, SDNPOutGlue, SDNPOptInGlue]>;
 def PTXstoreparam
   : SDNode<"PTXISD::STORE_PARAM", SDTypeProfile<0, 2, [SDTCisVT<0, i32>]>,
            [SDNPHasChain, SDNPOutGlue, SDNPOptInGlue]>;
@@ -821,34 +829,48 @@ defm LDc : PTX_LD_ALL<"ld.const",  load_constant>;
 defm LDl : PTX_LD_ALL<"ld.local",  load_local>;
 defm LDs : PTX_LD_ALL<"ld.shared", load_shared>;
 
-// This is a special instruction that is manually inserted for parameters
-def LDpiU16 : InstPTX<(outs RegI16:$d), (ins MEMpi:$a),
-                      "ld.param.u16\t$d, [$a]", []>;
-def LDpiU32 : InstPTX<(outs RegI32:$d), (ins MEMpi:$a),
-                      "ld.param.u32\t$d, [$a]", []>;
-def LDpiU64 : InstPTX<(outs RegI64:$d), (ins MEMpi:$a),
-                      "ld.param.u64\t$d, [$a]", []>;
-def LDpiF32 : InstPTX<(outs RegF32:$d), (ins MEMpi:$a),
-                      "ld.param.f32\t$d, [$a]", []>;
-def LDpiF64 : InstPTX<(outs RegF64:$d), (ins MEMpi:$a),
-                      "ld.param.f64\t$d, [$a]", []>;
-
-// def STpiPred : InstPTX<(outs), (ins i1imm:$d, RegPred:$a),
-//                        "st.param.pred\t[$d], $a",
-//                        [(PTXstoreparam imm:$d, RegPred:$a)]>;
-// def STpiU16 : InstPTX<(outs), (ins i16imm:$d, RegI16:$a),
-//                       "st.param.u16\t[$d], $a",
-//                       [(PTXstoreparam imm:$d, RegI16:$a)]>;
-def STpiU32 : InstPTX<(outs), (ins i32imm:$d, RegI32:$a),
-                      "st.param.u32\t[$d], $a",
-                      [(PTXstoreparam timm:$d, RegI32:$a)]>;
-// def STpiU64 : InstPTX<(outs), (ins i64imm:$d, RegI64:$a),
-//                       "st.param.u64\t[$d], $a",
-//                       [(PTXstoreparam imm:$d, RegI64:$a)]>;
-// def STpiF32 : InstPTX<(outs), (ins MEMpi:$d, RegF32:$a),
-//                       "st.param.f32\t[$d], $a", []>;
-// def STpiF64 : InstPTX<(outs), (ins MEMpi:$d, RegF64:$a),
-//                       "st.param.f64\t[$d], $a", []>;
+// These instructions are used to load/store from the .param space for
+// device and kernel parameters
+
+let hasSideEffects = 1 in {
+  def LDpiPred : InstPTX<(outs RegPred:$d), (ins MEMpi:$a),
+                         "ld.param.pred\t$d, [$a]",
+                         [(set RegPred:$d, (PTXloadparam timm:$a))]>;
+  def LDpiU16  : InstPTX<(outs RegI16:$d), (ins MEMpi:$a),
+                         "ld.param.u16\t$d, [$a]",
+                         [(set RegI16:$d, (PTXloadparam timm:$a))]>;
+  def LDpiU32  : InstPTX<(outs RegI32:$d), (ins MEMpi:$a),
+                         "ld.param.u32\t$d, [$a]",
+                         [(set RegI32:$d, (PTXloadparam timm:$a))]>;
+  def LDpiU64  : InstPTX<(outs RegI64:$d), (ins MEMpi:$a),
+                         "ld.param.u64\t$d, [$a]",
+                         [(set RegI64:$d, (PTXloadparam timm:$a))]>;
+  def LDpiF32  : InstPTX<(outs RegF32:$d), (ins MEMpi:$a),
+                         "ld.param.f32\t$d, [$a]",
+                         [(set RegF32:$d, (PTXloadparam timm:$a))]>;
+  def LDpiF64  : InstPTX<(outs RegF64:$d), (ins MEMpi:$a),
+                         "ld.param.f64\t$d, [$a]",
+                         [(set RegF64:$d, (PTXloadparam timm:$a))]>;
+
+  def STpiPred : InstPTX<(outs), (ins MEMret:$d, RegPred:$a),
+                         "st.param.pred\t[$d], $a",
+                         [(PTXstoreparam timm:$d, RegPred:$a)]>;
+  def STpiU16  : InstPTX<(outs), (ins MEMret:$d, RegI16:$a),
+                         "st.param.u16\t[$d], $a",
+                         [(PTXstoreparam timm:$d, RegI16:$a)]>;
+  def STpiU32  : InstPTX<(outs), (ins MEMret:$d, RegI32:$a),
+                         "st.param.u32\t[$d], $a",
+                         [(PTXstoreparam timm:$d, RegI32:$a)]>;
+  def STpiU64  : InstPTX<(outs), (ins MEMret:$d, RegI64:$a),
+                         "st.param.u64\t[$d], $a",
+                         [(PTXstoreparam timm:$d, RegI64:$a)]>;
+  def STpiF32  : InstPTX<(outs), (ins MEMret:$d, RegF32:$a),
+                         "st.param.f32\t[$d], $a",
+                         [(PTXstoreparam timm:$d, RegF32:$a)]>;
+  def STpiF64  : InstPTX<(outs), (ins MEMret:$d, RegF64:$a),
+                         "st.param.f64\t[$d], $a",
+                         [(PTXstoreparam timm:$d, RegF64:$a)]>;
+}
 
 // Stores
 defm STg : PTX_ST_ALL<"st.global", store_global>;