PTX: Add intrinsic support for ntid, ctaid, and nctaid registers
[oota-llvm.git] / lib / Target / PTX / PTXAsmPrinter.cpp
index cd27fb5d82efcbe1432dfb782467a6da135193b0..2c4c79b2f103be3d841a3f3b4df6f16e33a1c0ef 100644 (file)
@@ -24,6 +24,7 @@
 #include "llvm/ADT/Twine.h"
 #include "llvm/CodeGen/AsmPrinter.h"
 #include "llvm/CodeGen/MachineInstr.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/MC/MCStreamer.h"
 #include "llvm/MC/MCSymbol.h"
 #include "llvm/Target/Mangler.h"
 
 using namespace llvm;
 
-static cl::opt<std::string>
-OptPTXVersion("ptx-version", cl::desc("Set PTX version"),
-           cl::init("1.4"));
-
-static cl::opt<std::string>
-OptPTXTarget("ptx-target", cl::desc("Set GPU target (comma-separated list)"),
-           cl::init("sm_10"));
-
 namespace {
 class PTXAsmPrinter : public AsmPrinter {
 public:
@@ -67,6 +60,8 @@ public:
   void printOperand(const MachineInstr *MI, int opNum, raw_ostream &OS);
   void printMemOperand(const MachineInstr *MI, int opNum, raw_ostream &OS,
                        const char *Modifier = 0);
+  void printParamOperand(const MachineInstr *MI, int opNum, raw_ostream &OS,
+                         const char *Modifier = 0);
 
   // autogen'd.
   void printInstruction(const MachineInstr *MI, raw_ostream &OS);
@@ -81,10 +76,14 @@ private:
 static const char PARAM_PREFIX[] = "__param_";
 
 static const char *getRegisterTypeName(unsigned RegNo) {
-#define TEST_REGCLS(cls, clsstr) \
+#define TEST_REGCLS(cls, clsstr)                \
   if (PTX::cls ## RegisterClass->contains(RegNo)) return # clsstr;
-  TEST_REGCLS(RRegs32, s32);
   TEST_REGCLS(Preds, pred);
+  TEST_REGCLS(RRegu16, u16);
+  TEST_REGCLS(RRegu32, u32);
+  TEST_REGCLS(RRegu64, u64);
+  TEST_REGCLS(RRegf32, f32);
+  TEST_REGCLS(RRegf64, f64);
 #undef TEST_REGCLS
 
   llvm_unreachable("Not in any register class!");
@@ -103,11 +102,36 @@ static const char *getInstructionTypeName(const MachineInstr *MI) {
 }
 
 static const char *getStateSpaceName(unsigned addressSpace) {
-  if (addressSpace <= 255)
-    return "global";
-  // TODO Add more state spaces
+  switch (addressSpace) {
+  default: llvm_unreachable("Unknown state space");
+  case PTX::GLOBAL:    return "global";
+  case PTX::CONSTANT:  return "const";
+  case PTX::LOCAL:     return "local";
+  case PTX::PARAMETER: return "param";
+  case PTX::SHARED:    return "shared";
+  }
+  return NULL;
+}
 
-  llvm_unreachable("Unknown state space");
+static const char *getTypeName(const Type* type) {
+  while (true) {
+    switch (type->getTypeID()) {
+      default: llvm_unreachable("Unknown type");
+      case Type::FloatTyID: return ".f32";
+      case Type::DoubleTyID: return ".f64";
+      case Type::IntegerTyID:
+        switch (type->getPrimitiveSizeInBits()) {
+          default: llvm_unreachable("Unknown integer bit-width");
+          case 16: return ".u16";
+          case 32: return ".u32";
+          case 64: return ".u64";
+        }
+      case Type::ArrayTyID:
+      case Type::PointerTyID:
+        type = dyn_cast<const SequentialType>(type)->getElementType();
+        break;
+    }
+  }
   return NULL;
 }
 
@@ -142,8 +166,12 @@ bool PTXAsmPrinter::doFinalization(Module &M) {
 
 void PTXAsmPrinter::EmitStartOfAsmFile(Module &M)
 {
-  OutStreamer.EmitRawText(Twine("\t.version " + OptPTXVersion));
-  OutStreamer.EmitRawText(Twine("\t.target " + OptPTXTarget));
+  const PTXSubtarget& ST = TM.getSubtarget<PTXSubtarget>();
+
+  OutStreamer.EmitRawText(Twine("\t.version " + ST.getPTXVersionString()));
+  OutStreamer.EmitRawText(Twine("\t.target " + ST.getTargetString() +
+                                (ST.supportsDouble() ? ""
+                                                     : ", map_f64_to_f32")));
   OutStreamer.AddBlankLine();
 
   // declare global variables
@@ -214,6 +242,28 @@ void PTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum,
     case MachineOperand::MO_Register:
       OS << getRegisterName(MO.getReg());
       break;
+    case MachineOperand::MO_FPImmediate:
+      APInt constFP = MO.getFPImm()->getValueAPF().bitcastToAPInt();
+      bool  isFloat = MO.getFPImm()->getType()->getTypeID() == Type::FloatTyID;
+      // Emit 0F for 32-bit floats and 0D for 64-bit doubles.
+      if (isFloat) {
+        OS << "0F";
+      }
+      else {
+        OS << "0D";
+      }
+      // Emit the encoded floating-point value.
+      if (constFP.getZExtValue() > 0) {
+        OS << constFP.toString(16, false);
+      }
+      else {
+        OS << "00000000";
+        // If We have a double-precision zero, pad to 8-bytes.
+        if (!isFloat) {
+          OS << "00000000";
+        }
+      }
+      break;
   }
 }
 
@@ -228,6 +278,11 @@ void PTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum,
   printOperand(MI, opNum+1, OS);
 }
 
+void PTXAsmPrinter::printParamOperand(const MachineInstr *MI, int opNum,
+                                      raw_ostream &OS, const char *Modifier) {
+  OS << PARAM_PREFIX << (int) MI->getOperand(opNum).getImm() + 1;
+}
+
 void PTXAsmPrinter::EmitVariableDeclaration(const GlobalVariable *gv) {
   // Check to see if this is a special global used by LLVM, if so, emit it.
   if (EmitSpecialLLVMGlobal(gv))
@@ -256,8 +311,8 @@ void PTXAsmPrinter::EmitVariableDeclaration(const GlobalVariable *gv) {
     decl += " ";
   }
 
-  // TODO: add types
-  decl += ".s32 ";
+  decl += getTypeName(gv->getType());
+  decl += " ";
 
   decl += gvsym->getName();
 
@@ -304,16 +359,25 @@ void PTXAsmPrinter::EmitFunctionDeclaration() {
   if (!MFI->argRegEmpty()) {
     decl += " (";
     if (isKernel) {
-      for (int i = 0, e = MFI->getNumArg(); i != e; ++i) {
-        if (i != 0)
+      unsigned cnt = 0;
+      //for (int i = 0, e = MFI->getNumArg(); i != e; ++i) {
+      for(PTXMachineFunctionInfo::reg_reverse_iterator
+          i = MFI->argRegReverseBegin(), e = MFI->argRegReverseEnd(), b = i;
+          i != e; ++i) {
+        reg = *i;
+        assert(reg != PTX::NoRegister && "Not a valid register!");
+        if (i != b)
           decl += ", ";
-        decl += ".param .s32 "; // TODO: add types
+        decl += ".param .";
+        decl += getRegisterTypeName(reg);
+        decl += " ";
         decl += PARAM_PREFIX;
-        decl += utostr(i + 1);
+        decl += utostr(++cnt);
       }
     } else {
-      for (PTXMachineFunctionInfo::reg_iterator
-           i = MFI->argRegBegin(), e = MFI->argRegEnd(), b = i; i != e; ++i) {
+      for (PTXMachineFunctionInfo::reg_reverse_iterator
+           i = MFI->argRegReverseBegin(), e = MFI->argRegReverseEnd(), b = i;
+           i != e; ++i) {
         reg = *i;
         assert(reg != PTX::NoRegister && "Not a valid register!");
         if (i != b)