Temporarily Revert "Nuke the old JIT." as it's not quite ready to
[oota-llvm.git] / examples / Kaleidoscope / MCJIT / complete / toy.cpp
index 3beb0d8378938bf6f217f7a2dd40c7c4dc74a719..10e7ada1e88d75aaf359076eb57607b13060198f 100644 (file)
@@ -1,5 +1,6 @@
 #include "llvm/Analysis/Passes.h"
 #include "llvm/ExecutionEngine/ExecutionEngine.h"
+#include "llvm/ExecutionEngine/JIT.h"
 #include "llvm/ExecutionEngine/MCJIT.h"
 #include "llvm/ExecutionEngine/ObjectCache.h"
 #include "llvm/ExecutionEngine/SectionMemoryManager.h"
@@ -51,6 +52,10 @@ namespace {
                   cl::desc("Dump IR from modules to stderr on shutdown"),
                   cl::init(false));
 
+  cl::opt<bool> UseMCJIT(
+    "use-mcjit", cl::desc("Use the MCJIT execution engine"),
+    cl::init(true));
+
   cl::opt<bool> EnableLazyCompilation(
     "enable-lazy-compilation", cl::desc("Enable lazy compilation when using the MCJIT engine"),
     cl::init(true));
@@ -787,6 +792,96 @@ public:
   virtual void dump();
 };
 
+//===----------------------------------------------------------------------===//
+// Helper class for JIT execution engine
+//===----------------------------------------------------------------------===//
+
+class JITHelper : public BaseHelper {
+public:
+  JITHelper(LLVMContext &Context) {
+    // Make the module, which holds all the code.
+    if (!InputIR.empty()) {
+      TheModule = parseInputIR(InputIR, Context);
+    } else {
+      TheModule = new Module("my cool jit", Context);
+    }
+
+    // Create the JIT.  This takes ownership of the module.
+    std::string ErrStr;
+    TheExecutionEngine = EngineBuilder(TheModule).setErrorStr(&ErrStr).create();
+    if (!TheExecutionEngine) {
+      fprintf(stderr, "Could not create ExecutionEngine: %s\n", ErrStr.c_str());
+      exit(1);
+    }
+
+    TheFPM = new FunctionPassManager(TheModule);
+
+    // Set up the optimizer pipeline.  Start with registering info about how the
+    // target lays out data structures.
+    TheFPM->add(new DataLayout(*TheExecutionEngine->getDataLayout()));
+    // Provide basic AliasAnalysis support for GVN.
+    TheFPM->add(createBasicAliasAnalysisPass());
+    // Promote allocas to registers.
+    TheFPM->add(createPromoteMemoryToRegisterPass());
+    // Do simple "peephole" optimizations and bit-twiddling optzns.
+    TheFPM->add(createInstructionCombiningPass());
+    // Reassociate expressions.
+    TheFPM->add(createReassociatePass());
+    // Eliminate Common SubExpressions.
+    TheFPM->add(createGVNPass());
+    // Simplify the control flow graph (deleting unreachable blocks, etc).
+    TheFPM->add(createCFGSimplificationPass());
+
+    TheFPM->doInitialization();
+  }
+
+  virtual ~JITHelper() {
+    if (TheFPM)
+      delete TheFPM;
+    if (TheExecutionEngine)
+      delete TheExecutionEngine;
+  }
+
+  virtual Function *getFunction(const std::string FnName) {
+    assert(TheModule);
+    return TheModule->getFunction(FnName);
+  }
+
+  virtual Module *getModuleForNewFunction() {
+    assert(TheModule);
+    return TheModule;
+  }
+
+  virtual void *getPointerToFunction(Function* F) {
+    assert(TheExecutionEngine);
+    return TheExecutionEngine->getPointerToFunction(F);
+  }
+
+  virtual void *getPointerToNamedFunction(const std::string &Name) {
+    return TheExecutionEngine->getPointerToNamedFunction(Name);
+  }
+
+  virtual void runFPM(Function &F) {
+    assert(TheFPM);
+    TheFPM->run(F);
+  }
+
+  virtual void closeCurrentModule() {
+    // This should never be called for JIT
+    assert(false);
+  }
+
+  virtual void dump() {
+    assert(TheModule);
+    TheModule->dump();
+  }
+
+private:
+  Module *TheModule;
+  ExecutionEngine *TheExecutionEngine;
+  FunctionPassManager *TheFPM;
+};
+
 //===----------------------------------------------------------------------===//
 // MCJIT helper class
 //===----------------------------------------------------------------------===//
@@ -939,6 +1034,7 @@ ExecutionEngine *MCJITHelper::compileModule(Module *M) {
   std::string ErrStr;
   ExecutionEngine *EE = EngineBuilder(M)
                             .setErrorStr(&ErrStr)
+                            .setUseMCJIT(true)
                             .setMCJITMemoryManager(new HelpingMemoryManager(this))
                             .create();
   if (!EE) {
@@ -1098,8 +1194,10 @@ Value *UnaryExprAST::Codegen() {
   Value *OperandV = Operand->Codegen();
   if (OperandV == 0) return 0;
   Function *F;
-  F = TheHelper->getFunction(
-      MakeLegalFunctionName(std::string("unary") + Opcode));
+  if (UseMCJIT)
+    F = TheHelper->getFunction(MakeLegalFunctionName(std::string("unary")+Opcode));
+  else
+    F = TheHelper->getFunction(std::string("unary")+Opcode);
   if (F == 0)
     return ErrorV("Unknown unary operator");
 
@@ -1148,7 +1246,10 @@ Value *BinaryExprAST::Codegen() {
   // If it wasn't a builtin binary operator, it must be a user defined one. Emit
   // a call to it.
   Function *F;
-  F = TheHelper->getFunction(MakeLegalFunctionName(std::string("binary")+Op));
+  if (UseMCJIT)
+    F = TheHelper->getFunction(MakeLegalFunctionName(std::string("binary")+Op));
+  else
+    F = TheHelper->getFunction(std::string("binary")+Op);
   assert(F && "binary operator not found!");
 
   Value *Ops[] = { L, R };
@@ -1381,7 +1482,10 @@ Function *PrototypeAST::Codegen() {
                                        Doubles, false);
 
   std::string FnName;
-  FnName = MakeLegalFunctionName(Name);
+  if (UseMCJIT)
+    FnName = MakeLegalFunctionName(Name);
+  else
+    FnName = Name;
 
   Module* M = TheHelper->getModuleForNewFunction();
   Function *F = Function::Create(FT, Function::ExternalLinkage, FnName, M);
@@ -1456,6 +1560,10 @@ Function *FunctionAST::Codegen() {
     // Validate the generated code, checking for consistency.
     verifyFunction(*TheFunction);
 
+    // Optimize the function.
+    if (!UseMCJIT)
+      TheHelper->runFPM(*TheFunction);
+
     return TheFunction;
   }
 
@@ -1473,7 +1581,7 @@ Function *FunctionAST::Codegen() {
 
 static void HandleDefinition() {
   if (FunctionAST *F = ParseDefinition()) {
-    if (EnableLazyCompilation)
+    if (UseMCJIT && EnableLazyCompilation)
       TheHelper->closeCurrentModule();
     Function *LF = F->Codegen();
     if (LF && VerboseOutput) {
@@ -1563,8 +1671,10 @@ double printlf() {
 
 int main(int argc, char **argv) {
   InitializeNativeTarget();
-  InitializeNativeTargetAsmPrinter();
-  InitializeNativeTargetAsmParser();
+  if (UseMCJIT) {
+    InitializeNativeTargetAsmPrinter();
+    InitializeNativeTargetAsmParser();
+  }
   LLVMContext &Context = getGlobalContext();
 
   cl::ParseCommandLineOptions(argc, argv,
@@ -1580,7 +1690,10 @@ int main(int argc, char **argv) {
   BinopPrecedence['*'] = 40;  // highest.
 
   // Make the Helper, which holds all the code.
-  TheHelper = new MCJITHelper(Context);
+  if (UseMCJIT)
+    TheHelper = new MCJITHelper(Context);
+  else
+    TheHelper = new JITHelper(Context);
 
   // Prime the first token.
   if (!SuppressPrompts)