c439d4c17313f5791b859dac1508066c24a04665
[oota-llvm.git] / lib / Target / PTX / PTXISelLowering.cpp
1 //===-- PTXISelLowering.cpp - PTX DAG Lowering Implementation -------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements the PTXTargetLowering class.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "PTX.h"
15 #include "PTXISelLowering.h"
16 #include "PTXMachineFunctionInfo.h"
17 #include "PTXRegisterInfo.h"
18 #include "PTXSubtarget.h"
19 #include "llvm/Function.h"
20 #include "llvm/Support/ErrorHandling.h"
21 #include "llvm/CodeGen/CallingConvLower.h"
22 #include "llvm/CodeGen/MachineFunction.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/SelectionDAG.h"
25 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/raw_ostream.h"
28
29 using namespace llvm;
30
31 //===----------------------------------------------------------------------===//
32 // TargetLowering Implementation
33 //===----------------------------------------------------------------------===//
34
35 PTXTargetLowering::PTXTargetLowering(TargetMachine &TM)
36   : TargetLowering(TM, new TargetLoweringObjectFileELF()) {
37   // Set up the register classes.
38   addRegisterClass(MVT::i1,  PTX::RegPredRegisterClass);
39   addRegisterClass(MVT::i16, PTX::RegI16RegisterClass);
40   addRegisterClass(MVT::i32, PTX::RegI32RegisterClass);
41   addRegisterClass(MVT::i64, PTX::RegI64RegisterClass);
42   addRegisterClass(MVT::f32, PTX::RegF32RegisterClass);
43   addRegisterClass(MVT::f64, PTX::RegF64RegisterClass);
44
45   setBooleanContents(ZeroOrOneBooleanContent);
46   setBooleanVectorContents(ZeroOrOneBooleanContent); // FIXME: Is this correct?
47   setMinFunctionAlignment(2);
48
49   ////////////////////////////////////
50   /////////// Expansion //////////////
51   ////////////////////////////////////
52
53   // (any/zero/sign) extload => load + (any/zero/sign) extend
54
55   setLoadExtAction(ISD::EXTLOAD, MVT::i16, Expand);
56   setLoadExtAction(ISD::ZEXTLOAD, MVT::i16, Expand);
57   setLoadExtAction(ISD::SEXTLOAD, MVT::i16, Expand);
58
59   // f32 extload => load + fextend
60
61   setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
62
63   // f64 truncstore => trunc + store
64
65   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
66
67   // sign_extend_inreg => sign_extend
68
69   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
70
71   // br_cc => brcond
72
73   setOperationAction(ISD::BR_CC, MVT::Other, Expand);
74
75   // select_cc => setcc
76
77   setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
78   setOperationAction(ISD::SELECT_CC, MVT::f32, Expand);
79   setOperationAction(ISD::SELECT_CC, MVT::f64, Expand);
80
81   ////////////////////////////////////
82   //////////// Legal /////////////////
83   ////////////////////////////////////
84
85   setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
86   setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
87
88   ////////////////////////////////////
89   //////////// Custom ////////////////
90   ////////////////////////////////////
91
92   // customise setcc to use bitwise logic if possible
93
94   setOperationAction(ISD::SETCC, MVT::i1, Custom);
95
96   // customize translation of memory addresses
97
98   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
99   setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
100
101   // Compute derived properties from the register classes
102   computeRegisterProperties();
103 }
104
105 EVT PTXTargetLowering::getSetCCResultType(EVT VT) const {
106   return MVT::i1;
107 }
108
109 SDValue PTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
110   switch (Op.getOpcode()) {
111     default:
112       llvm_unreachable("Unimplemented operand");
113     case ISD::SETCC:
114       return LowerSETCC(Op, DAG);
115     case ISD::GlobalAddress:
116       return LowerGlobalAddress(Op, DAG);
117   }
118 }
119
120 const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
121   switch (Opcode) {
122     default:
123       llvm_unreachable("Unknown opcode");
124     case PTXISD::COPY_ADDRESS:
125       return "PTXISD::COPY_ADDRESS";
126     case PTXISD::LOAD_PARAM:
127       return "PTXISD::LOAD_PARAM";
128     case PTXISD::STORE_PARAM:
129       return "PTXISD::STORE_PARAM";
130     case PTXISD::READ_PARAM:
131       return "PTXISD::READ_PARAM";
132     case PTXISD::WRITE_PARAM:
133       return "PTXISD::WRITE_PARAM";
134     case PTXISD::EXIT:
135       return "PTXISD::EXIT";
136     case PTXISD::RET:
137       return "PTXISD::RET";
138     case PTXISD::CALL:
139       return "PTXISD::CALL";
140   }
141 }
142
143 //===----------------------------------------------------------------------===//
144 //                      Custom Lower Operation
145 //===----------------------------------------------------------------------===//
146
147 SDValue PTXTargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
148   assert(Op.getValueType() == MVT::i1 && "SetCC type must be 1-bit integer");
149   SDValue Op0 = Op.getOperand(0);
150   SDValue Op1 = Op.getOperand(1);
151   SDValue Op2 = Op.getOperand(2);
152   DebugLoc dl = Op.getDebugLoc();
153   ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
154
155   // Look for X == 0, X == 1, X != 0, or X != 1
156   // We can simplify these to bitwise logic
157
158   if (Op1.getOpcode() == ISD::Constant &&
159       (cast<ConstantSDNode>(Op1)->getZExtValue() == 1 ||
160        cast<ConstantSDNode>(Op1)->isNullValue()) &&
161       (CC == ISD::SETEQ || CC == ISD::SETNE)) {
162
163     return DAG.getNode(ISD::AND, dl, MVT::i1, Op0, Op1);
164   }
165
166   return DAG.getNode(ISD::SETCC, dl, MVT::i1, Op0, Op1, Op2);
167 }
168
169 SDValue PTXTargetLowering::
170 LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
171   EVT PtrVT = getPointerTy();
172   DebugLoc dl = Op.getDebugLoc();
173   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
174
175   assert(PtrVT.isSimple() && "Pointer must be to primitive type.");
176
177   SDValue targetGlobal = DAG.getTargetGlobalAddress(GV, dl, PtrVT);
178   SDValue movInstr = DAG.getNode(PTXISD::COPY_ADDRESS,
179                                  dl,
180                                  PtrVT.getSimpleVT(),
181                                  targetGlobal);
182
183   return movInstr;
184 }
185
186 //===----------------------------------------------------------------------===//
187 //                      Calling Convention Implementation
188 //===----------------------------------------------------------------------===//
189
190 SDValue PTXTargetLowering::
191   LowerFormalArguments(SDValue Chain,
192                        CallingConv::ID CallConv,
193                        bool isVarArg,
194                        const SmallVectorImpl<ISD::InputArg> &Ins,
195                        DebugLoc dl,
196                        SelectionDAG &DAG,
197                        SmallVectorImpl<SDValue> &InVals) const {
198   if (isVarArg) llvm_unreachable("PTX does not support varargs");
199
200   MachineFunction &MF = DAG.getMachineFunction();
201   const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
202   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
203   PTXParamManager &PM = MFI->getParamManager();
204
205   switch (CallConv) {
206     default:
207       llvm_unreachable("Unsupported calling convention");
208       break;
209     case CallingConv::PTX_Kernel:
210       MFI->setKernel(true);
211       break;
212     case CallingConv::PTX_Device:
213       MFI->setKernel(false);
214       break;
215   }
216
217   // We do one of two things here:
218   // IsKernel || SM >= 2.0  ->  Use param space for arguments
219   // SM < 2.0               ->  Use registers for arguments
220   if (MFI->isKernel() || ST.useParamSpaceForDeviceArgs()) {
221     // We just need to emit the proper LOAD_PARAM ISDs
222     for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
223       assert((!MFI->isKernel() || Ins[i].VT != MVT::i1) &&
224              "Kernels cannot take pred operands");
225
226       unsigned ParamSize = Ins[i].VT.getStoreSizeInBits();
227       unsigned Param = PM.addArgumentParam(ParamSize);
228       const std::string &ParamName = PM.getParamName(Param);
229       SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
230                                                        MVT::Other);
231       SDValue ArgValue = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain,
232                                      ParamValue);
233       InVals.push_back(ArgValue);
234     }
235   }
236   else {
237     for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
238       EVT                  RegVT = Ins[i].VT;
239       TargetRegisterClass* TRC   = 0;
240       int                  OpCode;
241
242       // Determine which register class we need
243       if (RegVT == MVT::i1) {
244         TRC = PTX::RegPredRegisterClass;
245         OpCode = PTX::READPARAMPRED;
246       }
247       else if (RegVT == MVT::i16) {
248         TRC = PTX::RegI16RegisterClass;
249         OpCode = PTX::READPARAMI16;
250       }
251       else if (RegVT == MVT::i32) {
252         TRC = PTX::RegI32RegisterClass;
253         OpCode = PTX::READPARAMI32;
254       }
255       else if (RegVT == MVT::i64) {
256         TRC = PTX::RegI64RegisterClass;
257         OpCode = PTX::READPARAMI64;
258       }
259       else if (RegVT == MVT::f32) {
260         TRC = PTX::RegF32RegisterClass;
261         OpCode = PTX::READPARAMF32;
262       }
263       else if (RegVT == MVT::f64) {
264         TRC = PTX::RegF64RegisterClass;
265         OpCode = PTX::READPARAMF64;
266       }
267       else {
268         llvm_unreachable("Unknown parameter type");
269       }
270
271       // Use a unique index in the instruction to prevent instruction folding.
272       // Yes, this is a hack.
273       SDValue Index = DAG.getTargetConstant(i, MVT::i32);
274       unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC);
275       SDValue ArgValue = DAG.getNode(PTXISD::READ_PARAM, dl, RegVT, Chain,
276                                      Index);
277
278       SDValue Flag = ArgValue.getValue(1);
279
280       SDValue Copy = DAG.getCopyFromReg(Chain, dl, Reg, RegVT);
281       SDValue RegValue = DAG.getRegister(Reg, RegVT);
282       InVals.push_back(ArgValue);
283
284       MFI->addArgReg(Reg);
285     }
286   }
287
288   return Chain;
289 }
290
291 SDValue PTXTargetLowering::
292   LowerReturn(SDValue Chain,
293               CallingConv::ID CallConv,
294               bool isVarArg,
295               const SmallVectorImpl<ISD::OutputArg> &Outs,
296               const SmallVectorImpl<SDValue> &OutVals,
297               DebugLoc dl,
298               SelectionDAG &DAG) const {
299   if (isVarArg) llvm_unreachable("PTX does not support varargs");
300
301   switch (CallConv) {
302     default:
303       llvm_unreachable("Unsupported calling convention.");
304     case CallingConv::PTX_Kernel:
305       assert(Outs.size() == 0 && "Kernel must return void.");
306       return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain);
307     case CallingConv::PTX_Device:
308       assert(Outs.size() <= 1 && "Can at most return one value.");
309       break;
310   }
311
312   MachineFunction& MF = DAG.getMachineFunction();
313   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
314   PTXParamManager &PM = MFI->getParamManager();
315
316   SDValue Flag;
317   const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
318
319   if (ST.useParamSpaceForDeviceArgs()) {
320     assert(Outs.size() < 2 && "Device functions can return at most one value");
321
322     if (Outs.size() == 1) {
323       unsigned ParamSize = OutVals[0].getValueType().getSizeInBits();
324       unsigned Param = PM.addReturnParam(ParamSize);
325       const std::string &ParamName = PM.getParamName(Param);
326       SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
327                                                        MVT::Other);
328       Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
329                           ParamValue, OutVals[0]);
330     }
331   } else {
332     for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
333       EVT                  RegVT = Outs[i].VT;
334       TargetRegisterClass* TRC = 0;
335
336       // Determine which register class we need
337       if (RegVT == MVT::i1) {
338         TRC = PTX::RegPredRegisterClass;
339       }
340       else if (RegVT == MVT::i16) {
341         TRC = PTX::RegI16RegisterClass;
342       }
343       else if (RegVT == MVT::i32) {
344         TRC = PTX::RegI32RegisterClass;
345       }
346       else if (RegVT == MVT::i64) {
347         TRC = PTX::RegI64RegisterClass;
348       }
349       else if (RegVT == MVT::f32) {
350         TRC = PTX::RegF32RegisterClass;
351       }
352       else if (RegVT == MVT::f64) {
353         TRC = PTX::RegF64RegisterClass;
354       }
355       else {
356         llvm_unreachable("Unknown parameter type");
357       }
358
359       unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC);
360
361       SDValue Copy = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i]/*, Flag*/);
362       SDValue OutReg = DAG.getRegister(Reg, RegVT);
363
364       Chain = DAG.getNode(PTXISD::WRITE_PARAM, dl, MVT::Other, Copy, OutReg);
365
366       MFI->addRetReg(Reg);
367     }
368   }
369
370   if (Flag.getNode() == 0) {
371     return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain);
372   }
373   else {
374     return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain, Flag);
375   }
376 }
377
378 SDValue
379 PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
380                              CallingConv::ID CallConv, bool isVarArg,
381                              bool &isTailCall,
382                              const SmallVectorImpl<ISD::OutputArg> &Outs,
383                              const SmallVectorImpl<SDValue> &OutVals,
384                              const SmallVectorImpl<ISD::InputArg> &Ins,
385                              DebugLoc dl, SelectionDAG &DAG,
386                              SmallVectorImpl<SDValue> &InVals) const {
387
388   MachineFunction& MF = DAG.getMachineFunction();
389   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
390   PTXParamManager &PM = MFI->getParamManager();
391
392   assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() &&
393          "Calls are not handled for the target device");
394
395   std::vector<SDValue> Ops;
396   // The layout of the ops will be [Chain, Ins, Callee, Outs]
397   Ops.resize(Outs.size() + Ins.size() + 2);
398
399   Ops[0] = Chain;
400
401   // Identify the callee function
402   const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal();
403   assert(cast<Function>(GV)->getCallingConv() == CallingConv::PTX_Device &&
404          "PTX function calls must be to PTX device functions");
405   Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
406   Ops[Ins.size()+1] = Callee;
407
408   // Generate STORE_PARAM nodes for each function argument.  In PTX, function
409   // arguments are explicitly stored into .param variables and passed as
410   // arguments. There is no register/stack-based calling convention in PTX.
411   for (unsigned i = 0; i != OutVals.size(); ++i) {
412     unsigned Size = OutVals[i].getValueType().getSizeInBits();
413     unsigned Param = PM.addLocalParam(Size);
414     const std::string &ParamName = PM.getParamName(Param);
415     SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
416                                                      MVT::Other);
417     Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
418                         ParamValue, OutVals[i]);
419     Ops[i+Ins.size()+2] = ParamValue;
420   }
421
422   std::vector<SDValue> InParams;
423
424   // Generate list of .param variables to hold the return value(s).
425   for (unsigned i = 0; i < Ins.size(); ++i) {
426     unsigned Size = Ins[i].VT.getStoreSizeInBits();
427     unsigned Param = PM.addLocalParam(Size);
428     const std::string &ParamName = PM.getParamName(Param);
429     SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
430                                                      MVT::Other);
431     Ops[i+1] = ParamValue;
432     InParams.push_back(ParamValue);
433   }
434
435   Ops[0] = Chain;
436
437   // Create the CALL node.
438   Chain = DAG.getNode(PTXISD::CALL, dl, MVT::Other, &Ops[0], Ops.size());
439
440   // Create the LOAD_PARAM nodes that retrieve the function return value(s).
441   for (unsigned i = 0; i < Ins.size(); ++i) {
442     SDValue Load = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain,
443                                InParams[i]);
444     InVals.push_back(Load);
445   }
446
447   return Chain;
448 }