Remove incorrect comments. These are not disassmebly only patterns.
[oota-llvm.git] / lib / Target / PTX / PTXAsmPrinter.cpp
index c9b29158877dd0b7e55d32dd6c484169dc799a64..97bfed07958ce66d3547f3046113d1c844bf21c4 100644 (file)
 #include "llvm/MC/MCSymbol.h"
 #include "llvm/Target/Mangler.h"
 #include "llvm/Target/TargetLoweringObjectFile.h"
-#include "llvm/Target/TargetRegistry.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/Path.h"
+#include "llvm/Support/TargetRegistry.h"
 #include "llvm/Support/raw_ostream.h"
 
 using namespace llvm;
@@ -70,6 +70,8 @@ public:
                           const char *Modifier = 0); 
   void printPredicateOperand(const MachineInstr *MI, raw_ostream &O);
 
+  void printCall(const MachineInstr *MI, raw_ostream &O);
+
   unsigned GetOrCreateSourceID(StringRef FileName,
                                StringRef DirName);
 
@@ -92,7 +94,6 @@ static const char *getRegisterTypeName(unsigned RegNo) {
 #define TEST_REGCLS(cls, clsstr)                \
   if (PTX::cls ## RegisterClass->contains(RegNo)) return # clsstr;
   TEST_REGCLS(RegPred, pred);
-  TEST_REGCLS(RegI8,  b8);
   TEST_REGCLS(RegI16, b16);
   TEST_REGCLS(RegI32, b32);
   TEST_REGCLS(RegI64, b64);
@@ -116,7 +117,7 @@ static const char *getStateSpaceName(unsigned addressSpace) {
   return NULL;
 }
 
-static const char *getTypeName(const Type* type) {
+static const char *getTypeName(Type* type) {
   while (true) {
     switch (type->getTypeID()) {
       default: llvm_unreachable("Unknown type");
@@ -125,14 +126,13 @@ static const char *getTypeName(const Type* type) {
       case Type::IntegerTyID:
         switch (type->getPrimitiveSizeInBits()) {
           default: llvm_unreachable("Unknown integer bit-width");
-          case 8:  return ".u8";
           case 16: return ".u16";
           case 32: return ".u32";
           case 64: return ".u64";
         }
       case Type::ArrayTyID:
       case Type::PointerTyID:
-        type = dyn_cast<const SequentialType>(type)->getElementType();
+        type = dyn_cast<SequentialType>(type)->getElementType();
         break;
     }
   }
@@ -244,6 +244,19 @@ void PTXAsmPrinter::EmitFunctionBodyStart() {
       OutStreamer.EmitRawText(Twine(def));
     }
   }
+
+  unsigned Index = 1;
+  // Print parameter passing params
+  for (PTXMachineFunctionInfo::param_iterator
+       i = MFI->paramBegin(), e = MFI->paramEnd(); i != e; ++i) {
+    std::string def = "\t.param .b";
+    def += utostr(*i);
+    def += " __ret_";
+    def += utostr(Index);
+    Index++;
+    def += ";";
+    OutStreamer.EmitRawText(Twine(def));
+  }
 }
 
 void PTXAsmPrinter::EmitInstruction(const MachineInstr *MI) {
@@ -304,7 +317,11 @@ void PTXAsmPrinter::EmitInstruction(const MachineInstr *MI) {
   printPredicateOperand(MI, OS);
 
   // Write instruction to str
-  printInstruction(MI, OS);
+  if (MI->getOpcode() == PTX::CALL) {
+    printCall(MI, OS);
+  } else {
+    printInstruction(MI, OS);
+  }
   OS << ';';
   OS.flush();
 
@@ -408,8 +425,8 @@ void PTXAsmPrinter::EmitVariableDeclaration(const GlobalVariable *gv) {
 
 
   if (PointerType::classof(gv->getType())) {
-    const PointerType* pointerTy = dyn_cast<const PointerType>(gv->getType());
-    const Type* elementTy = pointerTy->getElementType();
+    PointerType* pointerTy = dyn_cast<PointerType>(gv->getType());
+    Type* elementTy = pointerTy->getElementType();
 
     decl += ".b8 ";
     decl += gvsym->getName();
@@ -419,14 +436,14 @@ void PTXAsmPrinter::EmitVariableDeclaration(const GlobalVariable *gv) {
     {
       assert(elementTy->isArrayTy() && "Only pointers to arrays are supported");
 
-      const ArrayType* arrayTy = dyn_cast<const ArrayType>(elementTy);
+      ArrayType* arrayTy = dyn_cast<ArrayType>(elementTy);
       elementTy = arrayTy->getElementType();
 
       unsigned numElements = arrayTy->getNumElements();
 
       while (elementTy->isArrayTy()) {
 
-        arrayTy = dyn_cast<const ArrayType>(elementTy);
+        arrayTy = dyn_cast<ArrayType>(elementTy);
         elementTy = arrayTy->getElementType();
 
         numElements *= arrayTy->getNumElements();
@@ -571,6 +588,28 @@ printPredicateOperand(const MachineInstr *MI, raw_ostream &O) {
   }
 }
 
+void PTXAsmPrinter::
+printCall(const MachineInstr *MI, raw_ostream &O) {
+
+  O << "\tcall.uni\t";
+
+  const GlobalValue *Address = MI->getOperand(2).getGlobal();
+  O << Address->getName() << ", (";
+
+  // (0,1) : predicate register/flag
+  // (2)   : callee
+  for (unsigned i = 3; i < MI->getNumOperands(); ++i) {
+    //const MachineOperand& MO = MI->getOperand(i);
+
+    printReturnOperand(MI, i, O);
+    if (i < MI->getNumOperands()-1) {
+      O << ", ";
+    }
+  }
+
+  O << ")";
+}
+
 unsigned PTXAsmPrinter::GetOrCreateSourceID(StringRef FileName,
                                             StringRef DirName) {
   // If FE did not provide a file name, then assume stdin.