PTX: Finish new calling convention implementation
[oota-llvm.git] / lib / Target / PTX / PTXISelLowering.cpp
1 //===-- PTXISelLowering.cpp - PTX DAG Lowering Implementation -------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements the PTXTargetLowering class.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "PTX.h"
15 #include "PTXISelLowering.h"
16 #include "PTXMachineFunctionInfo.h"
17 #include "PTXRegisterInfo.h"
18 #include "llvm/Support/ErrorHandling.h"
19 #include "llvm/CodeGen/CallingConvLower.h"
20 #include "llvm/CodeGen/MachineFunction.h"
21 #include "llvm/CodeGen/MachineRegisterInfo.h"
22 #include "llvm/CodeGen/SelectionDAG.h"
23 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
24 #include "llvm/Support/raw_ostream.h"
25
26 using namespace llvm;
27
28 //===----------------------------------------------------------------------===//
29 // Calling Convention Implementation
30 //===----------------------------------------------------------------------===//
31
32 #include "PTXGenCallingConv.inc"
33
34 //===----------------------------------------------------------------------===//
35 // TargetLowering Implementation
36 //===----------------------------------------------------------------------===//
37
38 PTXTargetLowering::PTXTargetLowering(TargetMachine &TM)
39   : TargetLowering(TM, new TargetLoweringObjectFileELF()) {
40   // Set up the register classes.
41   addRegisterClass(MVT::i1,  PTX::RegPredRegisterClass);
42   addRegisterClass(MVT::i16, PTX::RegI16RegisterClass);
43   addRegisterClass(MVT::i32, PTX::RegI32RegisterClass);
44   addRegisterClass(MVT::i64, PTX::RegI64RegisterClass);
45   addRegisterClass(MVT::f32, PTX::RegF32RegisterClass);
46   addRegisterClass(MVT::f64, PTX::RegF64RegisterClass);
47
48   setBooleanContents(ZeroOrOneBooleanContent);
49
50   setOperationAction(ISD::EXCEPTIONADDR, MVT::i32, Expand);
51
52   setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
53   setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
54
55   // Turn i16 (z)extload into load + (z)extend
56   setLoadExtAction(ISD::EXTLOAD, MVT::i16, Expand);
57   setLoadExtAction(ISD::ZEXTLOAD, MVT::i16, Expand);
58
59   // Turn f32 extload into load + fextend
60   setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
61
62   // Turn f64 truncstore into trunc + store.
63   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
64
65   // Customize translation of memory addresses
66   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
67   setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
68
69   // Expand BR_CC into BRCOND
70   setOperationAction(ISD::BR_CC, MVT::Other, Expand);
71
72   // Expand SELECT_CC into SETCC
73   setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
74   setOperationAction(ISD::SELECT_CC, MVT::f32, Expand);
75   setOperationAction(ISD::SELECT_CC, MVT::f64, Expand);
76
77   // need to lower SETCC of RegPred into bitwise logic
78   setOperationAction(ISD::SETCC, MVT::i1, Custom);
79
80   setMinFunctionAlignment(2);
81
82   // Compute derived properties from the register classes
83   computeRegisterProperties();
84 }
85
86 MVT::SimpleValueType PTXTargetLowering::getSetCCResultType(EVT VT) const {
87   return MVT::i1;
88 }
89
90 SDValue PTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
91   switch (Op.getOpcode()) {
92     default:
93       llvm_unreachable("Unimplemented operand");
94     case ISD::SETCC:
95       return LowerSETCC(Op, DAG);
96     case ISD::GlobalAddress:
97       return LowerGlobalAddress(Op, DAG);
98   }
99 }
100
101 const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
102   switch (Opcode) {
103     default:
104       llvm_unreachable("Unknown opcode");
105     case PTXISD::COPY_ADDRESS:
106       return "PTXISD::COPY_ADDRESS";
107     case PTXISD::READ_PARAM:
108       return "PTXISD::READ_PARAM";
109     case PTXISD::EXIT:
110       return "PTXISD::EXIT";
111     case PTXISD::RET:
112       return "PTXISD::RET";
113   }
114 }
115
116 //===----------------------------------------------------------------------===//
117 //                      Custom Lower Operation
118 //===----------------------------------------------------------------------===//
119
120 SDValue PTXTargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
121   assert(Op.getValueType() == MVT::i1 && "SetCC type must be 1-bit integer");
122   SDValue Op0 = Op.getOperand(0);
123   SDValue Op1 = Op.getOperand(1);
124   SDValue Op2 = Op.getOperand(2);
125   DebugLoc dl = Op.getDebugLoc();
126   ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
127
128   // Look for X == 0, X == 1, X != 0, or X != 1  
129   // We can simplify these to bitwise logic
130
131   if (Op1.getOpcode() == ISD::Constant &&
132       (cast<ConstantSDNode>(Op1)->getZExtValue() == 1 ||
133        cast<ConstantSDNode>(Op1)->isNullValue()) &&
134       (CC == ISD::SETEQ || CC == ISD::SETNE)) {
135
136     return DAG.getNode(ISD::AND, dl, MVT::i1, Op0, Op1);
137   }
138
139   return DAG.getNode(ISD::SETCC, dl, MVT::i1, Op0, Op1, Op2);
140 }
141
142 SDValue PTXTargetLowering::
143 LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
144   EVT PtrVT = getPointerTy();
145   DebugLoc dl = Op.getDebugLoc();
146   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
147
148   assert(PtrVT.isSimple() && "Pointer must be to primitive type.");
149
150   SDValue targetGlobal = DAG.getTargetGlobalAddress(GV, dl, PtrVT);
151   SDValue movInstr = DAG.getNode(PTXISD::COPY_ADDRESS,
152                                  dl,
153                                  PtrVT.getSimpleVT(),
154                                  targetGlobal);
155
156   return movInstr;
157 }
158
159 //===----------------------------------------------------------------------===//
160 //                      Calling Convention Implementation
161 //===----------------------------------------------------------------------===//
162
163 namespace {
164 struct argmap_entry {
165   MVT::SimpleValueType VT;
166   TargetRegisterClass *RC;
167   TargetRegisterClass::iterator loc;
168
169   argmap_entry(MVT::SimpleValueType _VT, TargetRegisterClass *_RC)
170     : VT(_VT), RC(_RC), loc(_RC->begin()) {}
171
172   void reset() { loc = RC->begin(); }
173   bool operator==(MVT::SimpleValueType _VT) const { return VT == _VT; }
174 } argmap[] = {
175   argmap_entry(MVT::i1,  PTX::RegPredRegisterClass),
176   argmap_entry(MVT::i16, PTX::RegI16RegisterClass),
177   argmap_entry(MVT::i32, PTX::RegI32RegisterClass),
178   argmap_entry(MVT::i64, PTX::RegI64RegisterClass),
179   argmap_entry(MVT::f32, PTX::RegF32RegisterClass),
180   argmap_entry(MVT::f64, PTX::RegF64RegisterClass)
181 };
182 }                               // end anonymous namespace
183
184 SDValue PTXTargetLowering::
185   LowerFormalArguments(SDValue Chain,
186                        CallingConv::ID CallConv,
187                        bool isVarArg,
188                        const SmallVectorImpl<ISD::InputArg> &Ins,
189                        DebugLoc dl,
190                        SelectionDAG &DAG,
191                        SmallVectorImpl<SDValue> &InVals) const {
192   if (isVarArg) llvm_unreachable("PTX does not support varargs");
193
194   MachineFunction &MF = DAG.getMachineFunction();
195   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
196
197   switch (CallConv) {
198     default:
199       llvm_unreachable("Unsupported calling convention");
200       break;
201     case CallingConv::PTX_Kernel:
202       MFI->setKernel(true);
203       break;
204     case CallingConv::PTX_Device:
205       MFI->setKernel(false);
206       break;
207   }
208
209   if (MFI->isKernel()) {
210     // For kernel functions, we just need to emit the proper READ_PARAM ISDs
211     for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
212
213       assert(Ins[i].VT != MVT::i1 && "Kernels cannot take pred operands");
214
215       SDValue ArgValue = DAG.getNode(PTXISD::READ_PARAM, dl, Ins[i].VT, Chain,
216                                      DAG.getTargetConstant(i, MVT::i32));
217       InVals.push_back(ArgValue);
218
219       // Instead of storing a physical register in our argument list, we just
220       // store the total size of the parameter, in bits.  The ASM printer
221       // knows how to process this.
222       MFI->addArgReg(Ins[i].VT.getStoreSizeInBits());
223     }
224   }
225   else {
226     // For device functions, we use the PTX calling convention to do register
227     // assignments then create CopyFromReg ISDs for the allocated registers
228
229     SmallVector<CCValAssign, 16> ArgLocs;
230     CCState CCInfo(CallConv, isVarArg, MF, getTargetMachine(), ArgLocs,
231                    *DAG.getContext());
232
233     CCInfo.AnalyzeFormalArguments(Ins, CC_PTX);
234
235     for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) {
236
237       CCValAssign&         VA    = ArgLocs[i];
238       EVT                  RegVT = VA.getLocVT();
239       TargetRegisterClass* TRC   = 0;
240
241       assert(VA.isRegLoc() && "CCValAssign must be RegLoc");
242
243       // Determine which register class we need
244       if (RegVT == MVT::i1) {
245         TRC = PTX::RegPredRegisterClass;
246       }
247       else if (RegVT == MVT::i16) {
248         TRC = PTX::RegI16RegisterClass;
249       }
250       else if (RegVT == MVT::i32) {
251         TRC = PTX::RegI32RegisterClass;
252       }
253       else if (RegVT == MVT::i64) {
254         TRC = PTX::RegI64RegisterClass;
255       }
256       else if (RegVT == MVT::f32) {
257         TRC = PTX::RegF32RegisterClass;
258       }
259       else if (RegVT == MVT::f64) {
260         TRC = PTX::RegF64RegisterClass;
261       }
262       else {
263         llvm_unreachable("Unknown parameter type");
264       }
265
266       unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC);
267       MF.getRegInfo().addLiveIn(VA.getLocReg(), Reg);
268
269       SDValue ArgValue = DAG.getCopyFromReg(Chain, dl, Reg, RegVT);
270       InVals.push_back(ArgValue);
271
272       MFI->addArgReg(VA.getLocReg());
273     }
274   }
275
276   return Chain;
277 }
278
279 SDValue PTXTargetLowering::
280   LowerReturn(SDValue Chain,
281               CallingConv::ID CallConv,
282               bool isVarArg,
283               const SmallVectorImpl<ISD::OutputArg> &Outs,
284               const SmallVectorImpl<SDValue> &OutVals,
285               DebugLoc dl,
286               SelectionDAG &DAG) const {
287   if (isVarArg) llvm_unreachable("PTX does not support varargs");
288
289   switch (CallConv) {
290     default:
291       llvm_unreachable("Unsupported calling convention.");
292     case CallingConv::PTX_Kernel:
293       assert(Outs.size() == 0 && "Kernel must return void.");
294       return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain);
295     case CallingConv::PTX_Device:
296       //assert(Outs.size() <= 1 && "Can at most return one value.");
297       break;
298   }
299
300   MachineFunction& MF = DAG.getMachineFunction();
301   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
302   SmallVector<CCValAssign, 16> RVLocs;
303   CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(),
304                  getTargetMachine(), RVLocs, *DAG.getContext());
305
306   SDValue Flag;
307
308   CCInfo.AnalyzeReturn(Outs, RetCC_PTX);
309
310   for (unsigned i = 0, e = RVLocs.size(); i != e; ++i) {
311
312     CCValAssign& VA  = RVLocs[i];
313
314     assert(VA.isRegLoc() && "CCValAssign must be RegLoc");
315
316     unsigned Reg = VA.getLocReg();
317
318     DAG.getMachineFunction().getRegInfo().addLiveOut(Reg);
319
320     Chain = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i], Flag);
321
322     // Guarantee that all emitted copies are stuck together,
323     // avoiding something bad
324     Flag = Chain.getValue(1);
325
326     MFI->addRetReg(Reg);
327   }
328
329   if (Flag.getNode() == 0) {
330     return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain);
331   }
332   else {
333     return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain, Flag);
334   }
335 }