2 // The LLVM Compiler Infrastructure
4 // This file is distributed under the University of Illinois Open Source
5 // License. See LICENSE.TXT for details.
7 //===----------------------------------------------------------------------===//
9 // This file defines the interfaces that NVPTX uses to lower LLVM code into a
12 //===----------------------------------------------------------------------===//
14 #include "NVPTXISelLowering.h"
16 #include "NVPTXTargetMachine.h"
17 #include "NVPTXTargetObjectFile.h"
18 #include "NVPTXUtilities.h"
19 #include "llvm/CodeGen/Analysis.h"
20 #include "llvm/CodeGen/MachineFrameInfo.h"
21 #include "llvm/CodeGen/MachineFunction.h"
22 #include "llvm/CodeGen/MachineInstrBuilder.h"
23 #include "llvm/CodeGen/MachineRegisterInfo.h"
24 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/GlobalValue.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/MC/MCSectionELF.h"
32 #include "llvm/Support/CallSite.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include "llvm/Support/raw_ostream.h"
40 #define DEBUG_TYPE "nvptx-lower"
44 static unsigned int uniqueCallSite = 0;
46 static cl::opt<bool> sched4reg(
48 cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
50 static bool IsPTXVectorType(MVT VT) {
51 switch (VT.SimpleTy) {
68 // NVPTXTargetLowering Constructor.
69 NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
70 : TargetLowering(TM, new NVPTXTargetObjectFile()), nvTM(&TM),
71 nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) {
73 // always lower memset, memcpy, and memmove intrinsics to load/store
74 // instructions, rather
75 // then generating calls to memset, mempcy or memmove.
76 MaxStoresPerMemset = (unsigned) 0xFFFFFFFF;
77 MaxStoresPerMemcpy = (unsigned) 0xFFFFFFFF;
78 MaxStoresPerMemmove = (unsigned) 0xFFFFFFFF;
80 setBooleanContents(ZeroOrNegativeOneBooleanContent);
82 // Jump is Expensive. Don't create extra control flow for 'and', 'or'
83 // condition branches.
84 setJumpIsExpensive(true);
86 // By default, use the Source scheduling
88 setSchedulingPreference(Sched::RegPressure);
90 setSchedulingPreference(Sched::Source);
92 addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
93 addRegisterClass(MVT::i8, &NVPTX::Int8RegsRegClass);
94 addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
95 addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
96 addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
97 addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
98 addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
100 // Operations not directly supported by NVPTX.
101 setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
102 setOperationAction(ISD::BR_CC, MVT::f32, Expand);
103 setOperationAction(ISD::BR_CC, MVT::f64, Expand);
104 setOperationAction(ISD::BR_CC, MVT::i1, Expand);
105 setOperationAction(ISD::BR_CC, MVT::i8, Expand);
106 setOperationAction(ISD::BR_CC, MVT::i16, Expand);
107 setOperationAction(ISD::BR_CC, MVT::i32, Expand);
108 setOperationAction(ISD::BR_CC, MVT::i64, Expand);
109 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Expand);
110 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Expand);
111 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Expand);
112 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8, Expand);
113 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
115 if (nvptxSubtarget.hasROT64()) {
116 setOperationAction(ISD::ROTL, MVT::i64, Legal);
117 setOperationAction(ISD::ROTR, MVT::i64, Legal);
119 setOperationAction(ISD::ROTL, MVT::i64, Expand);
120 setOperationAction(ISD::ROTR, MVT::i64, Expand);
122 if (nvptxSubtarget.hasROT32()) {
123 setOperationAction(ISD::ROTL, MVT::i32, Legal);
124 setOperationAction(ISD::ROTR, MVT::i32, Legal);
126 setOperationAction(ISD::ROTL, MVT::i32, Expand);
127 setOperationAction(ISD::ROTR, MVT::i32, Expand);
130 setOperationAction(ISD::ROTL, MVT::i16, Expand);
131 setOperationAction(ISD::ROTR, MVT::i16, Expand);
132 setOperationAction(ISD::ROTL, MVT::i8, Expand);
133 setOperationAction(ISD::ROTR, MVT::i8, Expand);
134 setOperationAction(ISD::BSWAP, MVT::i16, Expand);
135 setOperationAction(ISD::BSWAP, MVT::i32, Expand);
136 setOperationAction(ISD::BSWAP, MVT::i64, Expand);
138 // Indirect branch is not supported.
139 // This also disables Jump Table creation.
140 setOperationAction(ISD::BR_JT, MVT::Other, Expand);
141 setOperationAction(ISD::BRIND, MVT::Other, Expand);
143 setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
144 setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
146 // We want to legalize constant related memmove and memcopy
148 setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
150 // Turn FP extload into load/fextend
151 setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
152 // Turn FP truncstore into trunc + store.
153 setTruncStoreAction(MVT::f64, MVT::f32, Expand);
155 // PTX does not support load / store predicate registers
156 setOperationAction(ISD::LOAD, MVT::i1, Custom);
157 setOperationAction(ISD::STORE, MVT::i1, Custom);
159 setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
160 setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
161 setTruncStoreAction(MVT::i64, MVT::i1, Expand);
162 setTruncStoreAction(MVT::i32, MVT::i1, Expand);
163 setTruncStoreAction(MVT::i16, MVT::i1, Expand);
164 setTruncStoreAction(MVT::i8, MVT::i1, Expand);
166 // This is legal in NVPTX
167 setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
168 setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
170 // TRAP can be lowered to PTX trap
171 setOperationAction(ISD::TRAP, MVT::Other, Legal);
173 // Register custom handling for vector loads/stores
174 for (int i = MVT::FIRST_VECTOR_VALUETYPE; i <= MVT::LAST_VECTOR_VALUETYPE;
176 MVT VT = (MVT::SimpleValueType) i;
177 if (IsPTXVectorType(VT)) {
178 setOperationAction(ISD::LOAD, VT, Custom);
179 setOperationAction(ISD::STORE, VT, Custom);
180 setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom);
184 // Now deduce the information based on the above mentioned
186 computeRegisterProperties();
189 const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
194 return "NVPTXISD::CALL";
195 case NVPTXISD::RET_FLAG:
196 return "NVPTXISD::RET_FLAG";
197 case NVPTXISD::Wrapper:
198 return "NVPTXISD::Wrapper";
199 case NVPTXISD::NVBuiltin:
200 return "NVPTXISD::NVBuiltin";
201 case NVPTXISD::DeclareParam:
202 return "NVPTXISD::DeclareParam";
203 case NVPTXISD::DeclareScalarParam:
204 return "NVPTXISD::DeclareScalarParam";
205 case NVPTXISD::DeclareRet:
206 return "NVPTXISD::DeclareRet";
207 case NVPTXISD::DeclareRetParam:
208 return "NVPTXISD::DeclareRetParam";
209 case NVPTXISD::PrintCall:
210 return "NVPTXISD::PrintCall";
211 case NVPTXISD::LoadParam:
212 return "NVPTXISD::LoadParam";
213 case NVPTXISD::StoreParam:
214 return "NVPTXISD::StoreParam";
215 case NVPTXISD::StoreParamS32:
216 return "NVPTXISD::StoreParamS32";
217 case NVPTXISD::StoreParamU32:
218 return "NVPTXISD::StoreParamU32";
219 case NVPTXISD::MoveToParam:
220 return "NVPTXISD::MoveToParam";
221 case NVPTXISD::CallArgBegin:
222 return "NVPTXISD::CallArgBegin";
223 case NVPTXISD::CallArg:
224 return "NVPTXISD::CallArg";
225 case NVPTXISD::LastCallArg:
226 return "NVPTXISD::LastCallArg";
227 case NVPTXISD::CallArgEnd:
228 return "NVPTXISD::CallArgEnd";
229 case NVPTXISD::CallVoid:
230 return "NVPTXISD::CallVoid";
231 case NVPTXISD::CallVal:
232 return "NVPTXISD::CallVal";
233 case NVPTXISD::CallSymbol:
234 return "NVPTXISD::CallSymbol";
235 case NVPTXISD::Prototype:
236 return "NVPTXISD::Prototype";
237 case NVPTXISD::MoveParam:
238 return "NVPTXISD::MoveParam";
239 case NVPTXISD::MoveRetval:
240 return "NVPTXISD::MoveRetval";
241 case NVPTXISD::MoveToRetval:
242 return "NVPTXISD::MoveToRetval";
243 case NVPTXISD::StoreRetval:
244 return "NVPTXISD::StoreRetval";
245 case NVPTXISD::PseudoUseParam:
246 return "NVPTXISD::PseudoUseParam";
247 case NVPTXISD::RETURN:
248 return "NVPTXISD::RETURN";
249 case NVPTXISD::CallSeqBegin:
250 return "NVPTXISD::CallSeqBegin";
251 case NVPTXISD::CallSeqEnd:
252 return "NVPTXISD::CallSeqEnd";
253 case NVPTXISD::LoadV2:
254 return "NVPTXISD::LoadV2";
255 case NVPTXISD::LoadV4:
256 return "NVPTXISD::LoadV4";
257 case NVPTXISD::LDGV2:
258 return "NVPTXISD::LDGV2";
259 case NVPTXISD::LDGV4:
260 return "NVPTXISD::LDGV4";
261 case NVPTXISD::LDUV2:
262 return "NVPTXISD::LDUV2";
263 case NVPTXISD::LDUV4:
264 return "NVPTXISD::LDUV4";
265 case NVPTXISD::StoreV2:
266 return "NVPTXISD::StoreV2";
267 case NVPTXISD::StoreV4:
268 return "NVPTXISD::StoreV4";
272 bool NVPTXTargetLowering::shouldSplitVectorElementType(EVT VT) const {
273 return VT == MVT::i1;
277 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
278 DebugLoc dl = Op.getDebugLoc();
279 const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
280 Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
281 return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op);
284 std::string NVPTXTargetLowering::getPrototype(
285 Type *retTy, const ArgListTy &Args,
286 const SmallVectorImpl<ISD::OutputArg> &Outs, unsigned retAlignment) const {
288 bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
291 O << "prototype_" << uniqueCallSite << " : .callprototype ";
293 if (retTy->getTypeID() == Type::VoidTyID)
298 if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
300 if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
301 size = ITy->getBitWidth();
305 assert(retTy->isFloatingPointTy() &&
306 "Floating point type expected here");
307 size = retTy->getPrimitiveSizeInBits();
310 O << ".param .b" << size << " _";
311 } else if (isa<PointerType>(retTy))
312 O << ".param .b" << getPointerTy().getSizeInBits() << " _";
314 if ((retTy->getTypeID() == Type::StructTyID) ||
315 isa<VectorType>(retTy)) {
316 SmallVector<EVT, 16> vtparts;
317 ComputeValueVTs(*this, retTy, vtparts);
318 unsigned totalsz = 0;
319 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
321 EVT elemtype = vtparts[i];
322 if (vtparts[i].isVector()) {
323 elems = vtparts[i].getVectorNumElements();
324 elemtype = vtparts[i].getVectorElementType();
326 for (unsigned j = 0, je = elems; j != je; ++j) {
327 unsigned sz = elemtype.getSizeInBits();
328 if (elemtype.isInteger() && (sz < 8))
333 O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
335 assert(false && "Unknown return type");
339 SmallVector<EVT, 16> vtparts;
340 ComputeValueVTs(*this, retTy, vtparts);
342 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
344 EVT elemtype = vtparts[i];
345 if (vtparts[i].isVector()) {
346 elems = vtparts[i].getVectorNumElements();
347 elemtype = vtparts[i].getVectorElementType();
350 for (unsigned j = 0, je = elems; j != je; ++j) {
351 unsigned sz = elemtype.getSizeInBits();
352 if (elemtype.isInteger() && (sz < 32))
354 O << ".reg .b" << sz << " _";
368 MVT thePointerTy = getPointerTy();
370 for (unsigned i = 0, e = Args.size(); i != e; ++i) {
371 const Type *Ty = Args[i].Ty;
377 if (Outs[i].Flags.isByVal() == false) {
379 if (isa<IntegerType>(Ty)) {
380 sz = cast<IntegerType>(Ty)->getBitWidth();
383 } else if (isa<PointerType>(Ty))
384 sz = thePointerTy.getSizeInBits();
386 sz = Ty->getPrimitiveSizeInBits();
388 O << ".param .b" << sz << " ";
390 O << ".reg .b" << sz << " ";
394 const PointerType *PTy = dyn_cast<PointerType>(Ty);
395 assert(PTy && "Param with byval attribute should be a pointer type");
396 Type *ETy = PTy->getElementType();
399 unsigned align = Outs[i].Flags.getByValAlign();
400 unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
401 O << ".param .align " << align << " .b8 ";
403 O << "[" << sz << "]";
406 SmallVector<EVT, 16> vtparts;
407 ComputeValueVTs(*this, ETy, vtparts);
408 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
410 EVT elemtype = vtparts[i];
411 if (vtparts[i].isVector()) {
412 elems = vtparts[i].getVectorNumElements();
413 elemtype = vtparts[i].getVectorElementType();
416 for (unsigned j = 0, je = elems; j != je; ++j) {
417 unsigned sz = elemtype.getSizeInBits();
418 if (elemtype.isInteger() && (sz < 32))
420 O << ".reg .b" << sz << " ";
435 SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
436 SmallVectorImpl<SDValue> &InVals) const {
437 SelectionDAG &DAG = CLI.DAG;
438 DebugLoc &dl = CLI.DL;
439 SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
440 SmallVector<SDValue, 32> &OutVals = CLI.OutVals;
441 SmallVector<ISD::InputArg, 32> &Ins = CLI.Ins;
442 SDValue Chain = CLI.Chain;
443 SDValue Callee = CLI.Callee;
444 bool &isTailCall = CLI.IsTailCall;
445 ArgListTy &Args = CLI.Args;
446 Type *retTy = CLI.RetTy;
447 ImmutableCallSite *CS = CLI.CS;
449 bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
451 SDValue tempChain = Chain;
453 DAG.getCALLSEQ_START(Chain, DAG.getIntPtrConstant(uniqueCallSite, true));
454 SDValue InFlag = Chain.getValue(1);
456 assert((Outs.size() == Args.size()) &&
457 "Unexpected number of arguments to function call");
458 unsigned paramCount = 0;
459 // Declare the .params or .reg need to pass values
461 for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
464 if (Outs[i].Flags.isByVal() == false) {
466 // for ABI, declare .param .b<size> .param<n>;
467 // for nonABI, declare .reg .b<size> .param<n>;
471 unsigned sz = VT.getSizeInBits();
472 if (VT.isInteger() && (sz < 32))
474 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
475 SDValue DeclareParamOps[] = { Chain,
476 DAG.getConstant(paramCount, MVT::i32),
477 DAG.getConstant(sz, MVT::i32),
478 DAG.getConstant(isReg, MVT::i32), InFlag };
479 Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
481 InFlag = Chain.getValue(1);
482 SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
483 SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
484 DAG.getConstant(0, MVT::i32), OutVals[i],
487 unsigned opcode = NVPTXISD::StoreParam;
489 opcode = NVPTXISD::MoveToParam;
491 if (Outs[i].Flags.isZExt())
492 opcode = NVPTXISD::StoreParamU32;
493 else if (Outs[i].Flags.isSExt())
494 opcode = NVPTXISD::StoreParamS32;
496 Chain = DAG.getNode(opcode, dl, CopyParamVTs, CopyParamOps, 5);
498 InFlag = Chain.getValue(1);
503 SmallVector<EVT, 16> vtparts;
504 const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
505 assert(PTy && "Type of a byval parameter should be pointer");
506 ComputeValueVTs(*this, PTy->getElementType(), vtparts);
509 // declare .param .align 16 .b8 .param<n>[<size>];
510 unsigned sz = Outs[i].Flags.getByValSize();
511 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
512 // The ByValAlign in the Outs[i].Flags is alway set at this point, so we
514 // worry about natural alignment or not. See TargetLowering::LowerCallTo()
515 SDValue DeclareParamOps[] = {
516 Chain, DAG.getConstant(Outs[i].Flags.getByValAlign(), MVT::i32),
517 DAG.getConstant(paramCount, MVT::i32), DAG.getConstant(sz, MVT::i32),
520 Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
522 InFlag = Chain.getValue(1);
523 unsigned curOffset = 0;
524 for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
526 EVT elemtype = vtparts[j];
527 if (vtparts[j].isVector()) {
528 elems = vtparts[j].getVectorNumElements();
529 elemtype = vtparts[j].getVectorElementType();
531 for (unsigned k = 0, ke = elems; k != ke; ++k) {
532 unsigned sz = elemtype.getSizeInBits();
533 if (elemtype.isInteger() && (sz < 8))
536 DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[i],
537 DAG.getConstant(curOffset, getPointerTy()));
539 DAG.getLoad(elemtype, dl, tempChain, srcAddr,
540 MachinePointerInfo(), false, false, false, 0);
541 SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
542 SDValue CopyParamOps[] = { Chain,
543 DAG.getConstant(paramCount, MVT::i32),
544 DAG.getConstant(curOffset, MVT::i32),
546 Chain = DAG.getNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
548 InFlag = Chain.getValue(1);
555 // Non-abi, struct or vector
556 // Declare a bunch or .reg .b<size> .param<n>
557 unsigned curOffset = 0;
558 for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
560 EVT elemtype = vtparts[j];
561 if (vtparts[j].isVector()) {
562 elems = vtparts[j].getVectorNumElements();
563 elemtype = vtparts[j].getVectorElementType();
565 for (unsigned k = 0, ke = elems; k != ke; ++k) {
566 unsigned sz = elemtype.getSizeInBits();
567 if (elemtype.isInteger() && (sz < 32))
569 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
570 SDValue DeclareParamOps[] = { Chain,
571 DAG.getConstant(paramCount, MVT::i32),
572 DAG.getConstant(sz, MVT::i32),
573 DAG.getConstant(1, MVT::i32), InFlag };
574 Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
576 InFlag = Chain.getValue(1);
578 DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[i],
579 DAG.getConstant(curOffset, getPointerTy()));
581 DAG.getLoad(elemtype, dl, tempChain, srcAddr, MachinePointerInfo(),
582 false, false, false, 0);
583 SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
584 SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
585 DAG.getConstant(0, MVT::i32), theVal,
587 Chain = DAG.getNode(NVPTXISD::MoveToParam, dl, CopyParamVTs,
589 InFlag = Chain.getValue(1);
595 GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
596 unsigned retAlignment = 0;
599 unsigned retCount = 0;
600 if (Ins.size() > 0) {
601 SmallVector<EVT, 16> resvtparts;
602 ComputeValueVTs(*this, retTy, resvtparts);
604 // Declare one .param .align 16 .b8 func_retval0[<size>] for ABI or
605 // individual .reg .b<size> func_retval<0..> for non ABI
606 unsigned resultsz = 0;
607 for (unsigned i = 0, e = resvtparts.size(); i != e; ++i) {
609 EVT elemtype = resvtparts[i];
610 if (resvtparts[i].isVector()) {
611 elems = resvtparts[i].getVectorNumElements();
612 elemtype = resvtparts[i].getVectorElementType();
614 for (unsigned j = 0, je = elems; j != je; ++j) {
615 unsigned sz = elemtype.getSizeInBits();
616 if (isABI == false) {
617 if (elemtype.isInteger() && (sz < 32))
620 if (elemtype.isInteger() && (sz < 8))
623 if (isABI == false) {
624 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
625 SDValue DeclareRetOps[] = { Chain, DAG.getConstant(2, MVT::i32),
626 DAG.getConstant(sz, MVT::i32),
627 DAG.getConstant(retCount, MVT::i32),
629 Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
631 InFlag = Chain.getValue(1);
638 if (retTy->isPrimitiveType() || retTy->isIntegerTy() ||
639 retTy->isPointerTy()) {
640 // Scalar needs to be at least 32bit wide
643 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
644 SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
645 DAG.getConstant(resultsz, MVT::i32),
646 DAG.getConstant(0, MVT::i32), InFlag };
647 Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
649 InFlag = Chain.getValue(1);
651 if (Func) { // direct call
652 if (!llvm::getAlign(*(CS->getCalledFunction()), 0, retAlignment))
653 retAlignment = getDataLayout()->getABITypeAlignment(retTy);
654 } else { // indirect call
655 const CallInst *CallI = dyn_cast<CallInst>(CS->getInstruction());
656 if (!llvm::getAlign(*CallI, 0, retAlignment))
657 retAlignment = getDataLayout()->getABITypeAlignment(retTy);
659 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
660 SDValue DeclareRetOps[] = { Chain,
661 DAG.getConstant(retAlignment, MVT::i32),
662 DAG.getConstant(resultsz / 8, MVT::i32),
663 DAG.getConstant(0, MVT::i32), InFlag };
664 Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
666 InFlag = Chain.getValue(1);
672 // This is indirect function call case : PTX requires a prototype of the
674 // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
675 // to be emitted, and the label has to used as the last arg of call
677 // The prototype is embedded in a string and put as the operand for an
679 SDVTList InlineAsmVTs = DAG.getVTList(MVT::Other, MVT::Glue);
680 std::string proto_string = getPrototype(retTy, Args, Outs, retAlignment);
681 const char *asmstr = nvTM->getManagedStrPool()
682 ->getManagedString(proto_string.c_str())->c_str();
683 SDValue InlineAsmOps[] = {
684 Chain, DAG.getTargetExternalSymbol(asmstr, getPointerTy()),
685 DAG.getMDNode(0), DAG.getTargetConstant(0, MVT::i32), InFlag
687 Chain = DAG.getNode(ISD::INLINEASM, dl, InlineAsmVTs, InlineAsmOps, 5);
688 InFlag = Chain.getValue(1);
690 // Op to just print "call"
691 SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
692 SDValue PrintCallOps[] = {
694 DAG.getConstant(isABI ? ((Ins.size() == 0) ? 0 : 1) : retCount, MVT::i32),
697 Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall),
698 dl, PrintCallVTs, PrintCallOps, 3);
699 InFlag = Chain.getValue(1);
701 // Ops to print out the function name
702 SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
703 SDValue CallVoidOps[] = { Chain, Callee, InFlag };
704 Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps, 3);
705 InFlag = Chain.getValue(1);
707 // Ops to print out the param list
708 SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
709 SDValue CallArgBeginOps[] = { Chain, InFlag };
710 Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
712 InFlag = Chain.getValue(1);
714 for (unsigned i = 0, e = paramCount; i != e; ++i) {
717 opcode = NVPTXISD::LastCallArg;
719 opcode = NVPTXISD::CallArg;
720 SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
721 SDValue CallArgOps[] = { Chain, DAG.getConstant(1, MVT::i32),
722 DAG.getConstant(i, MVT::i32), InFlag };
723 Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps, 4);
724 InFlag = Chain.getValue(1);
726 SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
727 SDValue CallArgEndOps[] = { Chain, DAG.getConstant(Func ? 1 : 0, MVT::i32),
730 DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps, 3);
731 InFlag = Chain.getValue(1);
734 SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
735 SDValue PrototypeOps[] = { Chain, DAG.getConstant(uniqueCallSite, MVT::i32),
737 Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps, 3);
738 InFlag = Chain.getValue(1);
741 // Generate loads from param memory/moves from registers for result
742 if (Ins.size() > 0) {
744 unsigned resoffset = 0;
745 for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
746 unsigned sz = Ins[i].VT.getSizeInBits();
747 if (Ins[i].VT.isInteger() && (sz < 8))
749 EVT LoadRetVTs[] = { Ins[i].VT, MVT::Other, MVT::Glue };
750 SDValue LoadRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
751 DAG.getConstant(resoffset, MVT::i32), InFlag };
752 SDValue retval = DAG.getNode(NVPTXISD::LoadParam, dl, LoadRetVTs,
753 LoadRetOps, array_lengthof(LoadRetOps));
754 Chain = retval.getValue(1);
755 InFlag = retval.getValue(2);
756 InVals.push_back(retval);
760 SmallVector<EVT, 16> resvtparts;
761 ComputeValueVTs(*this, retTy, resvtparts);
763 assert(Ins.size() == resvtparts.size() &&
764 "Unexpected number of return values in non-ABI case");
765 unsigned paramNum = 0;
766 for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
767 assert(EVT(Ins[i].VT) == resvtparts[i] &&
768 "Unexpected EVT type in non-ABI case");
769 unsigned numelems = 1;
770 EVT elemtype = Ins[i].VT;
771 if (Ins[i].VT.isVector()) {
772 numelems = Ins[i].VT.getVectorNumElements();
773 elemtype = Ins[i].VT.getVectorElementType();
775 std::vector<SDValue> tempRetVals;
776 for (unsigned j = 0; j < numelems; ++j) {
777 EVT MoveRetVTs[] = { elemtype, MVT::Other, MVT::Glue };
778 SDValue MoveRetOps[] = { Chain, DAG.getConstant(0, MVT::i32),
779 DAG.getConstant(paramNum, MVT::i32),
781 SDValue retval = DAG.getNode(NVPTXISD::LoadParam, dl, MoveRetVTs,
782 MoveRetOps, array_lengthof(MoveRetOps));
783 Chain = retval.getValue(1);
784 InFlag = retval.getValue(2);
785 tempRetVals.push_back(retval);
788 if (Ins[i].VT.isVector())
789 InVals.push_back(DAG.getNode(ISD::BUILD_VECTOR, dl, Ins[i].VT,
790 &tempRetVals[0], tempRetVals.size()));
792 InVals.push_back(tempRetVals[0]);
796 Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
797 DAG.getIntPtrConstant(uniqueCallSite + 1, true),
801 // set isTailCall to false for now, until we figure out how to express
802 // tail call optimization in PTX
807 // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
808 // (see LegalizeDAG.cpp). This is slow and uses local memory.
809 // We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
811 NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
812 SDNode *Node = Op.getNode();
813 DebugLoc dl = Node->getDebugLoc();
814 SmallVector<SDValue, 8> Ops;
815 unsigned NumOperands = Node->getNumOperands();
816 for (unsigned i = 0; i < NumOperands; ++i) {
817 SDValue SubOp = Node->getOperand(i);
818 EVT VVT = SubOp.getNode()->getValueType(0);
819 EVT EltVT = VVT.getVectorElementType();
820 unsigned NumSubElem = VVT.getVectorNumElements();
821 for (unsigned j = 0; j < NumSubElem; ++j) {
822 Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
823 DAG.getIntPtrConstant(j)));
826 return DAG.getNode(ISD::BUILD_VECTOR, dl, Node->getValueType(0), &Ops[0],
831 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
832 switch (Op.getOpcode()) {
833 case ISD::RETURNADDR:
837 case ISD::GlobalAddress:
838 return LowerGlobalAddress(Op, DAG);
839 case ISD::INTRINSIC_W_CHAIN:
841 case ISD::BUILD_VECTOR:
842 case ISD::EXTRACT_SUBVECTOR:
844 case ISD::CONCAT_VECTORS:
845 return LowerCONCAT_VECTORS(Op, DAG);
847 return LowerSTORE(Op, DAG);
849 return LowerLOAD(Op, DAG);
851 llvm_unreachable("Custom lowering not defined for operation");
855 SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
856 if (Op.getValueType() == MVT::i1)
857 return LowerLOADi1(Op, DAG);
865 // v = trunc v1 to i1
866 SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
867 SDNode *Node = Op.getNode();
868 LoadSDNode *LD = cast<LoadSDNode>(Node);
869 DebugLoc dl = Node->getDebugLoc();
870 assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
871 assert(Node->getValueType(0) == MVT::i1 &&
872 "Custom lowering for i1 load only");
874 DAG.getLoad(MVT::i8, dl, LD->getChain(), LD->getBasePtr(),
875 LD->getPointerInfo(), LD->isVolatile(), LD->isNonTemporal(),
876 LD->isInvariant(), LD->getAlignment());
877 SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
878 // The legalizer (the caller) is expecting two values from the legalized
879 // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
880 // in LegalizeDAG.cpp which also uses MergeValues.
881 SDValue Ops[] = { result, LD->getChain() };
882 return DAG.getMergeValues(Ops, 2, dl);
885 SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
886 EVT ValVT = Op.getOperand(1).getValueType();
887 if (ValVT == MVT::i1)
888 return LowerSTOREi1(Op, DAG);
889 else if (ValVT.isVector())
890 return LowerSTOREVector(Op, DAG);
896 NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
897 SDNode *N = Op.getNode();
898 SDValue Val = N->getOperand(1);
899 DebugLoc DL = N->getDebugLoc();
900 EVT ValVT = Val.getValueType();
902 if (ValVT.isVector()) {
903 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
904 // legal. We can (and should) split that into 2 stores of <2 x double> here
905 // but I'm leaving that as a TODO for now.
906 if (!ValVT.isSimple())
908 switch (ValVT.getSimpleVT().SimpleTy) {
921 // This is a "native" vector type
926 EVT EltVT = ValVT.getVectorElementType();
927 unsigned NumElts = ValVT.getVectorNumElements();
929 // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
930 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
931 // stored type to i16 and propogate the "real" type as the memory type.
932 bool NeedExt = false;
933 if (EltVT.getSizeInBits() < 16)
940 Opcode = NVPTXISD::StoreV2;
943 Opcode = NVPTXISD::StoreV4;
948 SmallVector<SDValue, 8> Ops;
950 // First is the chain
951 Ops.push_back(N->getOperand(0));
953 // Then the split values
954 for (unsigned i = 0; i < NumElts; ++i) {
955 SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
956 DAG.getIntPtrConstant(i));
958 // ANY_EXTEND is correct here since the store will only look at the
959 // lower-order bits anyway.
960 ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
961 Ops.push_back(ExtVal);
964 // Then any remaining arguments
965 for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i) {
966 Ops.push_back(N->getOperand(i));
969 MemSDNode *MemSD = cast<MemSDNode>(N);
971 SDValue NewSt = DAG.getMemIntrinsicNode(
972 Opcode, DL, DAG.getVTList(MVT::Other), &Ops[0], Ops.size(),
973 MemSD->getMemoryVT(), MemSD->getMemOperand());
975 //return DCI.CombineTo(N, NewSt, true);
986 SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
987 SDNode *Node = Op.getNode();
988 DebugLoc dl = Node->getDebugLoc();
989 StoreSDNode *ST = cast<StoreSDNode>(Node);
990 SDValue Tmp1 = ST->getChain();
991 SDValue Tmp2 = ST->getBasePtr();
992 SDValue Tmp3 = ST->getValue();
993 assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
994 unsigned Alignment = ST->getAlignment();
995 bool isVolatile = ST->isVolatile();
996 bool isNonTemporal = ST->isNonTemporal();
997 Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, Tmp3);
998 SDValue Result = DAG.getStore(Tmp1, dl, Tmp3, Tmp2, ST->getPointerInfo(),
999 isVolatile, isNonTemporal, Alignment);
1003 SDValue NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname,
1004 int idx, EVT v) const {
1005 std::string *name = nvTM->getManagedStrPool()->getManagedString(inname);
1006 std::stringstream suffix;
1008 *name += suffix.str();
1009 return DAG.getTargetExternalSymbol(name->c_str(), v);
1013 NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const {
1014 return getExtSymb(DAG, ".PARAM", idx, v);
1017 SDValue NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) {
1018 return getExtSymb(DAG, ".HLPPARAM", idx);
1021 // Check to see if the kernel argument is image*_t or sampler_t
1023 bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) {
1024 static const char *const specialTypes[] = { "struct._image2d_t",
1025 "struct._image3d_t",
1026 "struct._sampler_t" };
1028 const Type *Ty = arg->getType();
1029 const PointerType *PTy = dyn_cast<PointerType>(Ty);
1037 const StructType *STy = dyn_cast<StructType>(PTy->getElementType());
1038 const std::string TypeName = STy && !STy->isLiteral() ? STy->getName() : "";
1040 for (int i = 0, e = array_lengthof(specialTypes); i != e; ++i)
1041 if (TypeName == specialTypes[i])
1047 SDValue NVPTXTargetLowering::LowerFormalArguments(
1048 SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
1049 const SmallVectorImpl<ISD::InputArg> &Ins, DebugLoc dl, SelectionDAG &DAG,
1050 SmallVectorImpl<SDValue> &InVals) const {
1051 MachineFunction &MF = DAG.getMachineFunction();
1052 const DataLayout *TD = getDataLayout();
1054 const Function *F = MF.getFunction();
1055 const AttributeSet &PAL = F->getAttributes();
1057 SDValue Root = DAG.getRoot();
1058 std::vector<SDValue> OutChains;
1060 bool isKernel = llvm::isKernelFunction(*F);
1061 bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1063 std::vector<Type *> argTypes;
1064 std::vector<const Argument *> theArgs;
1065 for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end();
1067 theArgs.push_back(I);
1068 argTypes.push_back(I->getType());
1070 //assert(argTypes.size() == Ins.size() &&
1071 // "Ins types and function types did not match");
1074 for (unsigned i = 0, e = argTypes.size(); i != e; ++i, ++idx) {
1075 Type *Ty = argTypes[i];
1076 EVT ObjectVT = getValueType(Ty);
1077 //assert(ObjectVT == Ins[i].VT &&
1078 // "Ins type did not match function type");
1080 // If the kernel argument is image*_t or sampler_t, convert it to
1081 // a i32 constant holding the parameter position. This can later
1082 // matched in the AsmPrinter to output the correct mangled name.
1083 if (isImageOrSamplerVal(
1085 (theArgs[i]->getParent() ? theArgs[i]->getParent()->getParent()
1087 assert(isKernel && "Only kernels can have image/sampler params");
1088 InVals.push_back(DAG.getConstant(i + 1, MVT::i32));
1092 if (theArgs[i]->use_empty()) {
1094 if (ObjectVT.isVector()) {
1095 EVT EltVT = ObjectVT.getVectorElementType();
1096 unsigned NumElts = ObjectVT.getVectorNumElements();
1097 for (unsigned vi = 0; vi < NumElts; ++vi) {
1098 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, EltVT));
1101 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, ObjectVT));
1106 // In the following cases, assign a node order of "idx+1"
1107 // to newly created nodes. The SDNOdes for params have to
1108 // appear in the same order as their order of appearance
1109 // in the original function. "idx+1" holds that order.
1110 if (PAL.hasAttribute(i + 1, Attribute::ByVal) == false) {
1111 if (ObjectVT.isVector()) {
1112 unsigned NumElts = ObjectVT.getVectorNumElements();
1113 EVT EltVT = ObjectVT.getVectorElementType();
1114 unsigned Offset = 0;
1115 for (unsigned vi = 0; vi < NumElts; ++vi) {
1116 SDValue A = getParamSymbol(DAG, idx, getPointerTy());
1117 SDValue B = DAG.getIntPtrConstant(Offset);
1118 SDValue Addr = DAG.getNode(ISD::ADD, dl, getPointerTy(),
1119 //getParamSymbol(DAG, idx, EltVT),
1120 //DAG.getConstant(Offset, getPointerTy()));
1122 Value *SrcValue = Constant::getNullValue(PointerType::get(
1123 EltVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1124 SDValue Ld = DAG.getLoad(
1125 EltVT, dl, Root, Addr, MachinePointerInfo(SrcValue), false, false,
1127 TD->getABITypeAlignment(EltVT.getTypeForEVT(F->getContext())));
1128 Offset += EltVT.getStoreSizeInBits() / 8;
1129 InVals.push_back(Ld);
1135 if (isABI || isKernel) {
1136 // If ABI, load from the param symbol
1137 SDValue Arg = getParamSymbol(DAG, idx);
1138 // Conjure up a value that we can get the address space from.
1139 // FIXME: Using a constant here is a hack.
1140 Value *srcValue = Constant::getNullValue(
1141 PointerType::get(ObjectVT.getTypeForEVT(F->getContext()),
1142 llvm::ADDRESS_SPACE_PARAM));
1143 SDValue p = DAG.getLoad(
1144 ObjectVT, dl, Root, Arg, MachinePointerInfo(srcValue), false, false,
1146 TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1148 DAG.AssignOrdering(p.getNode(), idx + 1);
1149 InVals.push_back(p);
1151 // If no ABI, just move the param symbol
1152 SDValue Arg = getParamSymbol(DAG, idx, ObjectVT);
1153 SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1155 DAG.AssignOrdering(p.getNode(), idx + 1);
1156 InVals.push_back(p);
1161 // Param has ByVal attribute
1162 if (isABI || isKernel) {
1163 // Return MoveParam(param symbol).
1164 // Ideally, the param symbol can be returned directly,
1165 // but when SDNode builder decides to use it in a CopyToReg(),
1166 // machine instruction fails because TargetExternalSymbol
1167 // (not lowered) is target dependent, and CopyToReg assumes
1168 // the source is lowered.
1169 SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1170 SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1172 DAG.AssignOrdering(p.getNode(), idx + 1);
1174 InVals.push_back(p);
1176 SDValue p2 = DAG.getNode(
1177 ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT,
1178 DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32), p);
1179 InVals.push_back(p2);
1182 // Have to move a set of param symbols to registers and
1183 // store them locally and return the local pointer in InVals
1184 const PointerType *elemPtrType = dyn_cast<PointerType>(argTypes[i]);
1185 assert(elemPtrType && "Byval parameter should be a pointer type");
1186 Type *elemType = elemPtrType->getElementType();
1187 // Compute the constituent parts
1188 SmallVector<EVT, 16> vtparts;
1189 SmallVector<uint64_t, 16> offsets;
1190 ComputeValueVTs(*this, elemType, vtparts, &offsets, 0);
1191 unsigned totalsize = 0;
1192 for (unsigned j = 0, je = vtparts.size(); j != je; ++j)
1193 totalsize += vtparts[j].getStoreSizeInBits();
1194 SDValue localcopy = DAG.getFrameIndex(
1195 MF.getFrameInfo()->CreateStackObject(totalsize / 8, 16, false),
1197 unsigned sizesofar = 0;
1198 std::vector<SDValue> theChains;
1199 for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
1200 unsigned numElems = 1;
1201 if (vtparts[j].isVector())
1202 numElems = vtparts[j].getVectorNumElements();
1203 for (unsigned k = 0, ke = numElems; k != ke; ++k) {
1204 EVT tmpvt = vtparts[j];
1205 if (tmpvt.isVector())
1206 tmpvt = tmpvt.getVectorElementType();
1207 SDValue arg = DAG.getNode(NVPTXISD::MoveParam, dl, tmpvt,
1208 getParamSymbol(DAG, idx, tmpvt));
1210 DAG.getNode(ISD::ADD, dl, getPointerTy(), localcopy,
1211 DAG.getConstant(sizesofar, getPointerTy()));
1212 theChains.push_back(DAG.getStore(
1213 Chain, dl, arg, addr, MachinePointerInfo(), false, false, 0));
1214 sizesofar += tmpvt.getStoreSizeInBits() / 8;
1219 Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &theChains[0],
1221 InVals.push_back(localcopy);
1225 // Clang will check explicit VarArg and issue error if any. However, Clang
1226 // will let code with
1227 // implicit var arg like f() pass.
1228 // We treat this case as if the arg list is empty.
1229 //if (F.isVarArg()) {
1230 // assert(0 && "VarArg not supported yet!");
1233 if (!OutChains.empty())
1234 DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &OutChains[0],
1240 SDValue NVPTXTargetLowering::LowerReturn(
1241 SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
1242 const SmallVectorImpl<ISD::OutputArg> &Outs,
1243 const SmallVectorImpl<SDValue> &OutVals, DebugLoc dl,
1244 SelectionDAG &DAG) const {
1246 bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1248 unsigned sizesofar = 0;
1250 for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
1251 SDValue theVal = OutVals[i];
1252 EVT theValType = theVal.getValueType();
1253 unsigned numElems = 1;
1254 if (theValType.isVector())
1255 numElems = theValType.getVectorNumElements();
1256 for (unsigned j = 0, je = numElems; j != je; ++j) {
1257 SDValue tmpval = theVal;
1258 if (theValType.isVector())
1259 tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
1260 theValType.getVectorElementType(), tmpval,
1261 DAG.getIntPtrConstant(j));
1262 Chain = DAG.getNode(
1263 isABI ? NVPTXISD::StoreRetval : NVPTXISD::MoveToRetval, dl,
1264 MVT::Other, Chain, DAG.getConstant(isABI ? sizesofar : idx, MVT::i32),
1266 if (theValType.isVector())
1267 sizesofar += theValType.getVectorElementType().getStoreSizeInBits() / 8;
1269 sizesofar += theValType.getStoreSizeInBits() / 8;
1274 return DAG.getNode(NVPTXISD::RET_FLAG, dl, MVT::Other, Chain);
1277 void NVPTXTargetLowering::LowerAsmOperandForConstraint(
1278 SDValue Op, std::string &Constraint, std::vector<SDValue> &Ops,
1279 SelectionDAG &DAG) const {
1280 if (Constraint.length() > 1)
1283 TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
1286 // NVPTX suuport vector of legal types of any length in Intrinsics because the
1287 // NVPTX specific type legalizer
1288 // will legalize them to the PTX supported length.
1289 bool NVPTXTargetLowering::isTypeSupportedInIntrinsic(MVT VT) const {
1290 if (isTypeLegal(VT))
1292 if (VT.isVector()) {
1293 MVT eVT = VT.getVectorElementType();
1294 if (isTypeLegal(eVT))
1300 // llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
1302 // because we need the information that is only available in the "Value" type
1304 // pointer. In particular, the address space information.
1305 bool NVPTXTargetLowering::getTgtMemIntrinsic(
1306 IntrinsicInfo &Info, const CallInst &I, unsigned Intrinsic) const {
1307 switch (Intrinsic) {
1311 case Intrinsic::nvvm_atomic_load_add_f32:
1312 Info.opc = ISD::INTRINSIC_W_CHAIN;
1313 Info.memVT = MVT::f32;
1314 Info.ptrVal = I.getArgOperand(0);
1317 Info.readMem = true;
1318 Info.writeMem = true;
1322 case Intrinsic::nvvm_atomic_load_inc_32:
1323 case Intrinsic::nvvm_atomic_load_dec_32:
1324 Info.opc = ISD::INTRINSIC_W_CHAIN;
1325 Info.memVT = MVT::i32;
1326 Info.ptrVal = I.getArgOperand(0);
1329 Info.readMem = true;
1330 Info.writeMem = true;
1334 case Intrinsic::nvvm_ldu_global_i:
1335 case Intrinsic::nvvm_ldu_global_f:
1336 case Intrinsic::nvvm_ldu_global_p:
1338 Info.opc = ISD::INTRINSIC_W_CHAIN;
1339 if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
1340 Info.memVT = MVT::i32;
1341 else if (Intrinsic == Intrinsic::nvvm_ldu_global_p)
1342 Info.memVT = getPointerTy();
1344 Info.memVT = MVT::f32;
1345 Info.ptrVal = I.getArgOperand(0);
1348 Info.readMem = true;
1349 Info.writeMem = false;
1357 /// isLegalAddressingMode - Return true if the addressing mode represented
1358 /// by AM is legal for this target, for a load/store of the specified type.
1359 /// Used to guide target specific optimizations, like loop strength reduction
1360 /// (LoopStrengthReduce.cpp) and memory optimization for address mode
1361 /// (CodeGenPrepare.cpp)
1362 bool NVPTXTargetLowering::isLegalAddressingMode(const AddrMode &AM,
1365 // AddrMode - This represents an addressing mode of:
1366 // BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
1368 // The legal address modes are
1375 if (AM.BaseOffs || AM.HasBaseReg || AM.Scale)
1381 case 0: // "r", "r+i" or "i" is allowed
1384 if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
1386 // Otherwise we have r+i.
1389 // No scale > 1 is allowed
1395 //===----------------------------------------------------------------------===//
1396 // NVPTX Inline Assembly Support
1397 //===----------------------------------------------------------------------===//
1399 /// getConstraintType - Given a constraint letter, return the type of
1400 /// constraint it is for this target.
1401 NVPTXTargetLowering::ConstraintType
1402 NVPTXTargetLowering::getConstraintType(const std::string &Constraint) const {
1403 if (Constraint.size() == 1) {
1404 switch (Constraint[0]) {
1415 return C_RegisterClass;
1418 return TargetLowering::getConstraintType(Constraint);
1421 std::pair<unsigned, const TargetRegisterClass *>
1422 NVPTXTargetLowering::getRegForInlineAsmConstraint(const std::string &Constraint,
1424 if (Constraint.size() == 1) {
1425 switch (Constraint[0]) {
1427 return std::make_pair(0U, &NVPTX::Int8RegsRegClass);
1429 return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
1431 return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
1434 return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
1436 return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
1438 return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
1441 return TargetLowering::getRegForInlineAsmConstraint(Constraint, VT);
1444 /// getFunctionAlignment - Return the Log2 alignment of this function.
1445 unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const {
1449 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
1450 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
1451 SmallVectorImpl<SDValue> &Results) {
1452 EVT ResVT = N->getValueType(0);
1453 DebugLoc DL = N->getDebugLoc();
1455 assert(ResVT.isVector() && "Vector load must have vector type");
1457 // We only handle "native" vector sizes for now, e.g. <4 x double> is not
1458 // legal. We can (and should) split that into 2 loads of <2 x double> here
1459 // but I'm leaving that as a TODO for now.
1460 assert(ResVT.isSimple() && "Can only handle simple types");
1461 switch (ResVT.getSimpleVT().SimpleTy) {
1474 // This is a "native" vector type
1478 EVT EltVT = ResVT.getVectorElementType();
1479 unsigned NumElts = ResVT.getVectorNumElements();
1481 // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
1482 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
1483 // loaded type to i16 and propogate the "real" type as the memory type.
1484 bool NeedTrunc = false;
1485 if (EltVT.getSizeInBits() < 16) {
1490 unsigned Opcode = 0;
1497 Opcode = NVPTXISD::LoadV2;
1498 LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
1501 Opcode = NVPTXISD::LoadV4;
1502 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
1503 LdResVTs = DAG.getVTList(ListVTs, 5);
1508 SmallVector<SDValue, 8> OtherOps;
1510 // Copy regular operands
1511 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
1512 OtherOps.push_back(N->getOperand(i));
1514 LoadSDNode *LD = cast<LoadSDNode>(N);
1516 // The select routine does not have access to the LoadSDNode instance, so
1517 // pass along the extension information
1518 OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType()));
1520 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, &OtherOps[0],
1521 OtherOps.size(), LD->getMemoryVT(),
1522 LD->getMemOperand());
1524 SmallVector<SDValue, 4> ScalarRes;
1526 for (unsigned i = 0; i < NumElts; ++i) {
1527 SDValue Res = NewLD.getValue(i);
1529 Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
1530 ScalarRes.push_back(Res);
1533 SDValue LoadChain = NewLD.getValue(NumElts);
1536 DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
1538 Results.push_back(BuildVec);
1539 Results.push_back(LoadChain);
1542 static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
1543 SmallVectorImpl<SDValue> &Results) {
1544 SDValue Chain = N->getOperand(0);
1545 SDValue Intrin = N->getOperand(1);
1546 DebugLoc DL = N->getDebugLoc();
1548 // Get the intrinsic ID
1549 unsigned IntrinNo = cast<ConstantSDNode>(Intrin.getNode())->getZExtValue();
1553 case Intrinsic::nvvm_ldg_global_i:
1554 case Intrinsic::nvvm_ldg_global_f:
1555 case Intrinsic::nvvm_ldg_global_p:
1556 case Intrinsic::nvvm_ldu_global_i:
1557 case Intrinsic::nvvm_ldu_global_f:
1558 case Intrinsic::nvvm_ldu_global_p: {
1559 EVT ResVT = N->getValueType(0);
1561 if (ResVT.isVector()) {
1564 unsigned NumElts = ResVT.getVectorNumElements();
1565 EVT EltVT = ResVT.getVectorElementType();
1567 // Since LDU/LDG are target nodes, we cannot rely on DAG type legalization.
1568 // Therefore, we must ensure the type is legal. For i1 and i8, we set the
1569 // loaded type to i16 and propogate the "real" type as the memory type.
1570 bool NeedTrunc = false;
1571 if (EltVT.getSizeInBits() < 16) {
1576 unsigned Opcode = 0;
1586 case Intrinsic::nvvm_ldg_global_i:
1587 case Intrinsic::nvvm_ldg_global_f:
1588 case Intrinsic::nvvm_ldg_global_p:
1589 Opcode = NVPTXISD::LDGV2;
1591 case Intrinsic::nvvm_ldu_global_i:
1592 case Intrinsic::nvvm_ldu_global_f:
1593 case Intrinsic::nvvm_ldu_global_p:
1594 Opcode = NVPTXISD::LDUV2;
1597 LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
1603 case Intrinsic::nvvm_ldg_global_i:
1604 case Intrinsic::nvvm_ldg_global_f:
1605 case Intrinsic::nvvm_ldg_global_p:
1606 Opcode = NVPTXISD::LDGV4;
1608 case Intrinsic::nvvm_ldu_global_i:
1609 case Intrinsic::nvvm_ldu_global_f:
1610 case Intrinsic::nvvm_ldu_global_p:
1611 Opcode = NVPTXISD::LDUV4;
1614 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
1615 LdResVTs = DAG.getVTList(ListVTs, 5);
1620 SmallVector<SDValue, 8> OtherOps;
1622 // Copy regular operands
1624 OtherOps.push_back(Chain); // Chain
1625 // Skip operand 1 (intrinsic ID)
1627 for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i)
1628 OtherOps.push_back(N->getOperand(i));
1630 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
1632 SDValue NewLD = DAG.getMemIntrinsicNode(
1633 Opcode, DL, LdResVTs, &OtherOps[0], OtherOps.size(),
1634 MemSD->getMemoryVT(), MemSD->getMemOperand());
1636 SmallVector<SDValue, 4> ScalarRes;
1638 for (unsigned i = 0; i < NumElts; ++i) {
1639 SDValue Res = NewLD.getValue(i);
1642 DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
1643 ScalarRes.push_back(Res);
1646 SDValue LoadChain = NewLD.getValue(NumElts);
1649 DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
1651 Results.push_back(BuildVec);
1652 Results.push_back(LoadChain);
1655 assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
1656 "Custom handling of non-i8 ldu/ldg?");
1658 // Just copy all operands as-is
1659 SmallVector<SDValue, 4> Ops;
1660 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
1661 Ops.push_back(N->getOperand(i));
1663 // Force output to i16
1664 SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
1666 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
1668 // We make sure the memory type is i8, which will be used during isel
1669 // to select the proper instruction.
1671 DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, LdResVTs, &Ops[0],
1672 Ops.size(), MVT::i8, MemSD->getMemOperand());
1674 Results.push_back(NewLD.getValue(0));
1675 Results.push_back(NewLD.getValue(1));
1681 void NVPTXTargetLowering::ReplaceNodeResults(
1682 SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
1683 switch (N->getOpcode()) {
1685 report_fatal_error("Unhandled custom legalization");
1687 ReplaceLoadVector(N, DAG, Results);
1689 case ISD::INTRINSIC_W_CHAIN:
1690 ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);