ArrayRefize some code. No functionality change.
[oota-llvm.git] / lib / Target / NVPTX / NVPTXAsmPrinter.cpp
index d3dfb35e2611f587376761114ccd280339b48621..0115e1f5d3a82337f59147d0094f6e75028a0a4d 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "NVPTXAsmPrinter.h"
+#include "MCTargetDesc/NVPTXMCAsmInfo.h"
 #include "NVPTX.h"
 #include "NVPTXInstrInfo.h"
-#include "NVPTXTargetMachine.h"
+#include "NVPTXNumRegisters.h"
 #include "NVPTXRegisterInfo.h"
+#include "NVPTXTargetMachine.h"
 #include "NVPTXUtilities.h"
-#include "MCTargetDesc/NVPTXMCAsmInfo.h"
-#include "NVPTXNumRegisters.h"
+#include "cl_common_defines.h"
 #include "llvm/ADT/StringExtras.h"
-#include "llvm/DebugInfo.h"
-#include "llvm/Function.h"
-#include "llvm/GlobalVariable.h"
-#include "llvm/Module.h"
+#include "llvm/Analysis/ConstantFolding.h"
+#include "llvm/Assembly/Writer.h"
 #include "llvm/CodeGen/Analysis.h"
-#include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineModuleInfo.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/DebugInfo.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Operator.h"
 #include "llvm/MC/MCStreamer.h"
 #include "llvm/MC/MCSymbol.h"
-#include "llvm/Target/Mangler.h"
-#include "llvm/Target/TargetLoweringObjectFile.h"
-#include "llvm/Support/TargetRegistry.h"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FormattedStream.h"
-#include "llvm/DerivedTypes.h"
-#include "llvm/Support/TimeValue.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Support/Path.h"
-#include "llvm/Assembly/Writer.h"
-#include "cl_common_defines.h"
+#include "llvm/Support/TargetRegistry.h"
+#include "llvm/Support/TimeValue.h"
+#include "llvm/Target/Mangler.h"
+#include "llvm/Target/TargetLoweringObjectFile.h"
 #include <sstream>
 using namespace llvm;
 
@@ -68,7 +69,54 @@ static cl::opt<bool, true>InterleaveSrc("nvptx-emit-src",
                                         cl::location(llvm::InterleaveSrcInPtx));
 
 
+namespace {
+/// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
+/// depends.
+void DiscoverDependentGlobals(Value *V,
+                              DenseSet<GlobalVariable*> &Globals) {
+  if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
+    Globals.insert(GV);
+  else {
+    if (User *U = dyn_cast<User>(V)) {
+      for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
+        DiscoverDependentGlobals(U->getOperand(i), Globals);
+      }
+    }
+  }
+}
 
+/// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
+/// instances to be emitted, but only after any dependents have been added
+/// first.
+void VisitGlobalVariableForEmission(GlobalVariable *GV,
+                                    SmallVectorImpl<GlobalVariable*> &Order,
+                                    DenseSet<GlobalVariable*> &Visited,
+                                    DenseSet<GlobalVariable*> &Visiting) {
+  // Have we already visited this one?
+  if (Visited.count(GV)) return;
+
+  // Do we have a circular dependency?
+  if (Visiting.count(GV))
+    report_fatal_error("Circular dependency found in global variable set");
+
+  // Start visiting this global
+  Visiting.insert(GV);
+
+  // Make sure we visit all dependents first
+  DenseSet<GlobalVariable*> Others;
+  for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
+    DiscoverDependentGlobals(GV->getOperand(i), Others);
+  
+  for (DenseSet<GlobalVariable*>::iterator I = Others.begin(),
+       E = Others.end(); I != E; ++I)
+    VisitGlobalVariableForEmission(*I, Order, Visited, Visiting);
+
+  // Now we can visit ourself
+  Order.push_back(GV);
+  Visited.insert(GV);
+  Visiting.erase(GV);
+}
+}
 
 // @TODO: This is a copy from AsmPrinter.cpp.  The function is static, so we
 // cannot just link to the existing version.
@@ -117,20 +165,14 @@ const MCExpr *nvptx::LowerConstant(const Constant *CV, AsmPrinter &AP) {
   case Instruction::GetElementPtr: {
     const DataLayout &TD = *AP.TM.getDataLayout();
     // Generate a symbolic expression for the byte address
-    const Constant *PtrVal = CE->getOperand(0);
-    SmallVector<Value*, 8> IdxVec(CE->op_begin()+1, CE->op_end());
-    int64_t Offset = TD.getIndexedOffset(PtrVal->getType(), IdxVec);
+    APInt OffsetAI(TD.getPointerSizeInBits(), 0);
+    cast<GEPOperator>(CE)->accumulateConstantOffset(TD, OffsetAI);
 
     const MCExpr *Base = LowerConstant(CE->getOperand(0), AP);
-    if (Offset == 0)
+    if (!OffsetAI)
       return Base;
 
-    // Truncate/sext the offset to the pointer size.
-    if (TD.getPointerSizeInBits() != 64) {
-      int SExtAmount = 64-TD.getPointerSizeInBits();
-      Offset = (Offset << SExtAmount) >> SExtAmount;
-    }
-
+    int64_t Offset = OffsetAI.getSExtValue();
     return MCBinaryExpr::CreateAdd(Base, MCConstantExpr::Create(Offset, Ctx),
                                    Ctx);
   }
@@ -461,21 +503,7 @@ NVPTXAsmPrinter::getVirtualRegisterName(unsigned vr, bool isVec,
     O << getNVPTXRegClassStr(RC) << mapped_vr;
     return;
   }
-  // Vector virtual register
-  if (getNVPTXVectorSize(RC) == 4)
-    O << "{"
-    << getNVPTXRegClassStr(RC) << mapped_vr << "_0, "
-    << getNVPTXRegClassStr(RC) << mapped_vr << "_1, "
-    << getNVPTXRegClassStr(RC) << mapped_vr << "_2, "
-    << getNVPTXRegClassStr(RC) << mapped_vr << "_3"
-    << "}";
-  else if (getNVPTXVectorSize(RC) == 2)
-    O << "{"
-    << getNVPTXRegClassStr(RC) << mapped_vr << "_0, "
-    << getNVPTXRegClassStr(RC) << mapped_vr << "_1"
-    << "}";
-  else
-    llvm_unreachable("Unsupported vector size");
+  report_fatal_error("Bad register!");
 }
 
 void
@@ -631,7 +659,7 @@ void NVPTXAsmPrinter::printLdStCode(const MachineInstr *MI, int opNum,
           O << ".global";
         break;
       default:
-        assert("wrong value");
+        llvm_unreachable("Wrong Address Space");
       }
     }
     else if (!strcmp(Modifier, "sign")) {
@@ -649,10 +677,10 @@ void NVPTXAsmPrinter::printLdStCode(const MachineInstr *MI, int opNum,
         O << ".v4";
     }
     else
-      assert("unknown modifier");
+      llvm_unreachable("Unknown Modifier");
   }
   else
-    assert("unknown modifier");
+    llvm_unreachable("Empty Modifier");
 }
 
 void NVPTXAsmPrinter::emitDeclaration (const Function *F, raw_ostream &O) {
@@ -893,10 +921,27 @@ bool NVPTXAsmPrinter::doInitialization (Module &M) {
 
   emitDeclarations(M, OS2);
 
-  // Print out module-level global variables here.
+  // As ptxas does not support forward references of globals, we need to first
+  // sort the list of module-level globals in def-use order. We visit each
+  // global variable in order, and ensure that we emit it *after* its dependent
+  // globals. We use a little extra memory maintaining both a set and a list to
+  // have fast searches while maintaining a strict ordering.
+  SmallVector<GlobalVariable*,8> Globals;
+  DenseSet<GlobalVariable*> GVVisited;
+  DenseSet<GlobalVariable*> GVVisiting;
+
+  // Visit each global variable, in order
   for (Module::global_iterator I = M.global_begin(), E = M.global_end();
-      I != E; ++I)
-    printModuleLevelGV(I, OS2);
+       I != E; ++I)
+    VisitGlobalVariableForEmission(I, Globals, GVVisited, GVVisiting);
+
+  assert(GVVisited.size() == M.getGlobalList().size() && 
+         "Missed a global variable");
+  assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
+
+  // Print out module-level global variables in proper order
+  for (unsigned i = 0, e = Globals.size(); i != e; ++i)
+    printModuleLevelGV(Globals[i], OS2);
 
   OS2 << '\n';
 
@@ -910,7 +955,8 @@ void NVPTXAsmPrinter::emitHeader (Module &M, raw_ostream &O) {
   O << "//\n";
   O << "\n";
 
-  O << ".version 3.0\n";
+  unsigned PTXVersion = nvptxSubtarget.getPTXVersion();
+  O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
 
   O << ".target ";
   O << nvptxSubtarget.getTargetName();
@@ -1254,7 +1300,8 @@ void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
     O << "shared" ;
     break;
   default:
-    llvm_unreachable("unexpected address space");
+    report_fatal_error("Bad address space found while emitting PTX");
+    break;
   }
 }
 
@@ -1422,7 +1469,7 @@ void NVPTXAsmPrinter::printParamName(int paramIndex, raw_ostream &O) {
 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F,
                                             raw_ostream &O) {
   const DataLayout *TD = TM.getDataLayout();
-  const AttrListPtr &PAL = F->getAttributes();
+  const AttributeSet &PAL = F->getAttributes();
   const TargetLowering *TLI = TM.getTargetLowering();
   Function::const_arg_iterator I, E;
   unsigned paramIndex = 0;
@@ -1456,8 +1503,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F,
       continue;
     }
 
-    if (PAL.getParamAttributes(paramIndex+1).
-          hasAttribute(Attributes::ByVal) == false) {
+    if (PAL.hasAttribute(paramIndex+1, Attribute::ByVal) == false) {
       // Just a scalar
       const PointerType *PTy = dyn_cast<PointerType>(Ty);
       if (isKernelFunc) {
@@ -1525,6 +1571,9 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F,
       // <a> = PAL.getparamalignment
       // size = typeallocsize of element type
       unsigned align = PAL.getParamAlignment(paramIndex+1);
+      if (align == 0)
+        align = TD->getABITypeAlignment(ETy);
+
       unsigned sz = TD->getTypeAllocSize(ETy);
       O << "\t.param .align " << align
           << " .b8 ";
@@ -1961,29 +2010,9 @@ bool NVPTXAsmPrinter::ignoreLoc(const MachineInstr &MI)
   case NVPTX::StoreParamI64:  case NVPTX::StoreParamI8:
   case NVPTX::StoreParamS32I8:  case NVPTX::StoreParamU32I8:
   case NVPTX::StoreParamS32I16:  case NVPTX::StoreParamU32I16:
-  case NVPTX::StoreParamScalar2F32:  case NVPTX::StoreParamScalar2F64:
-  case NVPTX::StoreParamScalar2I16:  case NVPTX::StoreParamScalar2I32:
-  case NVPTX::StoreParamScalar2I64:  case NVPTX::StoreParamScalar2I8:
-  case NVPTX::StoreParamScalar4F32:  case NVPTX::StoreParamScalar4I16:
-  case NVPTX::StoreParamScalar4I32:  case NVPTX::StoreParamScalar4I8:
-  case NVPTX::StoreParamV2F32:  case NVPTX::StoreParamV2F64:
-  case NVPTX::StoreParamV2I16:  case NVPTX::StoreParamV2I32:
-  case NVPTX::StoreParamV2I64:  case NVPTX::StoreParamV2I8:
-  case NVPTX::StoreParamV4F32:  case NVPTX::StoreParamV4I16:
-  case NVPTX::StoreParamV4I32:  case NVPTX::StoreParamV4I8:
   case NVPTX::StoreRetvalF32:  case NVPTX::StoreRetvalF64:
   case NVPTX::StoreRetvalI16:  case NVPTX::StoreRetvalI32:
   case NVPTX::StoreRetvalI64:  case NVPTX::StoreRetvalI8:
-  case NVPTX::StoreRetvalScalar2F32:  case NVPTX::StoreRetvalScalar2F64:
-  case NVPTX::StoreRetvalScalar2I16:  case NVPTX::StoreRetvalScalar2I32:
-  case NVPTX::StoreRetvalScalar2I64:  case NVPTX::StoreRetvalScalar2I8:
-  case NVPTX::StoreRetvalScalar4F32:  case NVPTX::StoreRetvalScalar4I16:
-  case NVPTX::StoreRetvalScalar4I32:  case NVPTX::StoreRetvalScalar4I8:
-  case NVPTX::StoreRetvalV2F32:  case NVPTX::StoreRetvalV2F64:
-  case NVPTX::StoreRetvalV2I16:  case NVPTX::StoreRetvalV2I32:
-  case NVPTX::StoreRetvalV2I64:  case NVPTX::StoreRetvalV2I8:
-  case NVPTX::StoreRetvalV4F32:  case NVPTX::StoreRetvalV4I16:
-  case NVPTX::StoreRetvalV4I32:  case NVPTX::StoreRetvalV4I8:
   case NVPTX::LastCallArgF32:  case NVPTX::LastCallArgF64:
   case NVPTX::LastCallArgI16:  case NVPTX::LastCallArgI32:
   case NVPTX::LastCallArgI32imm:  case NVPTX::LastCallArgI64:
@@ -1994,16 +2023,6 @@ bool NVPTXAsmPrinter::ignoreLoc(const MachineInstr &MI)
   case NVPTX::LoadParamRegF32:  case NVPTX::LoadParamRegF64:
   case NVPTX::LoadParamRegI16:  case NVPTX::LoadParamRegI32:
   case NVPTX::LoadParamRegI64:  case NVPTX::LoadParamRegI8:
-  case NVPTX::LoadParamScalar2F32:  case NVPTX::LoadParamScalar2F64:
-  case NVPTX::LoadParamScalar2I16:  case NVPTX::LoadParamScalar2I32:
-  case NVPTX::LoadParamScalar2I64:  case NVPTX::LoadParamScalar2I8:
-  case NVPTX::LoadParamScalar4F32:  case NVPTX::LoadParamScalar4I16:
-  case NVPTX::LoadParamScalar4I32:  case NVPTX::LoadParamScalar4I8:
-  case NVPTX::LoadParamV2F32:  case NVPTX::LoadParamV2F64:
-  case NVPTX::LoadParamV2I16:  case NVPTX::LoadParamV2I32:
-  case NVPTX::LoadParamV2I64:  case NVPTX::LoadParamV2I8:
-  case NVPTX::LoadParamV4F32:  case NVPTX::LoadParamV4I16:
-  case NVPTX::LoadParamV4I32:  case NVPTX::LoadParamV4I8:
   case NVPTX::PrototypeInst:   case NVPTX::DBG_VALUE:
     return true;
   }