Fix a bug introduced by the Getelementptr change
[oota-llvm.git] / lib / Target / CBackend / Writer.cpp
index c706d089ba244188520241defb3d6081d495c6f8..c54af08e8442130e99a1bd988a0fc9ea97336a3c 100644 (file)
 #include "llvm/iPHINode.h"
 #include "llvm/iOther.h"
 #include "llvm/iOperators.h"
+#include "llvm/Pass.h"
 #include "llvm/SymbolTable.h"
 #include "llvm/SlotCalculator.h"
+#include "llvm/Analysis/FindUsedTypes.h"
 #include "llvm/Support/InstVisitor.h"
 #include "llvm/Support/InstIterator.h"
 #include "Support/StringExtras.h"
@@ -28,21 +30,40 @@ using std::map;
 using std::ostream;
 
 namespace {
-  class CWriter : public InstVisitor<CWriter> {
-    ostreamOut; 
-    SlotCalculator &Table;
+  class CWriter : public Pass, public InstVisitor<CWriter> {
+    ostream &Out; 
+    SlotCalculator *Table;
     const Module *TheModule;
     map<const Type *, string> TypeNames;
     std::set<const Value*> MangledGlobals;
   public:
-    inline CWriter(ostream &o, SlotCalculator &Tab, const Module *M)
-      : Out(o), Table(Tab), TheModule(M) {
+    CWriter(ostream &o) : Out(o) {}
+
+    void getAnalysisUsage(AnalysisUsage &AU) const {
+      AU.setPreservesAll();
+      AU.addRequired<FindUsedTypes>();
+    }
+
+    virtual bool run(Module &M) {
+      // Initialize
+      Table = new SlotCalculator(&M, false);
+      TheModule = &M;
+
+      // Ensure that all structure types have names...
+      bool Changed = nameAllUsedStructureTypes(M);
+
+      // Run...
+      printModule(&M);
+
+      // Free memory...
+      delete Table;
+      TypeNames.clear();
+      MangledGlobals.clear();
+      return false;
     }
-    
-    inline void write(Module *M) { printModule(M); }
 
     ostream &printType(const Type *Ty, const string &VariableName = "",
-                       bool IgnoreName = false);
+                       bool IgnoreName = false, bool namedContext = true);
 
     void writeOperand(Value *Operand);
     void writeOperandInternal(Value *Operand);
@@ -50,12 +71,12 @@ namespace {
     string getValueName(const Value *V);
 
   private :
+    bool nameAllUsedStructureTypes(Module &M);
     void printModule(Module *M);
     void printSymbolTable(const SymbolTable &ST);
     void printGlobal(const GlobalVariable *GV);
-    void printFunctionSignature(const Function *F);
-    void printFunctionDecl(const Function *F); // Print just the forward decl
-    
+    void printFunctionSignature(const Function *F, bool Prototype);
+
     void printFunction(Function *);
 
     void printConstant(Constant *CPV);
@@ -138,16 +159,21 @@ string CWriter::getValueName(const Value *V) {
            makeNameProper(V->getName());      
   }
 
-  int Slot = Table.getValSlot(V);
+  int Slot = Table->getValSlot(V);
   assert(Slot >= 0 && "Invalid value!");
   return "ltmp_" + itostr(Slot) + "_" + utostr(V->getType()->getUniqueID());
 }
 
+// A pointer type should not use parens around *'s alone, e.g., (**)
+inline bool ptrTypeNameNeedsParens(const string &NameSoFar) {
+  return (NameSoFar.find_last_not_of('*') != std::string::npos);
+}
+
 // Pass the Type* and the variable name and this prints out the variable
 // declaration.
 //
 ostream &CWriter::printType(const Type *Ty, const string &NameSoFar,
-                            bool IgnoreName = false) {
+                            bool IgnoreName, bool namedContext) {
   if (Ty->isPrimitiveType())
     switch (Ty->getPrimitiveID()) {
     case Type::VoidTyID:   return Out << "void "               << NameSoFar;
@@ -211,7 +237,15 @@ ostream &CWriter::printType(const Type *Ty, const string &NameSoFar,
 
   case Type::PointerTyID: {
     const PointerType *PTy = cast<PointerType>(Ty);
-    return printType(PTy->getElementType(), "(*" + NameSoFar + ")");
+    std::string ptrName = "*" + NameSoFar;
+
+    // Do not need parens around "* NameSoFar" if NameSoFar consists only
+    // of zero or more '*' chars *and* this is not an unnamed pointer type
+    // such as the result type in a cast statement.  Otherwise, enclose in ( ).
+    if (ptrTypeNameNeedsParens(NameSoFar) || !namedContext)
+      ptrName = "(" + ptrName + ")";    // 
+
+    return printType(PTy->getElementType(), ptrName);
   }
 
   case Type::ArrayTyID: {
@@ -394,7 +428,7 @@ void CWriter::writeOperandInternal(Value *Operand) {
   } else if (Constant *CPV = dyn_cast<Constant>(Operand)) {
     printConstant(CPV); 
   } else {
-    int Slot = Table.getValSlot(Operand);
+    int Slot = Table->getValSlot(Operand);
     assert(Slot >= 0 && "Malformed LLVM!");
     Out << "ltmp_" << Slot << "_" << Operand->getType()->getUniqueID();
   }
@@ -410,6 +444,36 @@ void CWriter::writeOperand(Value *Operand) {
     Out << ")";
 }
 
+// nameAllUsedStructureTypes - If there are structure types in the module that
+// are used but do not have names assigned to them in the symbol table yet then
+// we assign them names now.
+//
+bool CWriter::nameAllUsedStructureTypes(Module &M) {
+  // Get a set of types that are used by the program...
+  std::set<const Type *> UT = getAnalysis<FindUsedTypes>().getTypes();
+
+  // Loop over the module symbol table, removing types from UT that are already
+  // named.
+  //
+  SymbolTable *MST = M.getSymbolTableSure();
+  if (MST->find(Type::TypeTy) != MST->end())
+    for (SymbolTable::type_iterator I = MST->type_begin(Type::TypeTy),
+           E = MST->type_end(Type::TypeTy); I != E; ++I)
+      UT.erase(cast<Type>(I->second));
+
+  // UT now contains types that are not named.  Loop over it, naming structure
+  // types.
+  //
+  bool Changed = false;
+  for (std::set<const Type *>::const_iterator I = UT.begin(), E = UT.end();
+       I != E; ++I)
+    if (const StructType *ST = dyn_cast<StructType>(*I)) {
+      ((Value*)ST)->setName("unnamed", MST);
+      Changed = true;
+    }
+  return Changed;
+}
+
 void CWriter::printModule(Module *M) {
   // Calculate which global values have names that will collide when we throw
   // away type information.
@@ -430,7 +494,6 @@ void CWriter::printModule(Module *M) {
           FoundNames.insert(I->getName());   // Otherwise, keep track of name
   }
 
-
   // printing stdlib inclusion
   // Out << "#include <stdlib.h>\n";
 
@@ -439,9 +502,10 @@ void CWriter::printModule(Module *M) {
       << "#include <malloc.h>\n"
       << "#include <alloca.h>\n\n"
 
-    // Provide a definition for null if one does not already exist.
+    // Provide a definition for null if one does not already exist,
+    // and for `bool' if not compiling with a C++ compiler.
       << "#ifndef NULL\n#define NULL 0\n#endif\n\n"
-      << "typedef unsigned char bool;\n"
+      << "#ifndef __cplusplus\ntypedef unsigned char bool;\n#endif\n"
 
       << "\n\n/* Global Declarations */\n";
 
@@ -466,8 +530,10 @@ void CWriter::printModule(Module *M) {
   // Function declarations
   if (!M->empty()) {
     Out << "\n/* Function Declarations */\n";
-    for (Module::iterator I = M->begin(), E = M->end(); I != E; ++I)
-      printFunctionDecl(I);
+    for (Module::iterator I = M->begin(), E = M->end(); I != E; ++I) {
+      printFunctionSignature(I, true);
+      Out << ";\n";
+    }
   }
 
   // Output the global variable contents...
@@ -533,14 +599,7 @@ void CWriter::printSymbolTable(const SymbolTable &ST) {
 }
 
 
-// printFunctionDecl - Print function declaration
-//
-void CWriter::printFunctionDecl(const Function *F) {
-  printFunctionSignature(F);
-  Out << ";\n";
-}
-
-void CWriter::printFunctionSignature(const Function *F) {
+void CWriter::printFunctionSignature(const Function *F, bool Prototype) {
   if (F->hasInternalLinkage()) Out << "static ";
   
   // Loop over the arguments, printing them...
@@ -552,12 +611,20 @@ void CWriter::printFunctionSignature(const Function *F) {
     
   if (!F->isExternal()) {
     if (!F->aempty()) {
-      printType(F->afront().getType(), getValueName(F->abegin()));
+      string ArgName;
+      if (F->abegin()->hasName() || !Prototype)
+        ArgName = getValueName(F->abegin());
+
+      printType(F->afront().getType(), ArgName);
 
       for (Function::const_aiterator I = ++F->abegin(), E = F->aend();
            I != E; ++I) {
         Out << ", ";
-        printType(I->getType(), getValueName(I));
+        if (I->hasName() || !Prototype)
+          ArgName = getValueName(I);
+        else 
+          ArgName = "";
+        printType(I->getType(), ArgName);
       }
     }
   } else {
@@ -582,9 +649,9 @@ void CWriter::printFunctionSignature(const Function *F) {
 void CWriter::printFunction(Function *F) {
   if (F->isExternal()) return;
 
-  Table.incorporateFunction(F);
+  Table->incorporateFunction(F);
 
-  printFunctionSignature(F);
+  printFunctionSignature(F, false);
   Out << " {\n";
 
   // print local variable information for the function
@@ -632,7 +699,7 @@ void CWriter::printFunction(Function *F) {
   }
   
   Out << "}\n\n";
-  Table.purgeFunction();
+  Table->purgeFunction();
 }
 
 // Specific Instruction type classes... note that all of the casts are
@@ -752,7 +819,7 @@ void CWriter::visitBinaryOperator(Instruction &I) {
 
 void CWriter::visitCastInst(CastInst &I) {
   Out << "(";
-  printType(I.getType());
+  printType(I.getType(), string(""),/*ignoreName*/false, /*namedContext*/false);
   Out << ")";
   writeOperand(I.getOperand(0));
 }
@@ -762,7 +829,8 @@ void CWriter::visitCallInst(CallInst &I) {
   const FunctionType *FTy   = cast<FunctionType>(PTy->getElementType());
   const Type         *RetTy = FTy->getReturnType();
   
-  Out << getValueName(I.getOperand(0)) << "(";
+  writeOperand(I.getOperand(0));
+  Out << "(";
 
   if (I.getNumOperands() > 1) {
     writeOperand(I.getOperand(1));
@@ -842,29 +910,33 @@ void CWriter::printIndexingExpression(Value *Ptr, User::op_iterator I,
       Out << (HasImplicitAddress ? "." : "->");
       Out << "field" << cast<ConstantUInt>(*(I+1))->getValue();
       I += 2;
-    } else {  // Performing array indexing. Just skip the 0
+    } else {  // First array index of 0: Just skip it
       ++I;
     }
-  } else if (HasImplicitAddress) {
-    
   }
-    
+
   for (; I != E; ++I)
-    if ((*I)->getType() == Type::UIntTy) {
-      Out << "[";
+    if ((*I)->getType() == Type::LongTy) {
+      Out << "[((int) (";                 // sign-extend from 32 (to 64) bits
       writeOperand(*I);
-      Out << "]";
+      Out << " * sizeof(";
+      printType(cast<PointerType>(Ptr->getType())->getElementType());
+      Out << "))) / sizeof(";
+      printType(cast<PointerType>(Ptr->getType())->getElementType());
+      Out << ")]";
     } else {
       Out << ".field" << cast<ConstantUInt>(*I)->getValue();
     }
 }
 
 void CWriter::visitLoadInst(LoadInst &I) {
-  printIndexingExpression(I.getPointerOperand(), I.idx_begin(), I.idx_end());
+  Out << "*";
+  writeOperand(I.getOperand(0));
 }
 
 void CWriter::visitStoreInst(StoreInst &I) {
-  printIndexingExpression(I.getPointerOperand(), I.idx_begin(), I.idx_end());
+  Out << "*";
+  writeOperand(I.getPointerOperand());
   Out << " = ";
   writeOperand(I.getOperand(0));
 }
@@ -878,10 +950,4 @@ void CWriter::visitGetElementPtrInst(GetElementPtrInst &I) {
 //                       External Interface declaration
 //===----------------------------------------------------------------------===//
 
-void WriteToC(const Module *M, ostream &Out) {
-  assert(M && "You can't write a null module!!");
-  SlotCalculator SlotTable(M, false);
-  CWriter W(Out, SlotTable, M);
-  W.write((Module*)M);
-  Out.flush();
-}
+Pass *createWriteToCPass(std::ostream &o) { return new CWriter(o); }