PTX: Add preliminary support for floating-point divide and multiply-and-add
[oota-llvm.git] / lib / Target / PTX / PTXISelLowering.cpp
1 //===-- PTXISelLowering.cpp - PTX DAG Lowering Implementation -------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements the PTXTargetLowering class.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "PTX.h"
15 #include "PTXISelLowering.h"
16 #include "PTXMachineFunctionInfo.h"
17 #include "PTXRegisterInfo.h"
18 #include "llvm/Support/ErrorHandling.h"
19 #include "llvm/CodeGen/MachineFunction.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/CodeGen/SelectionDAG.h"
22 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
23 #include "llvm/Support/raw_ostream.h"
24
25 using namespace llvm;
26
27 PTXTargetLowering::PTXTargetLowering(TargetMachine &TM)
28   : TargetLowering(TM, new TargetLoweringObjectFileELF()) {
29   // Set up the register classes.
30   addRegisterClass(MVT::i1,  PTX::PredsRegisterClass);
31   addRegisterClass(MVT::i16, PTX::RRegu16RegisterClass);
32   addRegisterClass(MVT::i32, PTX::RRegu32RegisterClass);
33   addRegisterClass(MVT::i64, PTX::RRegu64RegisterClass);
34   addRegisterClass(MVT::f32, PTX::RRegf32RegisterClass);
35   addRegisterClass(MVT::f64, PTX::RRegf64RegisterClass);
36
37   setOperationAction(ISD::EXCEPTIONADDR, MVT::i32, Expand);
38
39   setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
40   setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
41
42   // Customize translation of memory addresses
43   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
44
45   // Compute derived properties from the register classes
46   computeRegisterProperties();
47 }
48
49 SDValue PTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
50   switch (Op.getOpcode()) {
51     default:                 llvm_unreachable("Unimplemented operand");
52     case ISD::GlobalAddress: return LowerGlobalAddress(Op, DAG);
53   }
54 }
55
56 const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
57   switch (Opcode) {
58     default:
59       llvm_unreachable("Unknown opcode");
60     case PTXISD::READ_PARAM:
61       return "PTXISD::READ_PARAM";
62     case PTXISD::EXIT:
63       return "PTXISD::EXIT";
64     case PTXISD::RET:
65       return "PTXISD::RET";
66   }
67 }
68
69 //===----------------------------------------------------------------------===//
70 //                      Custom Lower Operation
71 //===----------------------------------------------------------------------===//
72
73 SDValue PTXTargetLowering::
74 LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
75   EVT PtrVT = getPointerTy();
76   DebugLoc dl = Op.getDebugLoc();
77   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
78   return DAG.getTargetGlobalAddress(GV, dl, PtrVT);
79 }
80
81 //===----------------------------------------------------------------------===//
82 //                      Calling Convention Implementation
83 //===----------------------------------------------------------------------===//
84
85 namespace {
86 struct argmap_entry {
87   MVT::SimpleValueType VT;
88   TargetRegisterClass *RC;
89   TargetRegisterClass::iterator loc;
90
91   argmap_entry(MVT::SimpleValueType _VT, TargetRegisterClass *_RC)
92     : VT(_VT), RC(_RC), loc(_RC->begin()) {}
93
94   void reset() { loc = RC->begin(); }
95   bool operator==(MVT::SimpleValueType _VT) const { return VT == _VT; }
96 } argmap[] = {
97   argmap_entry(MVT::i1,  PTX::PredsRegisterClass),
98   argmap_entry(MVT::i16, PTX::RRegu16RegisterClass),
99   argmap_entry(MVT::i32, PTX::RRegu32RegisterClass),
100   argmap_entry(MVT::i64, PTX::RRegu64RegisterClass),
101   argmap_entry(MVT::f32, PTX::RRegf32RegisterClass),
102   argmap_entry(MVT::f64, PTX::RRegf64RegisterClass)
103 };
104 }                               // end anonymous namespace
105
106 SDValue PTXTargetLowering::
107   LowerFormalArguments(SDValue Chain,
108                        CallingConv::ID CallConv,
109                        bool isVarArg,
110                        const SmallVectorImpl<ISD::InputArg> &Ins,
111                        DebugLoc dl,
112                        SelectionDAG &DAG,
113                        SmallVectorImpl<SDValue> &InVals) const {
114   if (isVarArg) llvm_unreachable("PTX does not support varargs");
115
116   MachineFunction &MF = DAG.getMachineFunction();
117   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
118
119   switch (CallConv) {
120     default:
121       llvm_unreachable("Unsupported calling convention");
122       break;
123     case CallingConv::PTX_Kernel:
124       MFI->setKernel(true);
125       break;
126     case CallingConv::PTX_Device:
127       MFI->setKernel(false);
128       break;
129   }
130
131   // Make sure we don't add argument registers twice
132   if (MFI->isDoneAddArg())
133     llvm_unreachable("cannot add argument registers twice");
134
135   // Reset argmap before allocation
136   for (struct argmap_entry *i = argmap, *e = argmap + array_lengthof(argmap);
137        i != e; ++ i)
138     i->reset();
139
140   for (int i = 0, e = Ins.size(); i != e; ++ i) {
141     MVT::SimpleValueType VT = Ins[i].VT.SimpleTy;
142
143     struct argmap_entry *entry = std::find(argmap,
144                                            argmap + array_lengthof(argmap), VT);
145     if (entry == argmap + array_lengthof(argmap))
146       llvm_unreachable("Type of argument is not supported");
147
148     if (MFI->isKernel() && entry->RC == PTX::PredsRegisterClass)
149       llvm_unreachable("cannot pass preds to kernel");
150
151     MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
152
153     unsigned preg = *++(entry->loc); // allocate start from register 1
154     unsigned vreg = RegInfo.createVirtualRegister(entry->RC);
155     RegInfo.addLiveIn(preg, vreg);
156
157     MFI->addArgReg(preg);
158
159     SDValue inval;
160     if (MFI->isKernel())
161       inval = DAG.getNode(PTXISD::READ_PARAM, dl, VT, Chain,
162                           DAG.getTargetConstant(i, MVT::i32));
163     else
164       inval = DAG.getCopyFromReg(Chain, dl, vreg, VT);
165     InVals.push_back(inval);
166   }
167
168   MFI->doneAddArg();
169
170   return Chain;
171 }
172
173 SDValue PTXTargetLowering::
174   LowerReturn(SDValue Chain,
175               CallingConv::ID CallConv,
176               bool isVarArg,
177               const SmallVectorImpl<ISD::OutputArg> &Outs,
178               const SmallVectorImpl<SDValue> &OutVals,
179               DebugLoc dl,
180               SelectionDAG &DAG) const {
181   if (isVarArg) llvm_unreachable("PTX does not support varargs");
182
183   switch (CallConv) {
184     default:
185       llvm_unreachable("Unsupported calling convention.");
186     case CallingConv::PTX_Kernel:
187       assert(Outs.size() == 0 && "Kernel must return void.");
188       return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain);
189     case CallingConv::PTX_Device:
190       assert(Outs.size() <= 1 && "Can at most return one value.");
191       break;
192   }
193
194   // PTX_Device
195
196   // return void
197   if (Outs.size() == 0)
198     return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain);
199
200   SDValue Flag;
201   unsigned reg;
202
203   if (Outs[0].VT == MVT::i16) {
204     reg = PTX::RH0;
205   }
206   else if (Outs[0].VT == MVT::i32) {
207     reg = PTX::R0;
208   }
209   else if (Outs[0].VT == MVT::i64) {
210     reg = PTX::RD0;
211   }
212   else if (Outs[0].VT == MVT::f32) {
213     reg = PTX::F0;
214   }
215   else if (Outs[0].VT == MVT::f64) {
216     reg = PTX::FD0;
217   }
218   else {
219     assert(false && "Can return only basic types");
220   }
221
222   MachineFunction &MF = DAG.getMachineFunction();
223   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
224   MFI->setRetReg(reg);
225
226   // If this is the first return lowered for this function, add the regs to the
227   // liveout set for the function
228   if (DAG.getMachineFunction().getRegInfo().liveout_empty())
229     DAG.getMachineFunction().getRegInfo().addLiveOut(reg);
230
231   // Copy the result values into the output registers
232   Chain = DAG.getCopyToReg(Chain, dl, reg, OutVals[0], Flag);
233
234   // Guarantee that all emitted copies are stuck together,
235   // avoiding something bad
236   Flag = Chain.getValue(1);
237
238   return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain, Flag);
239 }