4590916a385ec7c418945831d396b3290f7ff999
[oota-llvm.git] / lib / Target / NVPTX / NVPTXISelLowering.cpp
1 //
2 //                     The LLVM Compiler Infrastructure
3 //
4 // This file is distributed under the University of Illinois Open Source
5 // License. See LICENSE.TXT for details.
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that NVPTX uses to lower LLVM code into a
10 // selection DAG.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "NVPTXISelLowering.h"
15 #include "NVPTX.h"
16 #include "NVPTXTargetMachine.h"
17 #include "NVPTXTargetObjectFile.h"
18 #include "NVPTXUtilities.h"
19 #include "llvm/CodeGen/Analysis.h"
20 #include "llvm/CodeGen/MachineFrameInfo.h"
21 #include "llvm/CodeGen/MachineFunction.h"
22 #include "llvm/CodeGen/MachineInstrBuilder.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/GlobalValue.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/MC/MCSectionELF.h"
32 #include "llvm/Support/CallSite.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <sstream>
38
39 #undef DEBUG_TYPE
40 #define DEBUG_TYPE "nvptx-lower"
41
42 using namespace llvm;
43
44 static unsigned int uniqueCallSite = 0;
45
46 static cl::opt<bool> sched4reg(
47     "nvptx-sched4reg",
48     cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
49
50 static bool IsPTXVectorType(MVT VT) {
51   switch (VT.SimpleTy) {
52   default:
53     return false;
54   case MVT::v2i1:
55   case MVT::v4i1:
56   case MVT::v2i8:
57   case MVT::v4i8:
58   case MVT::v2i16:
59   case MVT::v4i16:
60   case MVT::v2i32:
61   case MVT::v4i32:
62   case MVT::v2i64:
63   case MVT::v2f32:
64   case MVT::v4f32:
65   case MVT::v2f64:
66     return true;
67   }
68 }
69
70 /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
71 /// EVTs that compose it.  Unlike ComputeValueVTs, this will break apart vectors
72 /// into their primitive components.
73 /// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
74 /// same number of types as the Ins/Outs arrays in LowerFormalArguments,
75 /// LowerCall, and LowerReturn.
76 static void ComputePTXValueVTs(const TargetLowering &TLI, Type *Ty,
77                                SmallVectorImpl<EVT> &ValueVTs,
78                                SmallVectorImpl<uint64_t> *Offsets = 0,
79                                uint64_t StartingOffset = 0) {
80   SmallVector<EVT, 16> TempVTs;
81   SmallVector<uint64_t, 16> TempOffsets;
82
83   ComputeValueVTs(TLI, Ty, TempVTs, &TempOffsets, StartingOffset);
84   for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) {
85     EVT VT = TempVTs[i];
86     uint64_t Off = TempOffsets[i];
87     if (VT.isVector())
88       for (unsigned j = 0, je = VT.getVectorNumElements(); j != je; ++j) {
89         ValueVTs.push_back(VT.getVectorElementType());
90         if (Offsets)
91           Offsets->push_back(Off+j*VT.getVectorElementType().getStoreSize());
92       }
93     else {
94       ValueVTs.push_back(VT);
95       if (Offsets)
96         Offsets->push_back(Off);
97     }
98   }
99 }
100
101 // NVPTXTargetLowering Constructor.
102 NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
103     : TargetLowering(TM, new NVPTXTargetObjectFile()), nvTM(&TM),
104       nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) {
105
106   // always lower memset, memcpy, and memmove intrinsics to load/store
107   // instructions, rather
108   // then generating calls to memset, mempcy or memmove.
109   MaxStoresPerMemset = (unsigned) 0xFFFFFFFF;
110   MaxStoresPerMemcpy = (unsigned) 0xFFFFFFFF;
111   MaxStoresPerMemmove = (unsigned) 0xFFFFFFFF;
112
113   setBooleanContents(ZeroOrNegativeOneBooleanContent);
114
115   // Jump is Expensive. Don't create extra control flow for 'and', 'or'
116   // condition branches.
117   setJumpIsExpensive(true);
118
119   // By default, use the Source scheduling
120   if (sched4reg)
121     setSchedulingPreference(Sched::RegPressure);
122   else
123     setSchedulingPreference(Sched::Source);
124
125   addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
126   addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
127   addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
128   addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
129   addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
130   addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
131
132   // Operations not directly supported by NVPTX.
133   setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
134   setOperationAction(ISD::BR_CC, MVT::f32, Expand);
135   setOperationAction(ISD::BR_CC, MVT::f64, Expand);
136   setOperationAction(ISD::BR_CC, MVT::i1, Expand);
137   setOperationAction(ISD::BR_CC, MVT::i8, Expand);
138   setOperationAction(ISD::BR_CC, MVT::i16, Expand);
139   setOperationAction(ISD::BR_CC, MVT::i32, Expand);
140   setOperationAction(ISD::BR_CC, MVT::i64, Expand);
141   // Some SIGN_EXTEND_INREG can be done using cvt instruction.
142   // For others we will expand to a SHL/SRA pair.
143   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Legal);
144   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Legal);
145   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Legal);
146   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8 , Legal);
147   setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
148
149   if (nvptxSubtarget.hasROT64()) {
150     setOperationAction(ISD::ROTL, MVT::i64, Legal);
151     setOperationAction(ISD::ROTR, MVT::i64, Legal);
152   } else {
153     setOperationAction(ISD::ROTL, MVT::i64, Expand);
154     setOperationAction(ISD::ROTR, MVT::i64, Expand);
155   }
156   if (nvptxSubtarget.hasROT32()) {
157     setOperationAction(ISD::ROTL, MVT::i32, Legal);
158     setOperationAction(ISD::ROTR, MVT::i32, Legal);
159   } else {
160     setOperationAction(ISD::ROTL, MVT::i32, Expand);
161     setOperationAction(ISD::ROTR, MVT::i32, Expand);
162   }
163
164   setOperationAction(ISD::ROTL, MVT::i16, Expand);
165   setOperationAction(ISD::ROTR, MVT::i16, Expand);
166   setOperationAction(ISD::ROTL, MVT::i8, Expand);
167   setOperationAction(ISD::ROTR, MVT::i8, Expand);
168   setOperationAction(ISD::BSWAP, MVT::i16, Expand);
169   setOperationAction(ISD::BSWAP, MVT::i32, Expand);
170   setOperationAction(ISD::BSWAP, MVT::i64, Expand);
171
172   // Indirect branch is not supported.
173   // This also disables Jump Table creation.
174   setOperationAction(ISD::BR_JT, MVT::Other, Expand);
175   setOperationAction(ISD::BRIND, MVT::Other, Expand);
176
177   setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
178   setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
179
180   // We want to legalize constant related memmove and memcopy
181   // intrinsics.
182   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
183
184   // Turn FP extload into load/fextend
185   setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
186   // Turn FP truncstore into trunc + store.
187   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
188
189   // PTX does not support load / store predicate registers
190   setOperationAction(ISD::LOAD, MVT::i1, Custom);
191   setOperationAction(ISD::STORE, MVT::i1, Custom);
192
193   setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
194   setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
195   setTruncStoreAction(MVT::i64, MVT::i1, Expand);
196   setTruncStoreAction(MVT::i32, MVT::i1, Expand);
197   setTruncStoreAction(MVT::i16, MVT::i1, Expand);
198   setTruncStoreAction(MVT::i8, MVT::i1, Expand);
199
200   // This is legal in NVPTX
201   setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
202   setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
203
204   // TRAP can be lowered to PTX trap
205   setOperationAction(ISD::TRAP, MVT::Other, Legal);
206
207   // Register custom handling for vector loads/stores
208   for (int i = MVT::FIRST_VECTOR_VALUETYPE; i <= MVT::LAST_VECTOR_VALUETYPE;
209        ++i) {
210     MVT VT = (MVT::SimpleValueType) i;
211     if (IsPTXVectorType(VT)) {
212       setOperationAction(ISD::LOAD, VT, Custom);
213       setOperationAction(ISD::STORE, VT, Custom);
214       setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom);
215     }
216   }
217
218   // Custom handling for i8 intrinsics
219   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
220
221   setOperationAction(ISD::CTLZ, MVT::i16, Legal);
222   setOperationAction(ISD::CTLZ, MVT::i32, Legal);
223   setOperationAction(ISD::CTLZ, MVT::i64, Legal);
224   setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i16, Legal);
225   setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i32, Legal);
226   setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i64, Legal);
227   setOperationAction(ISD::CTTZ, MVT::i16, Expand);
228   setOperationAction(ISD::CTTZ, MVT::i32, Expand);
229   setOperationAction(ISD::CTTZ, MVT::i64, Expand);
230   setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i16, Expand);
231   setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i32, Expand);
232   setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i64, Expand);
233   setOperationAction(ISD::CTPOP, MVT::i16, Legal);
234   setOperationAction(ISD::CTPOP, MVT::i32, Legal);
235   setOperationAction(ISD::CTPOP, MVT::i64, Legal);
236
237   // Now deduce the information based on the above mentioned
238   // actions
239   computeRegisterProperties();
240 }
241
242 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
243   switch (Opcode) {
244   default:
245     return 0;
246   case NVPTXISD::CALL:
247     return "NVPTXISD::CALL";
248   case NVPTXISD::RET_FLAG:
249     return "NVPTXISD::RET_FLAG";
250   case NVPTXISD::Wrapper:
251     return "NVPTXISD::Wrapper";
252   case NVPTXISD::NVBuiltin:
253     return "NVPTXISD::NVBuiltin";
254   case NVPTXISD::DeclareParam:
255     return "NVPTXISD::DeclareParam";
256   case NVPTXISD::DeclareScalarParam:
257     return "NVPTXISD::DeclareScalarParam";
258   case NVPTXISD::DeclareRet:
259     return "NVPTXISD::DeclareRet";
260   case NVPTXISD::DeclareRetParam:
261     return "NVPTXISD::DeclareRetParam";
262   case NVPTXISD::PrintCall:
263     return "NVPTXISD::PrintCall";
264   case NVPTXISD::LoadParam:
265     return "NVPTXISD::LoadParam";
266   case NVPTXISD::LoadParamV2:
267     return "NVPTXISD::LoadParamV2";
268   case NVPTXISD::LoadParamV4:
269     return "NVPTXISD::LoadParamV4";
270   case NVPTXISD::StoreParam:
271     return "NVPTXISD::StoreParam";
272   case NVPTXISD::StoreParamV2:
273     return "NVPTXISD::StoreParamV2";
274   case NVPTXISD::StoreParamV4:
275     return "NVPTXISD::StoreParamV4";
276   case NVPTXISD::StoreParamS32:
277     return "NVPTXISD::StoreParamS32";
278   case NVPTXISD::StoreParamU32:
279     return "NVPTXISD::StoreParamU32";
280   case NVPTXISD::CallArgBegin:
281     return "NVPTXISD::CallArgBegin";
282   case NVPTXISD::CallArg:
283     return "NVPTXISD::CallArg";
284   case NVPTXISD::LastCallArg:
285     return "NVPTXISD::LastCallArg";
286   case NVPTXISD::CallArgEnd:
287     return "NVPTXISD::CallArgEnd";
288   case NVPTXISD::CallVoid:
289     return "NVPTXISD::CallVoid";
290   case NVPTXISD::CallVal:
291     return "NVPTXISD::CallVal";
292   case NVPTXISD::CallSymbol:
293     return "NVPTXISD::CallSymbol";
294   case NVPTXISD::Prototype:
295     return "NVPTXISD::Prototype";
296   case NVPTXISD::MoveParam:
297     return "NVPTXISD::MoveParam";
298   case NVPTXISD::StoreRetval:
299     return "NVPTXISD::StoreRetval";
300   case NVPTXISD::StoreRetvalV2:
301     return "NVPTXISD::StoreRetvalV2";
302   case NVPTXISD::StoreRetvalV4:
303     return "NVPTXISD::StoreRetvalV4";
304   case NVPTXISD::PseudoUseParam:
305     return "NVPTXISD::PseudoUseParam";
306   case NVPTXISD::RETURN:
307     return "NVPTXISD::RETURN";
308   case NVPTXISD::CallSeqBegin:
309     return "NVPTXISD::CallSeqBegin";
310   case NVPTXISD::CallSeqEnd:
311     return "NVPTXISD::CallSeqEnd";
312   case NVPTXISD::LoadV2:
313     return "NVPTXISD::LoadV2";
314   case NVPTXISD::LoadV4:
315     return "NVPTXISD::LoadV4";
316   case NVPTXISD::LDGV2:
317     return "NVPTXISD::LDGV2";
318   case NVPTXISD::LDGV4:
319     return "NVPTXISD::LDGV4";
320   case NVPTXISD::LDUV2:
321     return "NVPTXISD::LDUV2";
322   case NVPTXISD::LDUV4:
323     return "NVPTXISD::LDUV4";
324   case NVPTXISD::StoreV2:
325     return "NVPTXISD::StoreV2";
326   case NVPTXISD::StoreV4:
327     return "NVPTXISD::StoreV4";
328   }
329 }
330
331 bool NVPTXTargetLowering::shouldSplitVectorElementType(EVT VT) const {
332   return VT == MVT::i1;
333 }
334
335 SDValue
336 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
337   SDLoc dl(Op);
338   const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
339   Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
340   return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op);
341 }
342
343 /*
344 std::string NVPTXTargetLowering::getPrototype(
345     Type *retTy, const ArgListTy &Args,
346     const SmallVectorImpl<ISD::OutputArg> &Outs, unsigned retAlignment) const {
347
348   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
349
350   std::stringstream O;
351   O << "prototype_" << uniqueCallSite << " : .callprototype ";
352
353   if (retTy->getTypeID() == Type::VoidTyID)
354     O << "()";
355   else {
356     O << "(";
357     if (isABI) {
358       if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
359         unsigned size = 0;
360         if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
361           size = ITy->getBitWidth();
362           if (size < 32)
363             size = 32;
364         } else {
365           assert(retTy->isFloatingPointTy() &&
366                  "Floating point type expected here");
367           size = retTy->getPrimitiveSizeInBits();
368         }
369
370         O << ".param .b" << size << " _";
371       } else if (isa<PointerType>(retTy))
372         O << ".param .b" << getPointerTy().getSizeInBits() << " _";
373       else {
374         if ((retTy->getTypeID() == Type::StructTyID) ||
375             isa<VectorType>(retTy)) {
376           SmallVector<EVT, 16> vtparts;
377           ComputeValueVTs(*this, retTy, vtparts);
378           unsigned totalsz = 0;
379           for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
380             unsigned elems = 1;
381             EVT elemtype = vtparts[i];
382             if (vtparts[i].isVector()) {
383               elems = vtparts[i].getVectorNumElements();
384               elemtype = vtparts[i].getVectorElementType();
385             }
386             for (unsigned j = 0, je = elems; j != je; ++j) {
387               unsigned sz = elemtype.getSizeInBits();
388               if (elemtype.isInteger() && (sz < 8))
389                 sz = 8;
390               totalsz += sz / 8;
391             }
392           }
393           O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
394         } else {
395           assert(false && "Unknown return type");
396         }
397       }
398     } else {
399       SmallVector<EVT, 16> vtparts;
400       ComputeValueVTs(*this, retTy, vtparts);
401       unsigned idx = 0;
402       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
403         unsigned elems = 1;
404         EVT elemtype = vtparts[i];
405         if (vtparts[i].isVector()) {
406           elems = vtparts[i].getVectorNumElements();
407           elemtype = vtparts[i].getVectorElementType();
408         }
409
410         for (unsigned j = 0, je = elems; j != je; ++j) {
411           unsigned sz = elemtype.getSizeInBits();
412           if (elemtype.isInteger() && (sz < 32))
413             sz = 32;
414           O << ".reg .b" << sz << " _";
415           if (j < je - 1)
416             O << ", ";
417           ++idx;
418         }
419         if (i < e - 1)
420           O << ", ";
421       }
422     }
423     O << ") ";
424   }
425   O << "_ (";
426
427   bool first = true;
428   MVT thePointerTy = getPointerTy();
429
430   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
431     const Type *Ty = Args[i].Ty;
432     if (!first) {
433       O << ", ";
434     }
435     first = false;
436
437     if (Outs[i].Flags.isByVal() == false) {
438       unsigned sz = 0;
439       if (isa<IntegerType>(Ty)) {
440         sz = cast<IntegerType>(Ty)->getBitWidth();
441         if (sz < 32)
442           sz = 32;
443       } else if (isa<PointerType>(Ty))
444         sz = thePointerTy.getSizeInBits();
445       else
446         sz = Ty->getPrimitiveSizeInBits();
447       if (isABI)
448         O << ".param .b" << sz << " ";
449       else
450         O << ".reg .b" << sz << " ";
451       O << "_";
452       continue;
453     }
454     const PointerType *PTy = dyn_cast<PointerType>(Ty);
455     assert(PTy && "Param with byval attribute should be a pointer type");
456     Type *ETy = PTy->getElementType();
457
458     if (isABI) {
459       unsigned align = Outs[i].Flags.getByValAlign();
460       unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
461       O << ".param .align " << align << " .b8 ";
462       O << "_";
463       O << "[" << sz << "]";
464       continue;
465     } else {
466       SmallVector<EVT, 16> vtparts;
467       ComputeValueVTs(*this, ETy, vtparts);
468       for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
469         unsigned elems = 1;
470         EVT elemtype = vtparts[i];
471         if (vtparts[i].isVector()) {
472           elems = vtparts[i].getVectorNumElements();
473           elemtype = vtparts[i].getVectorElementType();
474         }
475
476         for (unsigned j = 0, je = elems; j != je; ++j) {
477           unsigned sz = elemtype.getSizeInBits();
478           if (elemtype.isInteger() && (sz < 32))
479             sz = 32;
480           O << ".reg .b" << sz << " ";
481           O << "_";
482           if (j < je - 1)
483             O << ", ";
484         }
485         if (i < e - 1)
486           O << ", ";
487       }
488       continue;
489     }
490   }
491   O << ");";
492   return O.str();
493 }*/
494
495 std::string
496 NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args,
497                                   const SmallVectorImpl<ISD::OutputArg> &Outs,
498                                   unsigned retAlignment,
499                                   const ImmutableCallSite *CS) const {
500
501   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
502   assert(isABI && "Non-ABI compilation is not supported");
503   if (!isABI)
504     return "";
505
506   std::stringstream O;
507   O << "prototype_" << uniqueCallSite << " : .callprototype ";
508
509   if (retTy->getTypeID() == Type::VoidTyID) {
510     O << "()";
511   } else {
512     O << "(";
513     if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
514       unsigned size = 0;
515       if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
516         size = ITy->getBitWidth();
517         if (size < 32)
518           size = 32;
519       } else {
520         assert(retTy->isFloatingPointTy() &&
521                "Floating point type expected here");
522         size = retTy->getPrimitiveSizeInBits();
523       }
524
525       O << ".param .b" << size << " _";
526     } else if (isa<PointerType>(retTy)) {
527       O << ".param .b" << getPointerTy().getSizeInBits() << " _";
528     } else {
529       if ((retTy->getTypeID() == Type::StructTyID) || isa<VectorType>(retTy)) {
530         SmallVector<EVT, 16> vtparts;
531         ComputeValueVTs(*this, retTy, vtparts);
532         unsigned totalsz = 0;
533         for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
534           unsigned elems = 1;
535           EVT elemtype = vtparts[i];
536           if (vtparts[i].isVector()) {
537             elems = vtparts[i].getVectorNumElements();
538             elemtype = vtparts[i].getVectorElementType();
539           }
540           // TODO: no need to loop
541           for (unsigned j = 0, je = elems; j != je; ++j) {
542             unsigned sz = elemtype.getSizeInBits();
543             if (elemtype.isInteger() && (sz < 8))
544               sz = 8;
545             totalsz += sz / 8;
546           }
547         }
548         O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
549       } else {
550         assert(false && "Unknown return type");
551       }
552     }
553     O << ") ";
554   }
555   O << "_ (";
556
557   bool first = true;
558   MVT thePointerTy = getPointerTy();
559
560   unsigned OIdx = 0;
561   for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
562     Type *Ty = Args[i].Ty;
563     if (!first) {
564       O << ", ";
565     }
566     first = false;
567
568     if (Outs[OIdx].Flags.isByVal() == false) {
569       if (Ty->isAggregateType() || Ty->isVectorTy()) {
570         unsigned align = 0;
571         const CallInst *CallI = cast<CallInst>(CS->getInstruction());
572         const DataLayout *TD = getDataLayout();
573         // +1 because index 0 is reserved for return type alignment
574         if (!llvm::getAlign(*CallI, i + 1, align))
575           align = TD->getABITypeAlignment(Ty);
576         unsigned sz = TD->getTypeAllocSize(Ty);
577         O << ".param .align " << align << " .b8 ";
578         O << "_";
579         O << "[" << sz << "]";
580         // update the index for Outs
581         SmallVector<EVT, 16> vtparts;
582         ComputeValueVTs(*this, Ty, vtparts);
583         if (unsigned len = vtparts.size())
584           OIdx += len - 1;
585         continue;
586       }
587       assert(getValueType(Ty) == Outs[OIdx].VT &&
588              "type mismatch between callee prototype and arguments");
589       // scalar type
590       unsigned sz = 0;
591       if (isa<IntegerType>(Ty)) {
592         sz = cast<IntegerType>(Ty)->getBitWidth();
593         if (sz < 32)
594           sz = 32;
595       } else if (isa<PointerType>(Ty))
596         sz = thePointerTy.getSizeInBits();
597       else
598         sz = Ty->getPrimitiveSizeInBits();
599       O << ".param .b" << sz << " ";
600       O << "_";
601       continue;
602     }
603     const PointerType *PTy = dyn_cast<PointerType>(Ty);
604     assert(PTy && "Param with byval attribute should be a pointer type");
605     Type *ETy = PTy->getElementType();
606
607     unsigned align = Outs[OIdx].Flags.getByValAlign();
608     unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
609     O << ".param .align " << align << " .b8 ";
610     O << "_";
611     O << "[" << sz << "]";
612   }
613   O << ");";
614   return O.str();
615 }
616
617 unsigned
618 NVPTXTargetLowering::getArgumentAlignment(SDValue Callee,
619                                           const ImmutableCallSite *CS,
620                                           Type *Ty,
621                                           unsigned Idx) const {
622   const DataLayout *TD = getDataLayout();
623   unsigned align = 0;
624   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
625
626   if (Func) { // direct call
627     assert(CS->getCalledFunction() &&
628            "direct call cannot find callee");
629     if (!llvm::getAlign(*(CS->getCalledFunction()), Idx, align))
630       align = TD->getABITypeAlignment(Ty);
631   }
632   else { // indirect call
633     const CallInst *CallI = dyn_cast<CallInst>(CS->getInstruction());
634     if (!llvm::getAlign(*CallI, Idx, align))
635       align = TD->getABITypeAlignment(Ty);
636   }
637
638   return align;
639 }
640
641 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
642                                        SmallVectorImpl<SDValue> &InVals) const {
643   SelectionDAG &DAG = CLI.DAG;
644   SDLoc dl = CLI.DL;
645   SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
646   SmallVector<SDValue, 32> &OutVals = CLI.OutVals;
647   SmallVector<ISD::InputArg, 32> &Ins = CLI.Ins;
648   SDValue Chain = CLI.Chain;
649   SDValue Callee = CLI.Callee;
650   bool &isTailCall = CLI.IsTailCall;
651   ArgListTy &Args = CLI.Args;
652   Type *retTy = CLI.RetTy;
653   ImmutableCallSite *CS = CLI.CS;
654
655   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
656   assert(isABI && "Non-ABI compilation is not supported");
657   if (!isABI)
658     return Chain;
659   const DataLayout *TD = getDataLayout();
660   MachineFunction &MF = DAG.getMachineFunction();
661   const Function *F = MF.getFunction();
662
663   SDValue tempChain = Chain;
664   Chain =
665       DAG.getCALLSEQ_START(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
666                            dl);
667   SDValue InFlag = Chain.getValue(1);
668
669   unsigned paramCount = 0;
670   // Args.size() and Outs.size() need not match.
671   // Outs.size() will be larger
672   //   * if there is an aggregate argument with multiple fields (each field
673   //     showing up separately in Outs)
674   //   * if there is a vector argument with more than typical vector-length
675   //     elements (generally if more than 4) where each vector element is
676   //     individually present in Outs.
677   // So a different index should be used for indexing into Outs/OutVals.
678   // See similar issue in LowerFormalArguments.
679   unsigned OIdx = 0;
680   // Declare the .params or .reg need to pass values
681   // to the function
682   for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
683     EVT VT = Outs[OIdx].VT;
684     Type *Ty = Args[i].Ty;
685
686     if (Outs[OIdx].Flags.isByVal() == false) {
687       if (Ty->isAggregateType()) {
688         // aggregate
689         SmallVector<EVT, 16> vtparts;
690         ComputeValueVTs(*this, Ty, vtparts);
691
692         unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
693         // declare .param .align <align> .b8 .param<n>[<size>];
694         unsigned sz = TD->getTypeAllocSize(Ty);
695         SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
696         SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32),
697                                       DAG.getConstant(paramCount, MVT::i32),
698                                       DAG.getConstant(sz, MVT::i32), InFlag };
699         Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
700                             DeclareParamOps, 5);
701         InFlag = Chain.getValue(1);
702         unsigned curOffset = 0;
703         for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
704           unsigned elems = 1;
705           EVT elemtype = vtparts[j];
706           if (vtparts[j].isVector()) {
707             elems = vtparts[j].getVectorNumElements();
708             elemtype = vtparts[j].getVectorElementType();
709           }
710           for (unsigned k = 0, ke = elems; k != ke; ++k) {
711             unsigned sz = elemtype.getSizeInBits();
712             if (elemtype.isInteger() && (sz < 8))
713               sz = 8;
714             SDValue StVal = OutVals[OIdx];
715             if (elemtype.getSizeInBits() < 16) {
716               StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
717             }
718             SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
719             SDValue CopyParamOps[] = { Chain,
720                                        DAG.getConstant(paramCount, MVT::i32),
721                                        DAG.getConstant(curOffset, MVT::i32),
722                                        StVal, InFlag };
723             Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
724                                             CopyParamVTs, &CopyParamOps[0], 5,
725                                             elemtype, MachinePointerInfo());
726             InFlag = Chain.getValue(1);
727             curOffset += sz / 8;
728             ++OIdx;
729           }
730         }
731         if (vtparts.size() > 0)
732           --OIdx;
733         ++paramCount;
734         continue;
735       }
736       if (Ty->isVectorTy()) {
737         EVT ObjectVT = getValueType(Ty);
738         unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
739         // declare .param .align <align> .b8 .param<n>[<size>];
740         unsigned sz = TD->getTypeAllocSize(Ty);
741         SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
742         SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32),
743                                       DAG.getConstant(paramCount, MVT::i32),
744                                       DAG.getConstant(sz, MVT::i32), InFlag };
745         Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
746                             DeclareParamOps, 5);
747         InFlag = Chain.getValue(1);
748         unsigned NumElts = ObjectVT.getVectorNumElements();
749         EVT EltVT = ObjectVT.getVectorElementType();
750         EVT MemVT = EltVT;
751         bool NeedExtend = false;
752         if (EltVT.getSizeInBits() < 16) {
753           NeedExtend = true;
754           EltVT = MVT::i16;
755         }
756
757         // V1 store
758         if (NumElts == 1) {
759           SDValue Elt = OutVals[OIdx++];
760           if (NeedExtend)
761             Elt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt);
762
763           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
764           SDValue CopyParamOps[] = { Chain,
765                                      DAG.getConstant(paramCount, MVT::i32),
766                                      DAG.getConstant(0, MVT::i32), Elt,
767                                      InFlag };
768           Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
769                                           CopyParamVTs, &CopyParamOps[0], 5,
770                                           MemVT, MachinePointerInfo());
771           InFlag = Chain.getValue(1);
772         } else if (NumElts == 2) {
773           SDValue Elt0 = OutVals[OIdx++];
774           SDValue Elt1 = OutVals[OIdx++];
775           if (NeedExtend) {
776             Elt0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt0);
777             Elt1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt1);
778           }
779
780           SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
781           SDValue CopyParamOps[] = { Chain,
782                                      DAG.getConstant(paramCount, MVT::i32),
783                                      DAG.getConstant(0, MVT::i32), Elt0, Elt1,
784                                      InFlag };
785           Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParamV2, dl,
786                                           CopyParamVTs, &CopyParamOps[0], 6,
787                                           MemVT, MachinePointerInfo());
788           InFlag = Chain.getValue(1);
789         } else {
790           unsigned curOffset = 0;
791           // V4 stores
792           // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
793           // the
794           // vector will be expanded to a power of 2 elements, so we know we can
795           // always round up to the next multiple of 4 when creating the vector
796           // stores.
797           // e.g.  4 elem => 1 st.v4
798           //       6 elem => 2 st.v4
799           //       8 elem => 2 st.v4
800           //      11 elem => 3 st.v4
801           unsigned VecSize = 4;
802           if (EltVT.getSizeInBits() == 64)
803             VecSize = 2;
804
805           // This is potentially only part of a vector, so assume all elements
806           // are packed together.
807           unsigned PerStoreOffset = MemVT.getStoreSizeInBits() / 8 * VecSize;
808
809           for (unsigned i = 0; i < NumElts; i += VecSize) {
810             // Get values
811             SDValue StoreVal;
812             SmallVector<SDValue, 8> Ops;
813             Ops.push_back(Chain);
814             Ops.push_back(DAG.getConstant(paramCount, MVT::i32));
815             Ops.push_back(DAG.getConstant(curOffset, MVT::i32));
816
817             unsigned Opc = NVPTXISD::StoreParamV2;
818
819             StoreVal = OutVals[OIdx++];
820             if (NeedExtend)
821               StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
822             Ops.push_back(StoreVal);
823
824             if (i + 1 < NumElts) {
825               StoreVal = OutVals[OIdx++];
826               if (NeedExtend)
827                 StoreVal =
828                     DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
829             } else {
830               StoreVal = DAG.getUNDEF(EltVT);
831             }
832             Ops.push_back(StoreVal);
833
834             if (VecSize == 4) {
835               Opc = NVPTXISD::StoreParamV4;
836               if (i + 2 < NumElts) {
837                 StoreVal = OutVals[OIdx++];
838                 if (NeedExtend)
839                   StoreVal =
840                       DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
841               } else {
842                 StoreVal = DAG.getUNDEF(EltVT);
843               }
844               Ops.push_back(StoreVal);
845
846               if (i + 3 < NumElts) {
847                 StoreVal = OutVals[OIdx++];
848                 if (NeedExtend)
849                   StoreVal =
850                       DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
851               } else {
852                 StoreVal = DAG.getUNDEF(EltVT);
853               }
854               Ops.push_back(StoreVal);
855             }
856
857             SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
858             Chain = DAG.getMemIntrinsicNode(Opc, dl, CopyParamVTs, &Ops[0],
859                                             Ops.size(), MemVT,
860                                             MachinePointerInfo());
861             InFlag = Chain.getValue(1);
862             curOffset += PerStoreOffset;
863           }
864         }
865         ++paramCount;
866         --OIdx;
867         continue;
868       }
869       // Plain scalar
870       // for ABI,    declare .param .b<size> .param<n>;
871       unsigned sz = VT.getSizeInBits();
872       bool needExtend = false;
873       if (VT.isInteger()) {
874         if (sz < 16)
875           needExtend = true;
876         if (sz < 32)
877           sz = 32;
878       }
879       SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
880       SDValue DeclareParamOps[] = { Chain,
881                                     DAG.getConstant(paramCount, MVT::i32),
882                                     DAG.getConstant(sz, MVT::i32),
883                                     DAG.getConstant(0, MVT::i32), InFlag };
884       Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
885                           DeclareParamOps, 5);
886       InFlag = Chain.getValue(1);
887       SDValue OutV = OutVals[OIdx];
888       if (needExtend) {
889         // zext/sext i1 to i16
890         unsigned opc = ISD::ZERO_EXTEND;
891         if (Outs[OIdx].Flags.isSExt())
892           opc = ISD::SIGN_EXTEND;
893         OutV = DAG.getNode(opc, dl, MVT::i16, OutV);
894       }
895       SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
896       SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
897                                  DAG.getConstant(0, MVT::i32), OutV, InFlag };
898
899       unsigned opcode = NVPTXISD::StoreParam;
900       if (Outs[OIdx].Flags.isZExt())
901         opcode = NVPTXISD::StoreParamU32;
902       else if (Outs[OIdx].Flags.isSExt())
903         opcode = NVPTXISD::StoreParamS32;
904       Chain = DAG.getMemIntrinsicNode(opcode, dl, CopyParamVTs, CopyParamOps, 5,
905                                       VT, MachinePointerInfo());
906
907       InFlag = Chain.getValue(1);
908       ++paramCount;
909       continue;
910     }
911     // struct or vector
912     SmallVector<EVT, 16> vtparts;
913     const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
914     assert(PTy && "Type of a byval parameter should be pointer");
915     ComputeValueVTs(*this, PTy->getElementType(), vtparts);
916
917     // declare .param .align <align> .b8 .param<n>[<size>];
918     unsigned sz = Outs[OIdx].Flags.getByValSize();
919     SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
920     // The ByValAlign in the Outs[OIdx].Flags is alway set at this point,
921     // so we don't need to worry about natural alignment or not.
922     // See TargetLowering::LowerCallTo().
923     SDValue DeclareParamOps[] = {
924       Chain, DAG.getConstant(Outs[OIdx].Flags.getByValAlign(), MVT::i32),
925       DAG.getConstant(paramCount, MVT::i32), DAG.getConstant(sz, MVT::i32),
926       InFlag
927     };
928     Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
929                         DeclareParamOps, 5);
930     InFlag = Chain.getValue(1);
931     unsigned curOffset = 0;
932     for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
933       unsigned elems = 1;
934       EVT elemtype = vtparts[j];
935       if (vtparts[j].isVector()) {
936         elems = vtparts[j].getVectorNumElements();
937         elemtype = vtparts[j].getVectorElementType();
938       }
939       for (unsigned k = 0, ke = elems; k != ke; ++k) {
940         unsigned sz = elemtype.getSizeInBits();
941         if (elemtype.isInteger() && (sz < 8))
942           sz = 8;
943         SDValue srcAddr =
944             DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
945                         DAG.getConstant(curOffset, getPointerTy()));
946         SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
947                                      MachinePointerInfo(), false, false, false,
948                                      0);
949         if (elemtype.getSizeInBits() < 16) {
950           theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal);
951         }
952         SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
953         SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
954                                    DAG.getConstant(curOffset, MVT::i32), theVal,
955                                    InFlag };
956         Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
957                                         CopyParamOps, 5, elemtype,
958                                         MachinePointerInfo());
959
960         InFlag = Chain.getValue(1);
961         curOffset += sz / 8;
962       }
963     }
964     ++paramCount;
965   }
966
967   GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
968   unsigned retAlignment = 0;
969
970   // Handle Result
971   if (Ins.size() > 0) {
972     SmallVector<EVT, 16> resvtparts;
973     ComputeValueVTs(*this, retTy, resvtparts);
974
975     // Declare
976     //  .param .align 16 .b8 retval0[<size-in-bytes>], or
977     //  .param .b<size-in-bits> retval0
978     unsigned resultsz = TD->getTypeAllocSizeInBits(retTy);
979     if (retTy->isPrimitiveType() || retTy->isIntegerTy() ||
980         retTy->isPointerTy()) {
981       // Scalar needs to be at least 32bit wide
982       if (resultsz < 32)
983         resultsz = 32;
984       SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
985       SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
986                                   DAG.getConstant(resultsz, MVT::i32),
987                                   DAG.getConstant(0, MVT::i32), InFlag };
988       Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
989                           DeclareRetOps, 5);
990       InFlag = Chain.getValue(1);
991     } else {
992       retAlignment = getArgumentAlignment(Callee, CS, retTy, 0);
993       SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
994       SDValue DeclareRetOps[] = { Chain,
995                                   DAG.getConstant(retAlignment, MVT::i32),
996                                   DAG.getConstant(resultsz / 8, MVT::i32),
997                                   DAG.getConstant(0, MVT::i32), InFlag };
998       Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
999                           DeclareRetOps, 5);
1000       InFlag = Chain.getValue(1);
1001     }
1002   }
1003
1004   if (!Func) {
1005     // This is indirect function call case : PTX requires a prototype of the
1006     // form
1007     // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
1008     // to be emitted, and the label has to used as the last arg of call
1009     // instruction.
1010     // The prototype is embedded in a string and put as the operand for an
1011     // INLINEASM SDNode.
1012     SDVTList InlineAsmVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1013     std::string proto_string =
1014         getPrototype(retTy, Args, Outs, retAlignment, CS);
1015     const char *asmstr = nvTM->getManagedStrPool()
1016         ->getManagedString(proto_string.c_str())->c_str();
1017     SDValue InlineAsmOps[] = {
1018       Chain, DAG.getTargetExternalSymbol(asmstr, getPointerTy()),
1019       DAG.getMDNode(0), DAG.getTargetConstant(0, MVT::i32), InFlag
1020     };
1021     Chain = DAG.getNode(ISD::INLINEASM, dl, InlineAsmVTs, InlineAsmOps, 5);
1022     InFlag = Chain.getValue(1);
1023   }
1024   // Op to just print "call"
1025   SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1026   SDValue PrintCallOps[] = {
1027     Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, MVT::i32), InFlag
1028   };
1029   Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall),
1030                       dl, PrintCallVTs, PrintCallOps, 3);
1031   InFlag = Chain.getValue(1);
1032
1033   // Ops to print out the function name
1034   SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1035   SDValue CallVoidOps[] = { Chain, Callee, InFlag };
1036   Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps, 3);
1037   InFlag = Chain.getValue(1);
1038
1039   // Ops to print out the param list
1040   SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1041   SDValue CallArgBeginOps[] = { Chain, InFlag };
1042   Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
1043                       CallArgBeginOps, 2);
1044   InFlag = Chain.getValue(1);
1045
1046   for (unsigned i = 0, e = paramCount; i != e; ++i) {
1047     unsigned opcode;
1048     if (i == (e - 1))
1049       opcode = NVPTXISD::LastCallArg;
1050     else
1051       opcode = NVPTXISD::CallArg;
1052     SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1053     SDValue CallArgOps[] = { Chain, DAG.getConstant(1, MVT::i32),
1054                              DAG.getConstant(i, MVT::i32), InFlag };
1055     Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps, 4);
1056     InFlag = Chain.getValue(1);
1057   }
1058   SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1059   SDValue CallArgEndOps[] = { Chain, DAG.getConstant(Func ? 1 : 0, MVT::i32),
1060                               InFlag };
1061   Chain =
1062       DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps, 3);
1063   InFlag = Chain.getValue(1);
1064
1065   if (!Func) {
1066     SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
1067     SDValue PrototypeOps[] = { Chain, DAG.getConstant(uniqueCallSite, MVT::i32),
1068                                InFlag };
1069     Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps, 3);
1070     InFlag = Chain.getValue(1);
1071   }
1072
1073   // Generate loads from param memory/moves from registers for result
1074   if (Ins.size() > 0) {
1075     unsigned resoffset = 0;
1076     if (retTy && retTy->isVectorTy()) {
1077       EVT ObjectVT = getValueType(retTy);
1078       unsigned NumElts = ObjectVT.getVectorNumElements();
1079       EVT EltVT = ObjectVT.getVectorElementType();
1080       assert(nvTM->getTargetLowering()->getNumRegisters(F->getContext(),
1081                                                         ObjectVT) == NumElts &&
1082              "Vector was not scalarized");
1083       unsigned sz = EltVT.getSizeInBits();
1084       bool needTruncate = sz < 16 ? true : false;
1085
1086       if (NumElts == 1) {
1087         // Just a simple load
1088         std::vector<EVT> LoadRetVTs;
1089         if (needTruncate) {
1090           // If loading i1 result, generate
1091           //   load i16
1092           //   trunc i16 to i1
1093           LoadRetVTs.push_back(MVT::i16);
1094         } else
1095           LoadRetVTs.push_back(EltVT);
1096         LoadRetVTs.push_back(MVT::Other);
1097         LoadRetVTs.push_back(MVT::Glue);
1098         std::vector<SDValue> LoadRetOps;
1099         LoadRetOps.push_back(Chain);
1100         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1101         LoadRetOps.push_back(DAG.getConstant(0, MVT::i32));
1102         LoadRetOps.push_back(InFlag);
1103         SDValue retval = DAG.getMemIntrinsicNode(
1104             NVPTXISD::LoadParam, dl,
1105             DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1106             LoadRetOps.size(), EltVT, MachinePointerInfo());
1107         Chain = retval.getValue(1);
1108         InFlag = retval.getValue(2);
1109         SDValue Ret0 = retval;
1110         if (needTruncate)
1111           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Ret0);
1112         InVals.push_back(Ret0);
1113       } else if (NumElts == 2) {
1114         // LoadV2
1115         std::vector<EVT> LoadRetVTs;
1116         if (needTruncate) {
1117           // If loading i1 result, generate
1118           //   load i16
1119           //   trunc i16 to i1
1120           LoadRetVTs.push_back(MVT::i16);
1121           LoadRetVTs.push_back(MVT::i16);
1122         } else {
1123           LoadRetVTs.push_back(EltVT);
1124           LoadRetVTs.push_back(EltVT);
1125         }
1126         LoadRetVTs.push_back(MVT::Other);
1127         LoadRetVTs.push_back(MVT::Glue);
1128         std::vector<SDValue> LoadRetOps;
1129         LoadRetOps.push_back(Chain);
1130         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1131         LoadRetOps.push_back(DAG.getConstant(0, MVT::i32));
1132         LoadRetOps.push_back(InFlag);
1133         SDValue retval = DAG.getMemIntrinsicNode(
1134             NVPTXISD::LoadParamV2, dl,
1135             DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1136             LoadRetOps.size(), EltVT, MachinePointerInfo());
1137         Chain = retval.getValue(2);
1138         InFlag = retval.getValue(3);
1139         SDValue Ret0 = retval.getValue(0);
1140         SDValue Ret1 = retval.getValue(1);
1141         if (needTruncate) {
1142           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret0);
1143           InVals.push_back(Ret0);
1144           Ret1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret1);
1145           InVals.push_back(Ret1);
1146         } else {
1147           InVals.push_back(Ret0);
1148           InVals.push_back(Ret1);
1149         }
1150       } else {
1151         // Split into N LoadV4
1152         unsigned Ofst = 0;
1153         unsigned VecSize = 4;
1154         unsigned Opc = NVPTXISD::LoadParamV4;
1155         if (EltVT.getSizeInBits() == 64) {
1156           VecSize = 2;
1157           Opc = NVPTXISD::LoadParamV2;
1158         }
1159         EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1160         for (unsigned i = 0; i < NumElts; i += VecSize) {
1161           SmallVector<EVT, 8> LoadRetVTs;
1162           if (needTruncate) {
1163             // If loading i1 result, generate
1164             //   load i16
1165             //   trunc i16 to i1
1166             for (unsigned j = 0; j < VecSize; ++j)
1167               LoadRetVTs.push_back(MVT::i16);
1168           } else {
1169             for (unsigned j = 0; j < VecSize; ++j)
1170               LoadRetVTs.push_back(EltVT);
1171           }
1172           LoadRetVTs.push_back(MVT::Other);
1173           LoadRetVTs.push_back(MVT::Glue);
1174           SmallVector<SDValue, 4> LoadRetOps;
1175           LoadRetOps.push_back(Chain);
1176           LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1177           LoadRetOps.push_back(DAG.getConstant(Ofst, MVT::i32));
1178           LoadRetOps.push_back(InFlag);
1179           SDValue retval = DAG.getMemIntrinsicNode(
1180               Opc, dl, DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()),
1181               &LoadRetOps[0], LoadRetOps.size(), EltVT, MachinePointerInfo());
1182           if (VecSize == 2) {
1183             Chain = retval.getValue(2);
1184             InFlag = retval.getValue(3);
1185           } else {
1186             Chain = retval.getValue(4);
1187             InFlag = retval.getValue(5);
1188           }
1189
1190           for (unsigned j = 0; j < VecSize; ++j) {
1191             if (i + j >= NumElts)
1192               break;
1193             SDValue Elt = retval.getValue(j);
1194             if (needTruncate)
1195               Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
1196             InVals.push_back(Elt);
1197           }
1198           Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1199         }
1200       }
1201     } else {
1202       SmallVector<EVT, 16> VTs;
1203       ComputePTXValueVTs(*this, retTy, VTs);
1204       assert(VTs.size() == Ins.size() && "Bad value decomposition");
1205       for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
1206         unsigned sz = VTs[i].getSizeInBits();
1207         bool needTruncate = sz < 8 ? true : false;
1208         if (VTs[i].isInteger() && (sz < 8))
1209           sz = 8;
1210
1211         SmallVector<EVT, 4> LoadRetVTs;
1212         EVT TheLoadType = VTs[i];
1213         if (retTy->isIntegerTy() &&
1214             TD->getTypeAllocSizeInBits(retTy) < 32) {
1215           // This is for integer types only, and specifically not for
1216           // aggregates.
1217           LoadRetVTs.push_back(MVT::i32);
1218           TheLoadType = MVT::i32;
1219         } else if (sz < 16) {
1220           // If loading i1/i8 result, generate
1221           //   load i8 (-> i16)
1222           //   trunc i16 to i1/i8
1223           LoadRetVTs.push_back(MVT::i16);
1224         } else
1225           LoadRetVTs.push_back(Ins[i].VT);
1226         LoadRetVTs.push_back(MVT::Other);
1227         LoadRetVTs.push_back(MVT::Glue);
1228
1229         SmallVector<SDValue, 4> LoadRetOps;
1230         LoadRetOps.push_back(Chain);
1231         LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1232         LoadRetOps.push_back(DAG.getConstant(resoffset, MVT::i32));
1233         LoadRetOps.push_back(InFlag);
1234         SDValue retval = DAG.getMemIntrinsicNode(
1235             NVPTXISD::LoadParam, dl,
1236             DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1237             LoadRetOps.size(), TheLoadType, MachinePointerInfo());
1238         Chain = retval.getValue(1);
1239         InFlag = retval.getValue(2);
1240         SDValue Ret0 = retval.getValue(0);
1241         if (needTruncate)
1242           Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0);
1243         InVals.push_back(Ret0);
1244         resoffset += sz / 8;
1245       }
1246     }
1247   }
1248
1249   Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
1250                              DAG.getIntPtrConstant(uniqueCallSite + 1, true),
1251                              InFlag, dl);
1252   uniqueCallSite++;
1253
1254   // set isTailCall to false for now, until we figure out how to express
1255   // tail call optimization in PTX
1256   isTailCall = false;
1257   return Chain;
1258 }
1259
1260 // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
1261 // (see LegalizeDAG.cpp). This is slow and uses local memory.
1262 // We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
1263 SDValue
1264 NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
1265   SDNode *Node = Op.getNode();
1266   SDLoc dl(Node);
1267   SmallVector<SDValue, 8> Ops;
1268   unsigned NumOperands = Node->getNumOperands();
1269   for (unsigned i = 0; i < NumOperands; ++i) {
1270     SDValue SubOp = Node->getOperand(i);
1271     EVT VVT = SubOp.getNode()->getValueType(0);
1272     EVT EltVT = VVT.getVectorElementType();
1273     unsigned NumSubElem = VVT.getVectorNumElements();
1274     for (unsigned j = 0; j < NumSubElem; ++j) {
1275       Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
1276                                 DAG.getIntPtrConstant(j)));
1277     }
1278   }
1279   return DAG.getNode(ISD::BUILD_VECTOR, dl, Node->getValueType(0), &Ops[0],
1280                      Ops.size());
1281 }
1282
1283 SDValue
1284 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
1285   switch (Op.getOpcode()) {
1286   case ISD::RETURNADDR:
1287     return SDValue();
1288   case ISD::FRAMEADDR:
1289     return SDValue();
1290   case ISD::GlobalAddress:
1291     return LowerGlobalAddress(Op, DAG);
1292   case ISD::INTRINSIC_W_CHAIN:
1293     return Op;
1294   case ISD::BUILD_VECTOR:
1295   case ISD::EXTRACT_SUBVECTOR:
1296     return Op;
1297   case ISD::CONCAT_VECTORS:
1298     return LowerCONCAT_VECTORS(Op, DAG);
1299   case ISD::STORE:
1300     return LowerSTORE(Op, DAG);
1301   case ISD::LOAD:
1302     return LowerLOAD(Op, DAG);
1303   default:
1304     llvm_unreachable("Custom lowering not defined for operation");
1305   }
1306 }
1307
1308 SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
1309   if (Op.getValueType() == MVT::i1)
1310     return LowerLOADi1(Op, DAG);
1311   else
1312     return SDValue();
1313 }
1314
1315 // v = ld i1* addr
1316 //   =>
1317 // v1 = ld i8* addr (-> i16)
1318 // v = trunc i16 to i1
1319 SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
1320   SDNode *Node = Op.getNode();
1321   LoadSDNode *LD = cast<LoadSDNode>(Node);
1322   SDLoc dl(Node);
1323   assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
1324   assert(Node->getValueType(0) == MVT::i1 &&
1325          "Custom lowering for i1 load only");
1326   SDValue newLD =
1327       DAG.getLoad(MVT::i16, dl, LD->getChain(), LD->getBasePtr(),
1328                   LD->getPointerInfo(), LD->isVolatile(), LD->isNonTemporal(),
1329                   LD->isInvariant(), LD->getAlignment());
1330   SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
1331   // The legalizer (the caller) is expecting two values from the legalized
1332   // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
1333   // in LegalizeDAG.cpp which also uses MergeValues.
1334   SDValue Ops[] = { result, LD->getChain() };
1335   return DAG.getMergeValues(Ops, 2, dl);
1336 }
1337
1338 SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
1339   EVT ValVT = Op.getOperand(1).getValueType();
1340   if (ValVT == MVT::i1)
1341     return LowerSTOREi1(Op, DAG);
1342   else if (ValVT.isVector())
1343     return LowerSTOREVector(Op, DAG);
1344   else
1345     return SDValue();
1346 }
1347
1348 SDValue
1349 NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
1350   SDNode *N = Op.getNode();
1351   SDValue Val = N->getOperand(1);
1352   SDLoc DL(N);
1353   EVT ValVT = Val.getValueType();
1354
1355   if (ValVT.isVector()) {
1356     // We only handle "native" vector sizes for now, e.g. <4 x double> is not
1357     // legal.  We can (and should) split that into 2 stores of <2 x double> here
1358     // but I'm leaving that as a TODO for now.
1359     if (!ValVT.isSimple())
1360       return SDValue();
1361     switch (ValVT.getSimpleVT().SimpleTy) {
1362     default:
1363       return SDValue();
1364     case MVT::v2i8:
1365     case MVT::v2i16:
1366     case MVT::v2i32:
1367     case MVT::v2i64:
1368     case MVT::v2f32:
1369     case MVT::v2f64:
1370     case MVT::v4i8:
1371     case MVT::v4i16:
1372     case MVT::v4i32:
1373     case MVT::v4f32:
1374       // This is a "native" vector type
1375       break;
1376     }
1377
1378     unsigned Opcode = 0;
1379     EVT EltVT = ValVT.getVectorElementType();
1380     unsigned NumElts = ValVT.getVectorNumElements();
1381
1382     // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
1383     // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
1384     // stored type to i16 and propogate the "real" type as the memory type.
1385     bool NeedExt = false;
1386     if (EltVT.getSizeInBits() < 16)
1387       NeedExt = true;
1388
1389     switch (NumElts) {
1390     default:
1391       return SDValue();
1392     case 2:
1393       Opcode = NVPTXISD::StoreV2;
1394       break;
1395     case 4: {
1396       Opcode = NVPTXISD::StoreV4;
1397       break;
1398     }
1399     }
1400
1401     SmallVector<SDValue, 8> Ops;
1402
1403     // First is the chain
1404     Ops.push_back(N->getOperand(0));
1405
1406     // Then the split values
1407     for (unsigned i = 0; i < NumElts; ++i) {
1408       SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
1409                                    DAG.getIntPtrConstant(i));
1410       if (NeedExt)
1411         ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
1412       Ops.push_back(ExtVal);
1413     }
1414
1415     // Then any remaining arguments
1416     for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i) {
1417       Ops.push_back(N->getOperand(i));
1418     }
1419
1420     MemSDNode *MemSD = cast<MemSDNode>(N);
1421
1422     SDValue NewSt = DAG.getMemIntrinsicNode(
1423         Opcode, DL, DAG.getVTList(MVT::Other), &Ops[0], Ops.size(),
1424         MemSD->getMemoryVT(), MemSD->getMemOperand());
1425
1426     //return DCI.CombineTo(N, NewSt, true);
1427     return NewSt;
1428   }
1429
1430   return SDValue();
1431 }
1432
1433 // st i1 v, addr
1434 //    =>
1435 // v1 = zxt v to i16
1436 // st.u8 i16, addr
1437 SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
1438   SDNode *Node = Op.getNode();
1439   SDLoc dl(Node);
1440   StoreSDNode *ST = cast<StoreSDNode>(Node);
1441   SDValue Tmp1 = ST->getChain();
1442   SDValue Tmp2 = ST->getBasePtr();
1443   SDValue Tmp3 = ST->getValue();
1444   assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
1445   unsigned Alignment = ST->getAlignment();
1446   bool isVolatile = ST->isVolatile();
1447   bool isNonTemporal = ST->isNonTemporal();
1448   Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3);
1449   SDValue Result = DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2,
1450                                      ST->getPointerInfo(), MVT::i8, isNonTemporal,
1451                                      isVolatile, Alignment);
1452   return Result;
1453 }
1454
1455 SDValue NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname,
1456                                         int idx, EVT v) const {
1457   std::string *name = nvTM->getManagedStrPool()->getManagedString(inname);
1458   std::stringstream suffix;
1459   suffix << idx;
1460   *name += suffix.str();
1461   return DAG.getTargetExternalSymbol(name->c_str(), v);
1462 }
1463
1464 SDValue
1465 NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const {
1466   return getExtSymb(DAG, ".PARAM", idx, v);
1467 }
1468
1469 SDValue NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) {
1470   return getExtSymb(DAG, ".HLPPARAM", idx);
1471 }
1472
1473 // Check to see if the kernel argument is image*_t or sampler_t
1474
1475 bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) {
1476   static const char *const specialTypes[] = { "struct._image2d_t",
1477                                               "struct._image3d_t",
1478                                               "struct._sampler_t" };
1479
1480   const Type *Ty = arg->getType();
1481   const PointerType *PTy = dyn_cast<PointerType>(Ty);
1482
1483   if (!PTy)
1484     return false;
1485
1486   if (!context)
1487     return false;
1488
1489   const StructType *STy = dyn_cast<StructType>(PTy->getElementType());
1490   const std::string TypeName = STy && !STy->isLiteral() ? STy->getName() : "";
1491
1492   for (int i = 0, e = array_lengthof(specialTypes); i != e; ++i)
1493     if (TypeName == specialTypes[i])
1494       return true;
1495
1496   return false;
1497 }
1498
1499 SDValue NVPTXTargetLowering::LowerFormalArguments(
1500     SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
1501     const SmallVectorImpl<ISD::InputArg> &Ins, SDLoc dl, SelectionDAG &DAG,
1502     SmallVectorImpl<SDValue> &InVals) const {
1503   MachineFunction &MF = DAG.getMachineFunction();
1504   const DataLayout *TD = getDataLayout();
1505
1506   const Function *F = MF.getFunction();
1507   const AttributeSet &PAL = F->getAttributes();
1508   const TargetLowering *TLI = nvTM->getTargetLowering();
1509
1510   SDValue Root = DAG.getRoot();
1511   std::vector<SDValue> OutChains;
1512
1513   bool isKernel = llvm::isKernelFunction(*F);
1514   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1515   assert(isABI && "Non-ABI compilation is not supported");
1516   if (!isABI)
1517     return Chain;
1518
1519   std::vector<Type *> argTypes;
1520   std::vector<const Argument *> theArgs;
1521   for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end();
1522        I != E; ++I) {
1523     theArgs.push_back(I);
1524     argTypes.push_back(I->getType());
1525   }
1526   // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
1527   // Ins.size() will be larger
1528   //   * if there is an aggregate argument with multiple fields (each field
1529   //     showing up separately in Ins)
1530   //   * if there is a vector argument with more than typical vector-length
1531   //     elements (generally if more than 4) where each vector element is
1532   //     individually present in Ins.
1533   // So a different index should be used for indexing into Ins.
1534   // See similar issue in LowerCall.
1535   unsigned InsIdx = 0;
1536
1537   int idx = 0;
1538   for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++idx, ++InsIdx) {
1539     Type *Ty = argTypes[i];
1540
1541     // If the kernel argument is image*_t or sampler_t, convert it to
1542     // a i32 constant holding the parameter position. This can later
1543     // matched in the AsmPrinter to output the correct mangled name.
1544     if (isImageOrSamplerVal(
1545             theArgs[i],
1546             (theArgs[i]->getParent() ? theArgs[i]->getParent()->getParent()
1547                                      : 0))) {
1548       assert(isKernel && "Only kernels can have image/sampler params");
1549       InVals.push_back(DAG.getConstant(i + 1, MVT::i32));
1550       continue;
1551     }
1552
1553     if (theArgs[i]->use_empty()) {
1554       // argument is dead
1555       if (Ty->isAggregateType()) {
1556         SmallVector<EVT, 16> vtparts;
1557
1558         ComputePTXValueVTs(*this, Ty, vtparts);
1559         assert(vtparts.size() > 0 && "empty aggregate type not expected");
1560         for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1561              ++parti) {
1562           EVT partVT = vtparts[parti];
1563           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, partVT));
1564           ++InsIdx;
1565         }
1566         if (vtparts.size() > 0)
1567           --InsIdx;
1568         continue;
1569       }
1570       if (Ty->isVectorTy()) {
1571         EVT ObjectVT = getValueType(Ty);
1572         unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
1573         for (unsigned parti = 0; parti < NumRegs; ++parti) {
1574           InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1575           ++InsIdx;
1576         }
1577         if (NumRegs > 0)
1578           --InsIdx;
1579         continue;
1580       }
1581       InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1582       continue;
1583     }
1584
1585     // In the following cases, assign a node order of "idx+1"
1586     // to newly created nodes. The SDNodes for params have to
1587     // appear in the same order as their order of appearance
1588     // in the original function. "idx+1" holds that order.
1589     if (PAL.hasAttribute(i + 1, Attribute::ByVal) == false) {
1590       if (Ty->isAggregateType()) {
1591         SmallVector<EVT, 16> vtparts;
1592         SmallVector<uint64_t, 16> offsets;
1593
1594         // NOTE: Here, we lose the ability to issue vector loads for vectors
1595         // that are a part of a struct.  This should be investigated in the
1596         // future.
1597         ComputePTXValueVTs(*this, Ty, vtparts, &offsets, 0);
1598         assert(vtparts.size() > 0 && "empty aggregate type not expected");
1599         bool aggregateIsPacked = false;
1600         if (StructType *STy = llvm::dyn_cast<StructType>(Ty))
1601           aggregateIsPacked = STy->isPacked();
1602
1603         SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1604         for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1605              ++parti) {
1606           EVT partVT = vtparts[parti];
1607           Value *srcValue = Constant::getNullValue(
1608               PointerType::get(partVT.getTypeForEVT(F->getContext()),
1609                                llvm::ADDRESS_SPACE_PARAM));
1610           SDValue srcAddr =
1611               DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1612                           DAG.getConstant(offsets[parti], getPointerTy()));
1613           unsigned partAlign =
1614               aggregateIsPacked ? 1
1615                                 : TD->getABITypeAlignment(
1616                                       partVT.getTypeForEVT(F->getContext()));
1617           SDValue p;
1618           if (Ins[InsIdx].VT.getSizeInBits() > partVT.getSizeInBits()) {
1619             ISD::LoadExtType ExtOp = Ins[InsIdx].Flags.isSExt() ? 
1620                                      ISD::SEXTLOAD : ISD::ZEXTLOAD;
1621             p = DAG.getExtLoad(ExtOp, dl, Ins[InsIdx].VT, Root, srcAddr,
1622                                MachinePointerInfo(srcValue), partVT, false,
1623                                false, partAlign);
1624           } else {
1625             p = DAG.getLoad(partVT, dl, Root, srcAddr,
1626                             MachinePointerInfo(srcValue), false, false, false,
1627                             partAlign);
1628           }
1629           if (p.getNode())
1630             p.getNode()->setIROrder(idx + 1);
1631           InVals.push_back(p);
1632           ++InsIdx;
1633         }
1634         if (vtparts.size() > 0)
1635           --InsIdx;
1636         continue;
1637       }
1638       if (Ty->isVectorTy()) {
1639         EVT ObjectVT = getValueType(Ty);
1640         SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1641         unsigned NumElts = ObjectVT.getVectorNumElements();
1642         assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts &&
1643                "Vector was not scalarized");
1644         unsigned Ofst = 0;
1645         EVT EltVT = ObjectVT.getVectorElementType();
1646
1647         // V1 load
1648         // f32 = load ...
1649         if (NumElts == 1) {
1650           // We only have one element, so just directly load it
1651           Value *SrcValue = Constant::getNullValue(PointerType::get(
1652               EltVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1653           SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1654                                         DAG.getConstant(Ofst, getPointerTy()));
1655           SDValue P = DAG.getLoad(
1656               EltVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1657               false, true,
1658               TD->getABITypeAlignment(EltVT.getTypeForEVT(F->getContext())));
1659           if (P.getNode())
1660             P.getNode()->setIROrder(idx + 1);
1661
1662           if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
1663             P = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, P);
1664           InVals.push_back(P);
1665           Ofst += TD->getTypeAllocSize(EltVT.getTypeForEVT(F->getContext()));
1666           ++InsIdx;
1667         } else if (NumElts == 2) {
1668           // V2 load
1669           // f32,f32 = load ...
1670           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, 2);
1671           Value *SrcValue = Constant::getNullValue(PointerType::get(
1672               VecVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1673           SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1674                                         DAG.getConstant(Ofst, getPointerTy()));
1675           SDValue P = DAG.getLoad(
1676               VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1677               false, true,
1678               TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1679           if (P.getNode())
1680             P.getNode()->setIROrder(idx + 1);
1681
1682           SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1683                                      DAG.getIntPtrConstant(0));
1684           SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1685                                      DAG.getIntPtrConstant(1));
1686
1687           if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits()) {
1688             Elt0 = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt0);
1689             Elt1 = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt1);
1690           }
1691
1692           InVals.push_back(Elt0);
1693           InVals.push_back(Elt1);
1694           Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1695           InsIdx += 2;
1696         } else {
1697           // V4 loads
1698           // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
1699           // the
1700           // vector will be expanded to a power of 2 elements, so we know we can
1701           // always round up to the next multiple of 4 when creating the vector
1702           // loads.
1703           // e.g.  4 elem => 1 ld.v4
1704           //       6 elem => 2 ld.v4
1705           //       8 elem => 2 ld.v4
1706           //      11 elem => 3 ld.v4
1707           unsigned VecSize = 4;
1708           if (EltVT.getSizeInBits() == 64) {
1709             VecSize = 2;
1710           }
1711           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1712           for (unsigned i = 0; i < NumElts; i += VecSize) {
1713             Value *SrcValue = Constant::getNullValue(
1714                 PointerType::get(VecVT.getTypeForEVT(F->getContext()),
1715                                  llvm::ADDRESS_SPACE_PARAM));
1716             SDValue SrcAddr =
1717                 DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1718                             DAG.getConstant(Ofst, getPointerTy()));
1719             SDValue P = DAG.getLoad(
1720                 VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1721                 false, true,
1722                 TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1723             if (P.getNode())
1724               P.getNode()->setIROrder(idx + 1);
1725
1726             for (unsigned j = 0; j < VecSize; ++j) {
1727               if (i + j >= NumElts)
1728                 break;
1729               SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1730                                         DAG.getIntPtrConstant(j));
1731               if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
1732                 Elt = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt);
1733               InVals.push_back(Elt);
1734             }
1735             Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1736             InsIdx += VecSize;
1737           }
1738         }
1739
1740         if (NumElts > 0)
1741           --InsIdx;
1742         continue;
1743       }
1744       // A plain scalar.
1745       EVT ObjectVT = getValueType(Ty);
1746       // If ABI, load from the param symbol
1747       SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1748       Value *srcValue = Constant::getNullValue(PointerType::get(
1749           ObjectVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1750       SDValue p;
1751        if (ObjectVT.getSizeInBits() < Ins[InsIdx].VT.getSizeInBits()) {
1752         ISD::LoadExtType ExtOp = Ins[InsIdx].Flags.isSExt() ? 
1753                                        ISD::SEXTLOAD : ISD::ZEXTLOAD;
1754         p = DAG.getExtLoad(ExtOp, dl, Ins[InsIdx].VT, Root, Arg,
1755                            MachinePointerInfo(srcValue), ObjectVT, false, false,
1756         TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1757       } else {
1758         p = DAG.getLoad(Ins[InsIdx].VT, dl, Root, Arg,
1759                         MachinePointerInfo(srcValue), false, false, false,
1760         TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1761       }
1762       if (p.getNode())
1763         p.getNode()->setIROrder(idx + 1);
1764       InVals.push_back(p);
1765       continue;
1766     }
1767
1768     // Param has ByVal attribute
1769     // Return MoveParam(param symbol).
1770     // Ideally, the param symbol can be returned directly,
1771     // but when SDNode builder decides to use it in a CopyToReg(),
1772     // machine instruction fails because TargetExternalSymbol
1773     // (not lowered) is target dependent, and CopyToReg assumes
1774     // the source is lowered.
1775     EVT ObjectVT = getValueType(Ty);
1776     assert(ObjectVT == Ins[InsIdx].VT &&
1777            "Ins type did not match function type");
1778     SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1779     SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1780     if (p.getNode())
1781       p.getNode()->setIROrder(idx + 1);
1782     if (isKernel)
1783       InVals.push_back(p);
1784     else {
1785       SDValue p2 = DAG.getNode(
1786           ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT,
1787           DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32), p);
1788       InVals.push_back(p2);
1789     }
1790   }
1791
1792   // Clang will check explicit VarArg and issue error if any. However, Clang
1793   // will let code with
1794   // implicit var arg like f() pass. See bug 617733.
1795   // We treat this case as if the arg list is empty.
1796   // if (F.isVarArg()) {
1797   // assert(0 && "VarArg not supported yet!");
1798   //}
1799
1800   if (!OutChains.empty())
1801     DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &OutChains[0],
1802                             OutChains.size()));
1803
1804   return Chain;
1805 }
1806
1807
1808 SDValue
1809 NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
1810                                  bool isVarArg,
1811                                  const SmallVectorImpl<ISD::OutputArg> &Outs,
1812                                  const SmallVectorImpl<SDValue> &OutVals,
1813                                  SDLoc dl, SelectionDAG &DAG) const {
1814   MachineFunction &MF = DAG.getMachineFunction();
1815   const Function *F = MF.getFunction();
1816   Type *RetTy = F->getReturnType();
1817   const DataLayout *TD = getDataLayout();
1818
1819   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1820   assert(isABI && "Non-ABI compilation is not supported");
1821   if (!isABI)
1822     return Chain;
1823
1824   if (VectorType *VTy = dyn_cast<VectorType>(RetTy)) {
1825     // If we have a vector type, the OutVals array will be the scalarized
1826     // components and we have combine them into 1 or more vector stores.
1827     unsigned NumElts = VTy->getNumElements();
1828     assert(NumElts == Outs.size() && "Bad scalarization of return value");
1829
1830     // const_cast can be removed in later LLVM versions
1831     EVT EltVT = getValueType(RetTy).getVectorElementType();
1832     bool NeedExtend = false;
1833     if (EltVT.getSizeInBits() < 16)
1834       NeedExtend = true;
1835
1836     // V1 store
1837     if (NumElts == 1) {
1838       SDValue StoreVal = OutVals[0];
1839       // We only have one element, so just directly store it
1840       if (NeedExtend)
1841         StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
1842       SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal };
1843       Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
1844                                       DAG.getVTList(MVT::Other), &Ops[0], 3,
1845                                       EltVT, MachinePointerInfo());
1846
1847     } else if (NumElts == 2) {
1848       // V2 store
1849       SDValue StoreVal0 = OutVals[0];
1850       SDValue StoreVal1 = OutVals[1];
1851
1852       if (NeedExtend) {
1853         StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal0);
1854         StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal1);
1855       }
1856
1857       SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal0,
1858                         StoreVal1 };
1859       Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetvalV2, dl,
1860                                       DAG.getVTList(MVT::Other), &Ops[0], 4,
1861                                       EltVT, MachinePointerInfo());
1862     } else {
1863       // V4 stores
1864       // We have at least 4 elements (<3 x Ty> expands to 4 elements) and the
1865       // vector will be expanded to a power of 2 elements, so we know we can
1866       // always round up to the next multiple of 4 when creating the vector
1867       // stores.
1868       // e.g.  4 elem => 1 st.v4
1869       //       6 elem => 2 st.v4
1870       //       8 elem => 2 st.v4
1871       //      11 elem => 3 st.v4
1872
1873       unsigned VecSize = 4;
1874       if (OutVals[0].getValueType().getSizeInBits() == 64)
1875         VecSize = 2;
1876
1877       unsigned Offset = 0;
1878
1879       EVT VecVT =
1880           EVT::getVectorVT(F->getContext(), OutVals[0].getValueType(), VecSize);
1881       unsigned PerStoreOffset =
1882           TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1883
1884       for (unsigned i = 0; i < NumElts; i += VecSize) {
1885         // Get values
1886         SDValue StoreVal;
1887         SmallVector<SDValue, 8> Ops;
1888         Ops.push_back(Chain);
1889         Ops.push_back(DAG.getConstant(Offset, MVT::i32));
1890         unsigned Opc = NVPTXISD::StoreRetvalV2;
1891         EVT ExtendedVT = (NeedExtend) ? MVT::i16 : OutVals[0].getValueType();
1892
1893         StoreVal = OutVals[i];
1894         if (NeedExtend)
1895           StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1896         Ops.push_back(StoreVal);
1897
1898         if (i + 1 < NumElts) {
1899           StoreVal = OutVals[i + 1];
1900           if (NeedExtend)
1901             StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1902         } else {
1903           StoreVal = DAG.getUNDEF(ExtendedVT);
1904         }
1905         Ops.push_back(StoreVal);
1906
1907         if (VecSize == 4) {
1908           Opc = NVPTXISD::StoreRetvalV4;
1909           if (i + 2 < NumElts) {
1910             StoreVal = OutVals[i + 2];
1911             if (NeedExtend)
1912               StoreVal =
1913                   DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1914           } else {
1915             StoreVal = DAG.getUNDEF(ExtendedVT);
1916           }
1917           Ops.push_back(StoreVal);
1918
1919           if (i + 3 < NumElts) {
1920             StoreVal = OutVals[i + 3];
1921             if (NeedExtend)
1922               StoreVal =
1923                   DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1924           } else {
1925             StoreVal = DAG.getUNDEF(ExtendedVT);
1926           }
1927           Ops.push_back(StoreVal);
1928         }
1929
1930         // Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size());
1931         Chain =
1932             DAG.getMemIntrinsicNode(Opc, dl, DAG.getVTList(MVT::Other), &Ops[0],
1933                                     Ops.size(), EltVT, MachinePointerInfo());
1934         Offset += PerStoreOffset;
1935       }
1936     }
1937   } else {
1938     SmallVector<EVT, 16> ValVTs;
1939     // const_cast is necessary since we are still using an LLVM version from
1940     // before the type system re-write.
1941     ComputePTXValueVTs(*this, RetTy, ValVTs);
1942     assert(ValVTs.size() == OutVals.size() && "Bad return value decomposition");
1943
1944     unsigned SizeSoFar = 0;
1945     for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
1946       SDValue theVal = OutVals[i];
1947       EVT TheValType = theVal.getValueType();
1948       unsigned numElems = 1;
1949       if (TheValType.isVector())
1950         numElems = TheValType.getVectorNumElements();
1951       for (unsigned j = 0, je = numElems; j != je; ++j) {
1952         SDValue TmpVal = theVal;
1953         if (TheValType.isVector())
1954           TmpVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
1955                                TheValType.getVectorElementType(), TmpVal,
1956                                DAG.getIntPtrConstant(j));
1957         EVT TheStoreType = ValVTs[i];
1958         if (RetTy->isIntegerTy() &&
1959             TD->getTypeAllocSizeInBits(RetTy) < 32) {
1960           // The following zero-extension is for integer types only, and
1961           // specifically not for aggregates.
1962           TmpVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, TmpVal);
1963           TheStoreType = MVT::i32;
1964         }
1965         else if (TmpVal.getValueType().getSizeInBits() < 16)
1966           TmpVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, TmpVal);
1967
1968         SDValue Ops[] = { Chain, DAG.getConstant(SizeSoFar, MVT::i32), TmpVal };
1969         Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
1970                                         DAG.getVTList(MVT::Other), &Ops[0],
1971                                         3, TheStoreType,
1972                                         MachinePointerInfo());
1973         if(TheValType.isVector())
1974           SizeSoFar += 
1975             TheStoreType.getVectorElementType().getStoreSizeInBits() / 8;
1976         else
1977           SizeSoFar += TheStoreType.getStoreSizeInBits()/8;
1978       }
1979     }
1980   }
1981
1982   return DAG.getNode(NVPTXISD::RET_FLAG, dl, MVT::Other, Chain);
1983 }
1984
1985
1986 void NVPTXTargetLowering::LowerAsmOperandForConstraint(
1987     SDValue Op, std::string &Constraint, std::vector<SDValue> &Ops,
1988     SelectionDAG &DAG) const {
1989   if (Constraint.length() > 1)
1990     return;
1991   else
1992     TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
1993 }
1994
1995 // NVPTX suuport vector of legal types of any length in Intrinsics because the
1996 // NVPTX specific type legalizer
1997 // will legalize them to the PTX supported length.
1998 bool NVPTXTargetLowering::isTypeSupportedInIntrinsic(MVT VT) const {
1999   if (isTypeLegal(VT))
2000     return true;
2001   if (VT.isVector()) {
2002     MVT eVT = VT.getVectorElementType();
2003     if (isTypeLegal(eVT))
2004       return true;
2005   }
2006   return false;
2007 }
2008
2009 // llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
2010 // TgtMemIntrinsic
2011 // because we need the information that is only available in the "Value" type
2012 // of destination
2013 // pointer. In particular, the address space information.
2014 bool NVPTXTargetLowering::getTgtMemIntrinsic(
2015     IntrinsicInfo &Info, const CallInst &I, unsigned Intrinsic) const {
2016   switch (Intrinsic) {
2017   default:
2018     return false;
2019
2020   case Intrinsic::nvvm_atomic_load_add_f32:
2021     Info.opc = ISD::INTRINSIC_W_CHAIN;
2022     Info.memVT = MVT::f32;
2023     Info.ptrVal = I.getArgOperand(0);
2024     Info.offset = 0;
2025     Info.vol = 0;
2026     Info.readMem = true;
2027     Info.writeMem = true;
2028     Info.align = 0;
2029     return true;
2030
2031   case Intrinsic::nvvm_atomic_load_inc_32:
2032   case Intrinsic::nvvm_atomic_load_dec_32:
2033     Info.opc = ISD::INTRINSIC_W_CHAIN;
2034     Info.memVT = MVT::i32;
2035     Info.ptrVal = I.getArgOperand(0);
2036     Info.offset = 0;
2037     Info.vol = 0;
2038     Info.readMem = true;
2039     Info.writeMem = true;
2040     Info.align = 0;
2041     return true;
2042
2043   case Intrinsic::nvvm_ldu_global_i:
2044   case Intrinsic::nvvm_ldu_global_f:
2045   case Intrinsic::nvvm_ldu_global_p:
2046
2047     Info.opc = ISD::INTRINSIC_W_CHAIN;
2048     if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
2049       Info.memVT = getValueType(I.getType());
2050     else if (Intrinsic == Intrinsic::nvvm_ldu_global_p)
2051       Info.memVT = getValueType(I.getType());
2052     else
2053       Info.memVT = MVT::f32;
2054     Info.ptrVal = I.getArgOperand(0);
2055     Info.offset = 0;
2056     Info.vol = 0;
2057     Info.readMem = true;
2058     Info.writeMem = false;
2059     Info.align = 0;
2060     return true;
2061
2062   }
2063   return false;
2064 }
2065
2066 /// isLegalAddressingMode - Return true if the addressing mode represented
2067 /// by AM is legal for this target, for a load/store of the specified type.
2068 /// Used to guide target specific optimizations, like loop strength reduction
2069 /// (LoopStrengthReduce.cpp) and memory optimization for address mode
2070 /// (CodeGenPrepare.cpp)
2071 bool NVPTXTargetLowering::isLegalAddressingMode(const AddrMode &AM,
2072                                                 Type *Ty) const {
2073
2074   // AddrMode - This represents an addressing mode of:
2075   //    BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
2076   //
2077   // The legal address modes are
2078   // - [avar]
2079   // - [areg]
2080   // - [areg+immoff]
2081   // - [immAddr]
2082
2083   if (AM.BaseGV) {
2084     if (AM.BaseOffs || AM.HasBaseReg || AM.Scale)
2085       return false;
2086     return true;
2087   }
2088
2089   switch (AM.Scale) {
2090   case 0: // "r", "r+i" or "i" is allowed
2091     break;
2092   case 1:
2093     if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
2094       return false;
2095     // Otherwise we have r+i.
2096     break;
2097   default:
2098     // No scale > 1 is allowed
2099     return false;
2100   }
2101   return true;
2102 }
2103
2104 //===----------------------------------------------------------------------===//
2105 //                         NVPTX Inline Assembly Support
2106 //===----------------------------------------------------------------------===//
2107
2108 /// getConstraintType - Given a constraint letter, return the type of
2109 /// constraint it is for this target.
2110 NVPTXTargetLowering::ConstraintType
2111 NVPTXTargetLowering::getConstraintType(const std::string &Constraint) const {
2112   if (Constraint.size() == 1) {
2113     switch (Constraint[0]) {
2114     default:
2115       break;
2116     case 'r':
2117     case 'h':
2118     case 'c':
2119     case 'l':
2120     case 'f':
2121     case 'd':
2122     case '0':
2123     case 'N':
2124       return C_RegisterClass;
2125     }
2126   }
2127   return TargetLowering::getConstraintType(Constraint);
2128 }
2129
2130 std::pair<unsigned, const TargetRegisterClass *>
2131 NVPTXTargetLowering::getRegForInlineAsmConstraint(const std::string &Constraint,
2132                                                   MVT VT) const {
2133   if (Constraint.size() == 1) {
2134     switch (Constraint[0]) {
2135     case 'c':
2136       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
2137     case 'h':
2138       return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
2139     case 'r':
2140       return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
2141     case 'l':
2142     case 'N':
2143       return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
2144     case 'f':
2145       return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
2146     case 'd':
2147       return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
2148     }
2149   }
2150   return TargetLowering::getRegForInlineAsmConstraint(Constraint, VT);
2151 }
2152
2153 /// getFunctionAlignment - Return the Log2 alignment of this function.
2154 unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const {
2155   return 4;
2156 }
2157
2158 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
2159 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
2160                               SmallVectorImpl<SDValue> &Results) {
2161   EVT ResVT = N->getValueType(0);
2162   SDLoc DL(N);
2163
2164   assert(ResVT.isVector() && "Vector load must have vector type");
2165
2166   // We only handle "native" vector sizes for now, e.g. <4 x double> is not
2167   // legal.  We can (and should) split that into 2 loads of <2 x double> here
2168   // but I'm leaving that as a TODO for now.
2169   assert(ResVT.isSimple() && "Can only handle simple types");
2170   switch (ResVT.getSimpleVT().SimpleTy) {
2171   default:
2172     return;
2173   case MVT::v2i8:
2174   case MVT::v2i16:
2175   case MVT::v2i32:
2176   case MVT::v2i64:
2177   case MVT::v2f32:
2178   case MVT::v2f64:
2179   case MVT::v4i8:
2180   case MVT::v4i16:
2181   case MVT::v4i32:
2182   case MVT::v4f32:
2183     // This is a "native" vector type
2184     break;
2185   }
2186
2187   EVT EltVT = ResVT.getVectorElementType();
2188   unsigned NumElts = ResVT.getVectorNumElements();
2189
2190   // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
2191   // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
2192   // loaded type to i16 and propogate the "real" type as the memory type.
2193   bool NeedTrunc = false;
2194   if (EltVT.getSizeInBits() < 16) {
2195     EltVT = MVT::i16;
2196     NeedTrunc = true;
2197   }
2198
2199   unsigned Opcode = 0;
2200   SDVTList LdResVTs;
2201
2202   switch (NumElts) {
2203   default:
2204     return;
2205   case 2:
2206     Opcode = NVPTXISD::LoadV2;
2207     LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
2208     break;
2209   case 4: {
2210     Opcode = NVPTXISD::LoadV4;
2211     EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
2212     LdResVTs = DAG.getVTList(ListVTs, 5);
2213     break;
2214   }
2215   }
2216
2217   SmallVector<SDValue, 8> OtherOps;
2218
2219   // Copy regular operands
2220   for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2221     OtherOps.push_back(N->getOperand(i));
2222
2223   LoadSDNode *LD = cast<LoadSDNode>(N);
2224
2225   // The select routine does not have access to the LoadSDNode instance, so
2226   // pass along the extension information
2227   OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType()));
2228
2229   SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, &OtherOps[0],
2230                                           OtherOps.size(), LD->getMemoryVT(),
2231                                           LD->getMemOperand());
2232
2233   SmallVector<SDValue, 4> ScalarRes;
2234
2235   for (unsigned i = 0; i < NumElts; ++i) {
2236     SDValue Res = NewLD.getValue(i);
2237     if (NeedTrunc)
2238       Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
2239     ScalarRes.push_back(Res);
2240   }
2241
2242   SDValue LoadChain = NewLD.getValue(NumElts);
2243
2244   SDValue BuildVec =
2245       DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
2246
2247   Results.push_back(BuildVec);
2248   Results.push_back(LoadChain);
2249 }
2250
2251 static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
2252                                      SmallVectorImpl<SDValue> &Results) {
2253   SDValue Chain = N->getOperand(0);
2254   SDValue Intrin = N->getOperand(1);
2255   SDLoc DL(N);
2256
2257   // Get the intrinsic ID
2258   unsigned IntrinNo = cast<ConstantSDNode>(Intrin.getNode())->getZExtValue();
2259   switch (IntrinNo) {
2260   default:
2261     return;
2262   case Intrinsic::nvvm_ldg_global_i:
2263   case Intrinsic::nvvm_ldg_global_f:
2264   case Intrinsic::nvvm_ldg_global_p:
2265   case Intrinsic::nvvm_ldu_global_i:
2266   case Intrinsic::nvvm_ldu_global_f:
2267   case Intrinsic::nvvm_ldu_global_p: {
2268     EVT ResVT = N->getValueType(0);
2269
2270     if (ResVT.isVector()) {
2271       // Vector LDG/LDU
2272
2273       unsigned NumElts = ResVT.getVectorNumElements();
2274       EVT EltVT = ResVT.getVectorElementType();
2275
2276       // Since LDU/LDG are target nodes, we cannot rely on DAG type
2277       // legalization.
2278       // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
2279       // loaded type to i16 and propogate the "real" type as the memory type.
2280       bool NeedTrunc = false;
2281       if (EltVT.getSizeInBits() < 16) {
2282         EltVT = MVT::i16;
2283         NeedTrunc = true;
2284       }
2285
2286       unsigned Opcode = 0;
2287       SDVTList LdResVTs;
2288
2289       switch (NumElts) {
2290       default:
2291         return;
2292       case 2:
2293         switch (IntrinNo) {
2294         default:
2295           return;
2296         case Intrinsic::nvvm_ldg_global_i:
2297         case Intrinsic::nvvm_ldg_global_f:
2298         case Intrinsic::nvvm_ldg_global_p:
2299           Opcode = NVPTXISD::LDGV2;
2300           break;
2301         case Intrinsic::nvvm_ldu_global_i:
2302         case Intrinsic::nvvm_ldu_global_f:
2303         case Intrinsic::nvvm_ldu_global_p:
2304           Opcode = NVPTXISD::LDUV2;
2305           break;
2306         }
2307         LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
2308         break;
2309       case 4: {
2310         switch (IntrinNo) {
2311         default:
2312           return;
2313         case Intrinsic::nvvm_ldg_global_i:
2314         case Intrinsic::nvvm_ldg_global_f:
2315         case Intrinsic::nvvm_ldg_global_p:
2316           Opcode = NVPTXISD::LDGV4;
2317           break;
2318         case Intrinsic::nvvm_ldu_global_i:
2319         case Intrinsic::nvvm_ldu_global_f:
2320         case Intrinsic::nvvm_ldu_global_p:
2321           Opcode = NVPTXISD::LDUV4;
2322           break;
2323         }
2324         EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
2325         LdResVTs = DAG.getVTList(ListVTs, 5);
2326         break;
2327       }
2328       }
2329
2330       SmallVector<SDValue, 8> OtherOps;
2331
2332       // Copy regular operands
2333
2334       OtherOps.push_back(Chain); // Chain
2335                                  // Skip operand 1 (intrinsic ID)
2336       // Others
2337       for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i)
2338         OtherOps.push_back(N->getOperand(i));
2339
2340       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2341
2342       SDValue NewLD = DAG.getMemIntrinsicNode(
2343           Opcode, DL, LdResVTs, &OtherOps[0], OtherOps.size(),
2344           MemSD->getMemoryVT(), MemSD->getMemOperand());
2345
2346       SmallVector<SDValue, 4> ScalarRes;
2347
2348       for (unsigned i = 0; i < NumElts; ++i) {
2349         SDValue Res = NewLD.getValue(i);
2350         if (NeedTrunc)
2351           Res =
2352               DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
2353         ScalarRes.push_back(Res);
2354       }
2355
2356       SDValue LoadChain = NewLD.getValue(NumElts);
2357
2358       SDValue BuildVec =
2359           DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
2360
2361       Results.push_back(BuildVec);
2362       Results.push_back(LoadChain);
2363     } else {
2364       // i8 LDG/LDU
2365       assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
2366              "Custom handling of non-i8 ldu/ldg?");
2367
2368       // Just copy all operands as-is
2369       SmallVector<SDValue, 4> Ops;
2370       for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2371         Ops.push_back(N->getOperand(i));
2372
2373       // Force output to i16
2374       SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
2375
2376       MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2377
2378       // We make sure the memory type is i8, which will be used during isel
2379       // to select the proper instruction.
2380       SDValue NewLD =
2381           DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, LdResVTs, &Ops[0],
2382                                   Ops.size(), MVT::i8, MemSD->getMemOperand());
2383
2384       Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
2385                                     NewLD.getValue(0)));
2386       Results.push_back(NewLD.getValue(1));
2387     }
2388   }
2389   }
2390 }
2391
2392 void NVPTXTargetLowering::ReplaceNodeResults(
2393     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
2394   switch (N->getOpcode()) {
2395   default:
2396     report_fatal_error("Unhandled custom legalization");
2397   case ISD::LOAD:
2398     ReplaceLoadVector(N, DAG, Results);
2399     return;
2400   case ISD::INTRINSIC_W_CHAIN:
2401     ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
2402     return;
2403   }
2404 }