1 //===-- PTXISelLowering.cpp - PTX DAG Lowering Implementation -------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // This file implements the PTXTargetLowering class.
12 //===----------------------------------------------------------------------===//
15 #include "PTXISelLowering.h"
16 #include "PTXMachineFunctionInfo.h"
17 #include "PTXRegisterInfo.h"
18 #include "PTXSubtarget.h"
19 #include "llvm/Function.h"
20 #include "llvm/Support/ErrorHandling.h"
21 #include "llvm/CodeGen/CallingConvLower.h"
22 #include "llvm/CodeGen/MachineFunction.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/SelectionDAG.h"
25 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/raw_ostream.h"
31 //===----------------------------------------------------------------------===//
32 // TargetLowering Implementation
33 //===----------------------------------------------------------------------===//
35 PTXTargetLowering::PTXTargetLowering(TargetMachine &TM)
36 : TargetLowering(TM, new TargetLoweringObjectFileELF()) {
37 // Set up the register classes.
38 addRegisterClass(MVT::i1, PTX::RegPredRegisterClass);
39 addRegisterClass(MVT::i16, PTX::RegI16RegisterClass);
40 addRegisterClass(MVT::i32, PTX::RegI32RegisterClass);
41 addRegisterClass(MVT::i64, PTX::RegI64RegisterClass);
42 addRegisterClass(MVT::f32, PTX::RegF32RegisterClass);
43 addRegisterClass(MVT::f64, PTX::RegF64RegisterClass);
45 setBooleanContents(ZeroOrOneBooleanContent);
46 setBooleanVectorContents(ZeroOrOneBooleanContent); // FIXME: Is this correct?
47 setMinFunctionAlignment(2);
49 ////////////////////////////////////
50 /////////// Expansion //////////////
51 ////////////////////////////////////
53 // (any/zero/sign) extload => load + (any/zero/sign) extend
55 setLoadExtAction(ISD::EXTLOAD, MVT::i16, Expand);
56 setLoadExtAction(ISD::ZEXTLOAD, MVT::i16, Expand);
57 setLoadExtAction(ISD::SEXTLOAD, MVT::i16, Expand);
59 // f32 extload => load + fextend
61 setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
63 // f64 truncstore => trunc + store
65 setTruncStoreAction(MVT::f64, MVT::f32, Expand);
67 // sign_extend_inreg => sign_extend
69 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
73 setOperationAction(ISD::BR_CC, MVT::Other, Expand);
77 setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
78 setOperationAction(ISD::SELECT_CC, MVT::f32, Expand);
79 setOperationAction(ISD::SELECT_CC, MVT::f64, Expand);
81 ////////////////////////////////////
82 //////////// Legal /////////////////
83 ////////////////////////////////////
85 setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
86 setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
88 ////////////////////////////////////
89 //////////// Custom ////////////////
90 ////////////////////////////////////
92 // customise setcc to use bitwise logic if possible
94 setOperationAction(ISD::SETCC, MVT::i1, Custom);
96 // customize translation of memory addresses
98 setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
99 setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
101 // Compute derived properties from the register classes
102 computeRegisterProperties();
105 EVT PTXTargetLowering::getSetCCResultType(EVT VT) const {
109 SDValue PTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
110 switch (Op.getOpcode()) {
112 llvm_unreachable("Unimplemented operand");
114 return LowerSETCC(Op, DAG);
115 case ISD::GlobalAddress:
116 return LowerGlobalAddress(Op, DAG);
120 const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
123 llvm_unreachable("Unknown opcode");
124 case PTXISD::COPY_ADDRESS:
125 return "PTXISD::COPY_ADDRESS";
126 case PTXISD::LOAD_PARAM:
127 return "PTXISD::LOAD_PARAM";
128 case PTXISD::STORE_PARAM:
129 return "PTXISD::STORE_PARAM";
130 case PTXISD::READ_PARAM:
131 return "PTXISD::READ_PARAM";
132 case PTXISD::WRITE_PARAM:
133 return "PTXISD::WRITE_PARAM";
135 return "PTXISD::EXIT";
137 return "PTXISD::RET";
139 return "PTXISD::CALL";
143 //===----------------------------------------------------------------------===//
144 // Custom Lower Operation
145 //===----------------------------------------------------------------------===//
147 SDValue PTXTargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
148 assert(Op.getValueType() == MVT::i1 && "SetCC type must be 1-bit integer");
149 SDValue Op0 = Op.getOperand(0);
150 SDValue Op1 = Op.getOperand(1);
151 SDValue Op2 = Op.getOperand(2);
152 DebugLoc dl = Op.getDebugLoc();
153 ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
155 // Look for X == 0, X == 1, X != 0, or X != 1
156 // We can simplify these to bitwise logic
158 if (Op1.getOpcode() == ISD::Constant &&
159 (cast<ConstantSDNode>(Op1)->getZExtValue() == 1 ||
160 cast<ConstantSDNode>(Op1)->isNullValue()) &&
161 (CC == ISD::SETEQ || CC == ISD::SETNE)) {
163 return DAG.getNode(ISD::AND, dl, MVT::i1, Op0, Op1);
166 return DAG.getNode(ISD::SETCC, dl, MVT::i1, Op0, Op1, Op2);
169 SDValue PTXTargetLowering::
170 LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
171 EVT PtrVT = getPointerTy();
172 DebugLoc dl = Op.getDebugLoc();
173 const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
175 assert(PtrVT.isSimple() && "Pointer must be to primitive type.");
177 SDValue targetGlobal = DAG.getTargetGlobalAddress(GV, dl, PtrVT);
178 SDValue movInstr = DAG.getNode(PTXISD::COPY_ADDRESS,
186 //===----------------------------------------------------------------------===//
187 // Calling Convention Implementation
188 //===----------------------------------------------------------------------===//
190 SDValue PTXTargetLowering::
191 LowerFormalArguments(SDValue Chain,
192 CallingConv::ID CallConv,
194 const SmallVectorImpl<ISD::InputArg> &Ins,
197 SmallVectorImpl<SDValue> &InVals) const {
198 if (isVarArg) llvm_unreachable("PTX does not support varargs");
200 MachineFunction &MF = DAG.getMachineFunction();
201 const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
202 PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
203 PTXParamManager &PM = MFI->getParamManager();
207 llvm_unreachable("Unsupported calling convention");
209 case CallingConv::PTX_Kernel:
210 MFI->setKernel(true);
212 case CallingConv::PTX_Device:
213 MFI->setKernel(false);
217 // We do one of two things here:
218 // IsKernel || SM >= 2.0 -> Use param space for arguments
219 // SM < 2.0 -> Use registers for arguments
220 if (MFI->isKernel() || ST.useParamSpaceForDeviceArgs()) {
221 // We just need to emit the proper LOAD_PARAM ISDs
222 for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
223 assert((!MFI->isKernel() || Ins[i].VT != MVT::i1) &&
224 "Kernels cannot take pred operands");
226 unsigned ParamSize = Ins[i].VT.getStoreSizeInBits();
227 unsigned Param = PM.addArgumentParam(ParamSize);
228 const std::string &ParamName = PM.getParamName(Param);
229 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
231 SDValue ArgValue = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain,
233 InVals.push_back(ArgValue);
237 for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
238 EVT RegVT = Ins[i].VT;
239 TargetRegisterClass* TRC = getRegClassFor(RegVT);
241 // Use a unique index in the instruction to prevent instruction folding.
242 // Yes, this is a hack.
243 SDValue Index = DAG.getTargetConstant(i, MVT::i32);
244 unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC);
245 SDValue ArgValue = DAG.getNode(PTXISD::READ_PARAM, dl, RegVT, Chain,
248 InVals.push_back(ArgValue);
257 SDValue PTXTargetLowering::
258 LowerReturn(SDValue Chain,
259 CallingConv::ID CallConv,
261 const SmallVectorImpl<ISD::OutputArg> &Outs,
262 const SmallVectorImpl<SDValue> &OutVals,
264 SelectionDAG &DAG) const {
265 if (isVarArg) llvm_unreachable("PTX does not support varargs");
269 llvm_unreachable("Unsupported calling convention.");
270 case CallingConv::PTX_Kernel:
271 assert(Outs.size() == 0 && "Kernel must return void.");
272 return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain);
273 case CallingConv::PTX_Device:
274 assert(Outs.size() <= 1 && "Can at most return one value.");
278 MachineFunction& MF = DAG.getMachineFunction();
279 PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
280 PTXParamManager &PM = MFI->getParamManager();
283 const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
285 if (ST.useParamSpaceForDeviceArgs()) {
286 assert(Outs.size() < 2 && "Device functions can return at most one value");
288 if (Outs.size() == 1) {
289 unsigned ParamSize = OutVals[0].getValueType().getSizeInBits();
290 unsigned Param = PM.addReturnParam(ParamSize);
291 const std::string &ParamName = PM.getParamName(Param);
292 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
294 Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
295 ParamValue, OutVals[0]);
298 for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
299 EVT RegVT = Outs[i].VT;
300 TargetRegisterClass* TRC = 0;
302 // Determine which register class we need
303 if (RegVT == MVT::i1) {
304 TRC = PTX::RegPredRegisterClass;
306 else if (RegVT == MVT::i16) {
307 TRC = PTX::RegI16RegisterClass;
309 else if (RegVT == MVT::i32) {
310 TRC = PTX::RegI32RegisterClass;
312 else if (RegVT == MVT::i64) {
313 TRC = PTX::RegI64RegisterClass;
315 else if (RegVT == MVT::f32) {
316 TRC = PTX::RegF32RegisterClass;
318 else if (RegVT == MVT::f64) {
319 TRC = PTX::RegF64RegisterClass;
322 llvm_unreachable("Unknown parameter type");
325 unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC);
327 SDValue Copy = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i]/*, Flag*/);
328 SDValue OutReg = DAG.getRegister(Reg, RegVT);
330 Chain = DAG.getNode(PTXISD::WRITE_PARAM, dl, MVT::Other, Copy, OutReg);
336 if (Flag.getNode() == 0) {
337 return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain);
340 return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain, Flag);
345 PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
346 CallingConv::ID CallConv, bool isVarArg,
348 const SmallVectorImpl<ISD::OutputArg> &Outs,
349 const SmallVectorImpl<SDValue> &OutVals,
350 const SmallVectorImpl<ISD::InputArg> &Ins,
351 DebugLoc dl, SelectionDAG &DAG,
352 SmallVectorImpl<SDValue> &InVals) const {
354 MachineFunction& MF = DAG.getMachineFunction();
355 PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
356 PTXParamManager &PM = MFI->getParamManager();
358 assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() &&
359 "Calls are not handled for the target device");
361 std::vector<SDValue> Ops;
362 // The layout of the ops will be [Chain, #Ins, Ins, Callee, #Outs, Outs]
363 Ops.resize(Outs.size() + Ins.size() + 4);
367 // Identify the callee function
368 const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal();
369 assert(cast<Function>(GV)->getCallingConv() == CallingConv::PTX_Device &&
370 "PTX function calls must be to PTX device functions");
371 Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
372 Ops[Ins.size()+2] = Callee;
374 // Generate STORE_PARAM nodes for each function argument. In PTX, function
375 // arguments are explicitly stored into .param variables and passed as
376 // arguments. There is no register/stack-based calling convention in PTX.
377 Ops[Ins.size()+3] = DAG.getTargetConstant(OutVals.size(), MVT::i32);
378 for (unsigned i = 0; i != OutVals.size(); ++i) {
379 unsigned Size = OutVals[i].getValueType().getSizeInBits();
380 unsigned Param = PM.addLocalParam(Size);
381 const std::string &ParamName = PM.getParamName(Param);
382 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
384 Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
385 ParamValue, OutVals[i]);
386 Ops[i+Ins.size()+4] = ParamValue;
389 std::vector<SDValue> InParams;
391 // Generate list of .param variables to hold the return value(s).
392 Ops[1] = DAG.getTargetConstant(Ins.size(), MVT::i32);
393 for (unsigned i = 0; i < Ins.size(); ++i) {
394 unsigned Size = Ins[i].VT.getStoreSizeInBits();
395 unsigned Param = PM.addLocalParam(Size);
396 const std::string &ParamName = PM.getParamName(Param);
397 SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
399 Ops[i+2] = ParamValue;
400 InParams.push_back(ParamValue);
405 // Create the CALL node.
406 Chain = DAG.getNode(PTXISD::CALL, dl, MVT::Other, &Ops[0], Ops.size());
408 // Create the LOAD_PARAM nodes that retrieve the function return value(s).
409 for (unsigned i = 0; i < Ins.size(); ++i) {
410 SDValue Load = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain,
412 InVals.push_back(Load);
418 unsigned PTXTargetLowering::getNumRegisters(LLVMContext &Context, EVT VT) {
419 // All arguments consist of one "register," regardless of the type.