Make sure NVPTX doesn't emit symbol names that aren't valid in PTX.
authorEli Bendersky <eliben@google.com>
Mon, 10 Mar 2014 20:05:42 +0000 (20:05 +0000)
committerEli Bendersky <eliben@google.com>
Mon, 10 Mar 2014 20:05:42 +0000 (20:05 +0000)
NVPTX, like the other backends, relies on generic symbol name sanitizing done by
MCSymbol. However, the ptxas assembler is more stringent and disallows some
additional characters in symbol names.

See PR19099 for more details.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@203483 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/NVPTX/NVPTXAsmPrinter.cpp
lib/Target/NVPTX/NVPTXAsmPrinter.h

index 25108bc39aa8126f1b2f973f9640437a4f033e4b..0cbdcc49aa9a6b1ae254cce29ec19f5192e5a9e0 100644 (file)
@@ -684,7 +684,7 @@ void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
   else
     O << ".func ";
   printReturnValStr(F, O);
-  O << *getSymbol(F) << "\n";
+  O << getSymbolName(F) << "\n";
   emitFunctionParamList(F, O);
   O << ";\n";
 }
@@ -1209,7 +1209,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
     else
       O << getPTXFundamentalTypeStr(ETy, false);
     O << " ";
-    O << *getSymbol(GVar);
+    O << getSymbolName(GVar);
 
     // Ptx allows variable initilization only for constant and global state
     // spaces.
@@ -1245,15 +1245,15 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
           bufferAggregateConstant(Initializer, &aggBuffer);
           if (aggBuffer.numSymbols) {
             if (nvptxSubtarget.is64Bit()) {
-              O << " .u64 " << *getSymbol(GVar) << "[";
+              O << " .u64 " << getSymbolName(GVar) << "[";
               O << ElementSize / 8;
             } else {
-              O << " .u32 " << *getSymbol(GVar) << "[";
+              O << " .u32 " << getSymbolName(GVar) << "[";
               O << ElementSize / 4;
             }
             O << "]";
           } else {
-            O << " .b8 " << *getSymbol(GVar) << "[";
+            O << " .b8 " << getSymbolName(GVar) << "[";
             O << ElementSize;
             O << "]";
           }
@@ -1261,7 +1261,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
           aggBuffer.print();
           O << "}";
         } else {
-          O << " .b8 " << *getSymbol(GVar);
+          O << " .b8 " << getSymbolName(GVar);
           if (ElementSize) {
             O << "[";
             O << ElementSize;
@@ -1269,7 +1269,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
           }
         }
       } else {
-        O << " .b8 " << *getSymbol(GVar);
+        O << " .b8 " << getSymbolName(GVar);
         if (ElementSize) {
           O << "[";
           O << ElementSize;
@@ -1376,7 +1376,7 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
     O << " .";
     O << getPTXFundamentalTypeStr(ETy);
     O << " ";
-    O << *getSymbol(GVar);
+    O << getSymbolName(GVar);
     return;
   }
 
@@ -1391,7 +1391,7 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
   case Type::ArrayTyID:
   case Type::VectorTyID:
     ElementSize = TD->getTypeStoreSize(ETy);
-    O << " .b8 " << *getSymbol(GVar) << "[";
+    O << " .b8 " << getSymbolName(GVar) << "[";
     if (ElementSize) {
       O << itostr(ElementSize);
     }
@@ -1446,7 +1446,7 @@ void NVPTXAsmPrinter::printParamName(Function::const_arg_iterator I,
                                      int paramIndex, raw_ostream &O) {
   if ((nvptxSubtarget.getDrvInterface() == NVPTX::NVCL) ||
       (nvptxSubtarget.getDrvInterface() == NVPTX::CUDA))
-    O << *getSymbol(I->getParent()) << "_param_" << paramIndex;
+    O << getSymbolName(I->getParent()) << "_param_" << paramIndex;
   else {
     std::string argName = I->getName();
     const char *p = argName.c_str();
@@ -1505,13 +1505,13 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
       if (llvm::isImage(*I)) {
         std::string sname = I->getName();
         if (llvm::isImageWriteOnly(*I))
-          O << "\t.param .surfref " << *getSymbol(F) << "_param_"
+          O << "\t.param .surfref " << getSymbolName(F) << "_param_"
             << paramIndex;
         else // Default image is read_only
-          O << "\t.param .texref " << *getSymbol(F) << "_param_"
+          O << "\t.param .texref " << getSymbolName(F) << "_param_"
             << paramIndex;
       } else // Should be llvm::isSampler(*I)
-        O << "\t.param .samplerref " << *getSymbol(F) << "_param_"
+        O << "\t.param .samplerref " << getSymbolName(F) << "_param_"
           << paramIndex;
       continue;
     }
@@ -1758,13 +1758,13 @@ void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
     return;
   }
   if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
-    O << *getSymbol(GVar);
+    O << getSymbolName(GVar);
     return;
   }
   if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
     const Value *v = Cexpr->stripPointerCasts();
     if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
-      O << *getSymbol(GVar);
+      O << getSymbolName(GVar);
       return;
     } else {
       O << *LowerConstant(CPV, *this);
@@ -2078,7 +2078,7 @@ void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum,
     break;
 
   case MachineOperand::MO_GlobalAddress:
-    O << *getSymbol(MO.getGlobal());
+    O << getSymbolName(MO.getGlobal());
     break;
 
   case MachineOperand::MO_MachineBasicBlock:
@@ -2139,6 +2139,33 @@ LineReader *NVPTXAsmPrinter::getReader(std::string filename) {
   return reader;
 }
 
+std::string NVPTXAsmPrinter::getSymbolName(const GlobalValue *GV) const {
+  // Obtain the original symbol name.
+  MCSymbol *Sym = getSymbol(GV);
+  std::string OriginalName;
+  raw_string_ostream OriginalNameStream(OriginalName);
+  Sym->print(OriginalNameStream);
+  OriginalNameStream.flush();
+
+  // MCSymbol already does symbol-name sanitizing, so names it produces are
+  // valid for object files. The only two characters valida in that context
+  // and indigestible by the PTX assembler are '.' and '@'.
+  std::string CleanName;
+  raw_string_ostream CleanNameStream(CleanName);
+  for (unsigned I = 0, E = OriginalName.size(); I != E; ++I) {
+    char C = OriginalName[I];
+    if (C == '.') {
+      CleanNameStream << "_$_";
+    } else if (C == '@') {
+      CleanNameStream << "_%_";
+    } else {
+      CleanNameStream << C;
+    }
+  }
+
+  return CleanNameStream.str();
+}
+
 std::string LineReader::readLine(unsigned lineNum) {
   if (lineNum < theCurLine) {
     theCurLine = 0;
index 71624200d0ec1d2d75e561504e6a549f8ef7f764..abce85c39d74288ea0dd51f2500c7d5743b1b5a9 100644 (file)
@@ -276,6 +276,11 @@ private:
 
   LineReader *reader;
   LineReader *getReader(std::string);
+
+  // Get the symbol name of the given global symbol.
+  //
+  // Cleans up the name so it's a valid in PTX assembly.
+  std::string getSymbolName(const GlobalValue *GV) const;
 public:
   NVPTXAsmPrinter(TargetMachine &TM, MCStreamer &Streamer)
       : AsmPrinter(TM, Streamer),