From: Justin Holewinski Date: Thu, 23 Jun 2011 18:10:03 +0000 (+0000) Subject: PTX: Use .param space for parameters in device functions for SM >= 2.0 X-Git-Url: http://demsky.eecs.uci.edu/git/?a=commitdiff_plain;h=67a918486132309f224d152188747ca5e7f224ca;p=oota-llvm.git PTX: Use .param space for parameters in device functions for SM >= 2.0 FIXME: DCE is eliminating the final st.param.x calls, figure out why git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@133732 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Target/PTX/PTXAsmPrinter.cpp b/lib/Target/PTX/PTXAsmPrinter.cpp index b1f7c1e4b2b..0b055c24c7a 100644 --- a/lib/Target/PTX/PTXAsmPrinter.cpp +++ b/lib/Target/PTX/PTXAsmPrinter.cpp @@ -417,6 +417,7 @@ void PTXAsmPrinter::EmitFunctionDeclaration() { const PTXMachineFunctionInfo *MFI = MF->getInfo(); const bool isKernel = MFI->isKernel(); + const PTXSubtarget& ST = TM.getSubtarget(); std::string decl = isKernel ? ".entry" : ".func"; @@ -452,7 +453,7 @@ void PTXAsmPrinter::EmitFunctionDeclaration() { if (i != b) { decl += ", "; } - if (isKernel) { + if (isKernel || ST.getShaderModel() >= PTXSubtarget::PTX_SM_2_0) { decl += ".param .b"; decl += utostr(*i); decl += " "; diff --git a/lib/Target/PTX/PTXISelDAGToDAG.cpp b/lib/Target/PTX/PTXISelDAGToDAG.cpp index b3c85da7b44..1cae8f33bb6 100644 --- a/lib/Target/PTX/PTXISelDAGToDAG.cpp +++ b/lib/Target/PTX/PTXISelDAGToDAG.cpp @@ -15,6 +15,7 @@ #include "PTXTargetMachine.h" #include "llvm/CodeGen/SelectionDAGISel.h" #include "llvm/DerivedTypes.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; @@ -42,7 +43,8 @@ class PTXDAGToDAGISel : public SelectionDAGISel { private: SDNode *SelectREAD_PARAM(SDNode *Node); - + //SDNode *SelectSTORE_PARAM(SDNode *Node); + // We need this only because we can't match intruction BRAdp // pattern (PTXbrcond bb:$d, ...) in PTXInstrInfo.td SDNode *SelectBRCOND(SDNode *Node); @@ -69,6 +71,8 @@ SDNode *PTXDAGToDAGISel::Select(SDNode *Node) { switch (Node->getOpcode()) { case PTXISD::READ_PARAM: return SelectREAD_PARAM(Node); + // case PTXISD::STORE_PARAM: + // return SelectSTORE_PARAM(Node); case ISD::BRCOND: return SelectBRCOND(Node); default: @@ -86,20 +90,15 @@ SDNode *PTXDAGToDAGISel::SelectREAD_PARAM(SDNode *Node) { if (Node->getValueType(0) == MVT::i16) { opcode = PTX::LDpiU16; - } - else if (Node->getValueType(0) == MVT::i32) { + } else if (Node->getValueType(0) == MVT::i32) { opcode = PTX::LDpiU32; - } - else if (Node->getValueType(0) == MVT::i64) { + } else if (Node->getValueType(0) == MVT::i64) { opcode = PTX::LDpiU64; - } - else if (Node->getValueType(0) == MVT::f32) { + } else if (Node->getValueType(0) == MVT::f32) { opcode = PTX::LDpiF32; - } - else if (Node->getValueType(0) == MVT::f64) { + } else if (Node->getValueType(0) == MVT::f64) { opcode = PTX::LDpiF64; - } - else { + } else { llvm_unreachable("Unknown parameter type for ld.param"); } @@ -107,6 +106,42 @@ SDNode *PTXDAGToDAGISel::SelectREAD_PARAM(SDNode *Node) { GetPTXMachineNode(CurDAG, opcode, dl, Node->getValueType(0), index); } +// SDNode *PTXDAGToDAGISel::SelectSTORE_PARAM(SDNode *Node) { +// SDValue Chain = Node->getOperand(0); +// SDValue index = Node->getOperand(1); +// SDValue value = Node->getOperand(2); +// DebugLoc dl = Node->getDebugLoc(); +// unsigned opcode; + +// if (index.getOpcode() != ISD::TargetConstant) +// llvm_unreachable("STORE_PARAM: index is not ISD::TargetConstant"); + +// if (value->getValueType(0) == MVT::i16) { +// opcode = PTX::STpiU16; +// } else if (value->getValueType(0) == MVT::i32) { +// opcode = PTX::STpiU32; +// } else if (value->getValueType(0) == MVT::i64) { +// opcode = PTX::STpiU64; +// } else if (value->getValueType(0) == MVT::f32) { +// opcode = PTX::STpiF32; +// } else if (value->getValueType(0) == MVT::f64) { +// opcode = PTX::STpiF64; +// } else { +// llvm_unreachable("Unknown parameter type for st.param"); +// } + +// SDVTList VTs = CurDAG->getVTList(MVT::Other, MVT::Glue); +// SDValue PredReg = CurDAG->getRegister(PTX::NoRegister, MVT::i1); +// SDValue PredOp = CurDAG->getTargetConstant(PTX::PRED_NORMAL, MVT::i32); +// SDValue Ops[] = { Chain, index, value, PredReg, PredOp }; +// //SDNode *RetNode = PTXInstrInfo:: +// // GetPTXMachineNode(CurDAG, opcode, dl, VTs, index, value); +// SDNode *RetNode = CurDAG->getMachineNode(opcode, dl, VTs, Ops, array_lengthof(Ops)); +// DEBUG(dbgs() << "SelectSTORE_PARAM: Selected: "); +// RetNode->dumpr(CurDAG); +// return RetNode; +// } + SDNode *PTXDAGToDAGISel::SelectBRCOND(SDNode *Node) { assert(Node->getNumOperands() >= 3); diff --git a/lib/Target/PTX/PTXISelLowering.cpp b/lib/Target/PTX/PTXISelLowering.cpp index c3cdabad51d..782d9165957 100644 --- a/lib/Target/PTX/PTXISelLowering.cpp +++ b/lib/Target/PTX/PTXISelLowering.cpp @@ -15,6 +15,7 @@ #include "PTXISelLowering.h" #include "PTXMachineFunctionInfo.h" #include "PTXRegisterInfo.h" +#include "PTXSubtarget.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFunction.h" @@ -106,6 +107,8 @@ const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const { return "PTXISD::COPY_ADDRESS"; case PTXISD::READ_PARAM: return "PTXISD::READ_PARAM"; + case PTXISD::STORE_PARAM: + return "PTXISD::STORE_PARAM"; case PTXISD::EXIT: return "PTXISD::EXIT"; case PTXISD::RET: @@ -192,6 +195,7 @@ SDValue PTXTargetLowering:: if (isVarArg) llvm_unreachable("PTX does not support varargs"); MachineFunction &MF = DAG.getMachineFunction(); + const PTXSubtarget& ST = getTargetMachine().getSubtarget(); PTXMachineFunctionInfo *MFI = MF.getInfo(); switch (CallConv) { @@ -206,11 +210,16 @@ SDValue PTXTargetLowering:: break; } - if (MFI->isKernel()) { - // For kernel functions, we just need to emit the proper READ_PARAM ISDs + // We do one of two things here: + // IsKernel || SM >= 2.0 -> Use param space for arguments + // SM < 2.0 -> Use registers for arguments + + if (MFI->isKernel() || ST.getShaderModel() >= PTXSubtarget::PTX_SM_2_0) { + // We just need to emit the proper READ_PARAM ISDs for (unsigned i = 0, e = Ins.size(); i != e; ++i) { - assert(Ins[i].VT != MVT::i1 && "Kernels cannot take pred operands"); + assert((!MFI->isKernel() || Ins[i].VT != MVT::i1) && + "Kernels cannot take pred operands"); SDValue ArgValue = DAG.getNode(PTXISD::READ_PARAM, dl, Ins[i].VT, Chain, DAG.getTargetConstant(i, MVT::i32)); @@ -299,31 +308,49 @@ SDValue PTXTargetLowering:: MachineFunction& MF = DAG.getMachineFunction(); PTXMachineFunctionInfo *MFI = MF.getInfo(); - SmallVector RVLocs; - CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(), - getTargetMachine(), RVLocs, *DAG.getContext()); + const PTXSubtarget& ST = getTargetMachine().getSubtarget(); SDValue Flag; - CCInfo.AnalyzeReturn(Outs, RetCC_PTX); + if (ST.getShaderModel() >= PTXSubtarget::PTX_SM_2_0) { + // For SM 2.0+, we return arguments in the param space + for (unsigned i = 0, e = Outs.size(); i != e; ++i) { + SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue); + SDValue ParamIndex = DAG.getTargetConstant(i, MVT::i32); + SDValue Ops[] = { Chain, ParamIndex, OutVals[i], Flag }; + Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, VTs, Ops, + Flag.getNode() ? 4 : 3); + Flag = Chain.getValue(1); + // Instead of storing a physical register in our argument list, we just + // store the total size of the parameter, in bits. The ASM printer + // knows how to process this. + MFI->addRetReg(Outs[i].VT.getStoreSizeInBits()); + } + } else { + // For SM < 2.0, we return arguments in registers + SmallVector RVLocs; + CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(), + getTargetMachine(), RVLocs, *DAG.getContext()); - for (unsigned i = 0, e = RVLocs.size(); i != e; ++i) { + CCInfo.AnalyzeReturn(Outs, RetCC_PTX); - CCValAssign& VA = RVLocs[i]; + for (unsigned i = 0, e = RVLocs.size(); i != e; ++i) { + CCValAssign& VA = RVLocs[i]; - assert(VA.isRegLoc() && "CCValAssign must be RegLoc"); + assert(VA.isRegLoc() && "CCValAssign must be RegLoc"); - unsigned Reg = VA.getLocReg(); + unsigned Reg = VA.getLocReg(); - DAG.getMachineFunction().getRegInfo().addLiveOut(Reg); + DAG.getMachineFunction().getRegInfo().addLiveOut(Reg); - Chain = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i], Flag); + Chain = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i], Flag); - // Guarantee that all emitted copies are stuck together, - // avoiding something bad - Flag = Chain.getValue(1); + // Guarantee that all emitted copies are stuck together, + // avoiding something bad + Flag = Chain.getValue(1); - MFI->addRetReg(Reg); + MFI->addRetReg(Reg); + } } if (Flag.getNode() == 0) { diff --git a/lib/Target/PTX/PTXISelLowering.h b/lib/Target/PTX/PTXISelLowering.h index ead17edc01a..e33c0bd9540 100644 --- a/lib/Target/PTX/PTXISelLowering.h +++ b/lib/Target/PTX/PTXISelLowering.h @@ -25,11 +25,12 @@ namespace PTXISD { enum NodeType { FIRST_NUMBER = ISD::BUILTIN_OP_END, READ_PARAM, + STORE_PARAM, EXIT, RET, COPY_ADDRESS }; -} // namespace PTXISD +} // namespace PTXISD class PTXTargetLowering : public TargetLowering { public: diff --git a/lib/Target/PTX/PTXInstrInfo.td b/lib/Target/PTX/PTXInstrInfo.td index cc7494412bc..b5597d4addc 100644 --- a/lib/Target/PTX/PTXInstrInfo.td +++ b/lib/Target/PTX/PTXInstrInfo.td @@ -180,10 +180,15 @@ def PTXsra : SDNode<"ISD::SRA", SDTIntBinOp>; def PTXexit : SDNode<"PTXISD::EXIT", SDTNone, [SDNPHasChain]>; def PTXret - : SDNode<"PTXISD::RET", SDTNone, [SDNPHasChain]>; + : SDNode<"PTXISD::RET", SDTNone, + [SDNPHasChain, SDNPOptInGlue, SDNPVariadic]>; def PTXcopyaddress : SDNode<"PTXISD::COPY_ADDRESS", SDTypeProfile<1, 1, []>, []>; +def PTXstoreparam + : SDNode<"PTXISD::STORE_PARAM", SDTypeProfile<0, 2, [SDTCisVT<0, i32>]>, + [SDNPHasChain, SDNPOutGlue, SDNPOptInGlue]>; + //===----------------------------------------------------------------------===// // Instruction Class Templates //===----------------------------------------------------------------------===// @@ -816,7 +821,7 @@ defm LDc : PTX_LD_ALL<"ld.const", load_constant>; defm LDl : PTX_LD_ALL<"ld.local", load_local>; defm LDs : PTX_LD_ALL<"ld.shared", load_shared>; -// This is a special instruction that is manually inserted for kernel parameters +// This is a special instruction that is manually inserted for parameters def LDpiU16 : InstPTX<(outs RegI16:$d), (ins MEMpi:$a), "ld.param.u16\t$d, [$a]", []>; def LDpiU32 : InstPTX<(outs RegI32:$d), (ins MEMpi:$a), @@ -828,6 +833,23 @@ def LDpiF32 : InstPTX<(outs RegF32:$d), (ins MEMpi:$a), def LDpiF64 : InstPTX<(outs RegF64:$d), (ins MEMpi:$a), "ld.param.f64\t$d, [$a]", []>; +// def STpiPred : InstPTX<(outs), (ins i1imm:$d, RegPred:$a), +// "st.param.pred\t[$d], $a", +// [(PTXstoreparam imm:$d, RegPred:$a)]>; +// def STpiU16 : InstPTX<(outs), (ins i16imm:$d, RegI16:$a), +// "st.param.u16\t[$d], $a", +// [(PTXstoreparam imm:$d, RegI16:$a)]>; +def STpiU32 : InstPTX<(outs), (ins i32imm:$d, RegI32:$a), + "st.param.u32\t[$d], $a", + [(PTXstoreparam timm:$d, RegI32:$a)]>; +// def STpiU64 : InstPTX<(outs), (ins i64imm:$d, RegI64:$a), +// "st.param.u64\t[$d], $a", +// [(PTXstoreparam imm:$d, RegI64:$a)]>; +// def STpiF32 : InstPTX<(outs), (ins MEMpi:$d, RegF32:$a), +// "st.param.f32\t[$d], $a", []>; +// def STpiF64 : InstPTX<(outs), (ins MEMpi:$d, RegF64:$a), +// "st.param.f64\t[$d], $a", []>; + // Stores defm STg : PTX_ST_ALL<"st.global", store_global>; defm STl : PTX_ST_ALL<"st.local", store_local>; diff --git a/lib/Target/PTX/PTXSubtarget.h b/lib/Target/PTX/PTXSubtarget.h index c8f8c3b00d0..2ebe6cfdc83 100644 --- a/lib/Target/PTX/PTXSubtarget.h +++ b/lib/Target/PTX/PTXSubtarget.h @@ -18,7 +18,7 @@ namespace llvm { class PTXSubtarget : public TargetSubtarget { - private: + public: /** * Enumeration of Shader Models supported by the back-end. @@ -41,6 +41,8 @@ namespace llvm { PTX_VERSION_2_3 /*< PTX Version 2.3 */ }; + private: + /// Shader Model supported on the target GPU. PTXShaderModelEnum PTXShaderModel; @@ -58,8 +60,10 @@ namespace llvm { bool Is64Bit; public: + PTXSubtarget(const std::string &TT, const std::string &FS, bool is64Bit); + // Target architecture accessors std::string getTargetString() const; std::string getPTXVersionString() const; @@ -80,6 +84,9 @@ namespace llvm { bool supportsPTX23() const { return PTXVersion >= PTX_VERSION_2_3; } + PTXShaderModelEnum getShaderModel() const { return PTXShaderModel; } + + std::string ParseSubtargetFeatures(const std::string &FS, const std::string &CPU); }; // class PTXSubtarget