Revert "Remove access to the DataLayout in the TargetMachine"
[oota-llvm.git] / examples / Kaleidoscope / Orc / fully_lazy / toy.cpp
index 9210dd1f3bb85e76c4960d20cd7346fd1c62355e..c9b2c6af56588b625f3c2e68fb1579fb71b476e8 100644 (file)
@@ -1,6 +1,7 @@
 #include "llvm/Analysis/Passes.h"
 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
+#include "llvm/ExecutionEngine/Orc/LambdaResolver.h"
 #include "llvm/ExecutionEngine/Orc/LazyEmittingLayer.h"
 #include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h"
 #include "llvm/ExecutionEngine/Orc/OrcTargetSupport.h"
@@ -20,7 +21,9 @@
 #include <sstream>
 #include <string>
 #include <vector>
+
 using namespace llvm;
+using namespace llvm::orc;
 
 //===----------------------------------------------------------------------===//
 // Lexer
@@ -680,14 +683,18 @@ std::string MakeLegalFunctionName(std::string Name)
 
 class SessionContext {
 public:
-  SessionContext(LLVMContext &C) : Context(C) {}
+  SessionContext(LLVMContext &C)
+    : Context(C), TM(EngineBuilder().selectTarget()) {}
   LLVMContext& getLLVMContext() const { return Context; }
+  TargetMachine& getTarget() { return *TM; }
   void addPrototypeAST(std::unique_ptr<PrototypeAST> P);
   PrototypeAST* getPrototypeAST(const std::string &Name);
-  std::map<std::string, std::unique_ptr<FunctionAST>> FunctionDefs; 
 private:
   typedef std::map<std::string, std::unique_ptr<PrototypeAST>> PrototypeMap;
+  
   LLVMContext &Context;
+  std::unique_ptr<TargetMachine> TM;
+      
   PrototypeMap Prototypes;
 };
 
@@ -709,7 +716,9 @@ public:
     : Session(S),
       M(new Module(GenerateUniqueName("jit_module_"),
                    Session.getLLVMContext())),
-      Builder(Session.getLLVMContext()) {}
+      Builder(Session.getLLVMContext()) {
+    M->setDataLayout(*Session.getTarget().getDataLayout());
+  }
 
   SessionContext& getSession() { return Session; }
   Module& getM() const { return *M; }
@@ -1138,6 +1147,12 @@ static std::unique_ptr<llvm::Module> IRGen(SessionContext &S,
   return C.takeM();
 }
 
+template <typename T>
+static std::vector<T> singletonSet(T t) {
+  std::vector<T> Vec;
+  Vec.push_back(std::move(t));
+  return Vec;
+}
 
 static void EarthShatteringKaboom() {
   fprintf(stderr, "Earth shattering kaboom.");
@@ -1149,111 +1164,121 @@ public:
   typedef ObjectLinkingLayer<> ObjLayerT;
   typedef IRCompileLayer<ObjLayerT> CompileLayerT;
   typedef LazyEmittingLayer<CompileLayerT> LazyEmitLayerT;
-
   typedef LazyEmitLayerT::ModuleSetHandleT ModuleHandleT;
 
-  std::string Mangle(const std::string &Name) {
+  KaleidoscopeJIT(SessionContext &Session)
+    : Session(Session),
+      CompileLayer(ObjectLayer, SimpleCompiler(Session.getTarget())),
+      LazyEmitLayer(CompileLayer),
+      CompileCallbacks(LazyEmitLayer, CCMgrMemMgr, Session.getLLVMContext(),
+                       reinterpret_cast<uintptr_t>(EarthShatteringKaboom),
+                       64) {}
+
+  std::string mangle(const std::string &Name) {
     std::string MangledName;
     {
       raw_string_ostream MangledNameStream(MangledName);
-      Mang.getNameWithPrefix(MangledNameStream, Name);
+      Mangler::getNameWithPrefix(MangledNameStream, Name,
+                                 *Session.getTarget().getDataLayout());
     }
     return MangledName;
   }
 
-  KaleidoscopeJIT(SessionContext &Session)
-    : TM(EngineBuilder().selectTarget()),
-      Mang(TM->getDataLayout()), Session(Session),
-      ObjectLayer(
-        [](){ return llvm::make_unique<SectionMemoryManager>(); }),
-      CompileLayer(ObjectLayer, SimpleCompiler(*TM)),
-      LazyEmitLayer(CompileLayer),
-      CompileCallbacks(LazyEmitLayer, Session.getLLVMContext(),
-                       reinterpret_cast<uintptr_t>(EarthShatteringKaboom),
-                       64) {}
+  void addFunctionAST(std::unique_ptr<FunctionAST> FnAST) {
+    std::cerr << "Adding AST: " << FnAST->Proto->Name << "\n";
+    FunctionDefs[mangle(FnAST->Proto->Name)] = std::move(FnAST);
+  }
 
   ModuleHandleT addModule(std::unique_ptr<Module> M) {
-    if (!M->getDataLayout())
-      M->setDataLayout(TM->getDataLayout());
-
-    // The LazyEmitLayer takes lists of modules, rather than single modules, so
-    // we'll just build a single-element list.
-    std::vector<std::unique_ptr<Module>> S;
-    S.push_back(std::move(M));
-
     // We need a memory manager to allocate memory and resolve symbols for this
-    // new module. Create one that resolves symbols by looking back into the JIT.
-    auto MM = createLookasideRTDyldMM<SectionMemoryManager>(
-                [&](const std::string &Name) -> uint64_t {
-                  // First try to find 'Name' within the JIT.
-                  if (auto Symbol = findMangledSymbol(Name))
-                    return Symbol.getAddress();
-
-                  // If we don't find 'Name' in the JIT, see if we have some AST
-                  // for it.
-                  auto DefI = Session.FunctionDefs.find(Name);
-                  if (DefI == Session.FunctionDefs.end())
-                    return 0;
-
-                  // We have AST for 'Name'. IRGen it, add it to the JIT, and
-                  // return the address for it.
-                  // FIXME: What happens if IRGen fails?
-                  addModule(IRGen(Session, *DefI->second));
-
-                  // Remove the function definition's AST now that we've
-                  // finished with it.
-                  Session.FunctionDefs.erase(DefI);
-
-                  return findMangledSymbol(Name).getAddress();
-                },
-                [](const std::string &S) { return 0; } );
-
-    return LazyEmitLayer.addModuleSet(std::move(S), std::move(MM));
+    // new module. Create one that resolves symbols by looking back into the
+    // JIT.
+    auto Resolver = createLambdaResolver(
+                      [&](const std::string &Name) {
+                        // First try to find 'Name' within the JIT.
+                        if (auto Symbol = findSymbol(Name))
+                          return RuntimeDyld::SymbolInfo(Symbol.getAddress(),
+                                                         Symbol.getFlags());
+
+                        // If we don't already have a definition of 'Name' then search
+                        // the ASTs.
+                        return searchFunctionASTs(Name);
+                      },
+                      [](const std::string &S) { return nullptr; } );
+
+    return LazyEmitLayer.addModuleSet(singletonSet(std::move(M)),
+                                      make_unique<SectionMemoryManager>(),
+                                      std::move(Resolver));
   }
 
   void removeModule(ModuleHandleT H) { LazyEmitLayer.removeModuleSet(H); }
 
-  JITSymbol findMangledSymbol(const std::string &Name) {
-    return LazyEmitLayer.findSymbol(Name, true);
+  JITSymbol findSymbol(const std::string &Name) {
+    return LazyEmitLayer.findSymbol(Name, false);
   }
 
-  JITSymbol findSymbol(const std::string &Name) {
-    return findMangledSymbol(Mangle(Name));
+  JITSymbol findSymbolIn(ModuleHandleT H, const std::string &Name) {
+    return LazyEmitLayer.findSymbolIn(H, Name, false);
   }
 
-  JITSymbol findMangledSymbolIn(LazyEmitLayerT::ModuleSetHandleT H,
-                                const std::string &Name) {
-    return LazyEmitLayer.findSymbolIn(H, Name, true); 
+  JITSymbol findUnmangledSymbol(const std::string &Name) {
+    return findSymbol(mangle(Name));
   }
 
-  JITSymbol findSymbolIn(LazyEmitLayerT::ModuleSetHandleT H,
-                         const std::string &Name) {
-    return findMangledSymbolIn(H, Mangle(Name));
+  JITSymbol findUnmangledSymbolIn(ModuleHandleT H, const std::string &Name) {
+    return findSymbolIn(H, mangle(Name));
   }
 
-  void addFunctionDefinition(std::unique_ptr<FunctionAST> FnAST) {
-    // Step 1) IRGen a prototype for this function:
+private:
+
+  // This method searches the FunctionDefs map for a definition of 'Name'. If it
+  // finds one it generates a stub for it and returns the address of the stub.
+  RuntimeDyld::SymbolInfo searchFunctionASTs(const std::string &Name) {
+    auto DefI = FunctionDefs.find(Name);
+    if (DefI == FunctionDefs.end())
+      return 0;
+
+    // Return the address of the stub.
+    // Take the FunctionAST out of the map.
+    auto FnAST = std::move(DefI->second);
+    FunctionDefs.erase(DefI);
+
+    // IRGen the AST, add it to the JIT, and return the address for it.
+    auto H = irGenStub(std::move(FnAST));
+    auto Sym = findSymbolIn(H, Name);
+    return RuntimeDyld::SymbolInfo(Sym.getAddress(), Sym.getFlags());
+  }
+
+  // This method will take the AST for a function definition and IR-gen a stub
+  // for that function that will, on first call, IR-gen the actual body of the
+  // function.
+  ModuleHandleT irGenStub(std::unique_ptr<FunctionAST> FnAST) {
+    // Step 1) IRGen a prototype for the stub. This will have the same type as
+    //         the function.
     IRGenContext C(Session);
     Function *F = FnAST->Proto->IRGen(C);
-    C.getM().setDataLayout(TM->getDataLayout());
 
-    // Step 2) Create a compile callback that will be used to compile this
-    //         function when it is first called.
+    // Step 2) Get a compile callback that can be used to compile the body of
+    //         the function. The resulting CallbackInfo type will let us set the
+    //         compile and update actions for the callback, and get a pointer to
+    //         the jit trampoline that we need to call to trigger those actions.
     auto CallbackInfo =
-      CompileCallbacks.getCompileCallback(*F->getFunctionType());
+      CompileCallbacks.getCompileCallback(F->getContext());
 
     // Step 3) Create a stub that will indirectly call the body of this
     //         function once it is compiled. Initially, set the function
-    //         pointer for the indirection to point at the compile callback.
+    //         pointer for the indirection to point at the trampoline.
     std::string BodyPtrName = (F->getName() + "$address").str();
     GlobalVariable *FunctionBodyPointer =
-      createImplPointer(*F, BodyPtrName, CallbackInfo.getAddress());
+      createImplPointer(*F->getType(), *F->getParent(), BodyPtrName,
+                        createIRTypedAddress(*F->getFunctionType(),
+                                             CallbackInfo.getAddress()));
     makeStub(*F, *FunctionBodyPointer);
 
     // Step 4) Add the module containing the stub to the JIT.
-    auto H = addModule(C.takeM());
+    auto StubH = addModule(C.takeM());
 
-    // Step 5) Set the compile and update actions for the callback.
+    // Step 5) Set the compile and update actions.
     //
     //   The compile action will IRGen the function and add it to the JIT, then
     // request its address, which will trigger codegen. Since we don't need the
@@ -1263,32 +1288,38 @@ public:
     //
     //   The update action will update FunctionBodyPointer to point at the newly
     // compiled function.
-    CallbackInfo.setCompileAction(
-      [this,Fn = std::shared_ptr<FunctionAST>(std::move(FnAST))](){
-        auto H = addModule(IRGen(Session, *Fn));
-        return findSymbolIn(H, Fn->Proto->Name).getAddress();
-      });
-    CallbackInfo.setUpdateAction(
-      CompileCallbacks.getLocalFPUpdater(H, Mangle(BodyPtrName)));
+    std::shared_ptr<FunctionAST> Fn = std::move(FnAST);
+    CallbackInfo.setCompileAction([this, Fn, BodyPtrName, StubH]() {
+      auto H = addModule(IRGen(Session, *Fn));
+      auto BodySym = findUnmangledSymbolIn(H, Fn->Proto->Name);
+      auto BodyPtrSym = findUnmangledSymbolIn(StubH, BodyPtrName);
+      assert(BodySym && "Missing function body.");
+      assert(BodyPtrSym && "Missing function pointer.");
+      auto BodyAddr = BodySym.getAddress();
+      auto BodyPtr = reinterpret_cast<void*>(
+                       static_cast<uintptr_t>(BodyPtrSym.getAddress()));
+      memcpy(BodyPtr, &BodyAddr, sizeof(uintptr_t));
+      return BodyAddr;
+    });
+
+    return StubH;
   }
 
-private:
-
-  std::unique_ptr<TargetMachine> TM;
-  Mangler Mang;
   SessionContext &Session;
-
+  SectionMemoryManager CCMgrMemMgr;
   ObjLayerT ObjectLayer;
   CompileLayerT CompileLayer;
   LazyEmitLayerT LazyEmitLayer;
 
+  std::map<std::string, std::unique_ptr<FunctionAST>> FunctionDefs;
+
   JITCompileCallbackManager<LazyEmitLayerT, OrcX86_64> CompileCallbacks;
 };
 
 static void HandleDefinition(SessionContext &S, KaleidoscopeJIT &J) {
   if (auto F = ParseDefinition()) {
     S.addPrototypeAST(llvm::make_unique<PrototypeAST>(*F->Proto));
-    J.addFunctionDefinition(std::move(F));
+    J.addFunctionAST(std::move(F));
   } else {
     // Skip token for error recovery.
     getNextToken();
@@ -1318,7 +1349,7 @@ static void HandleTopLevelExpression(SessionContext &S, KaleidoscopeJIT &J) {
       auto H = J.addModule(C.takeM());
 
       // Get the address of the JIT'd function in memory.
-      auto ExprSymbol = J.findSymbol("__anon_expr");
+      auto ExprSymbol = J.findUnmangledSymbol("__anon_expr");
       
       // Cast it to the right type (takes no arguments, returns a double) so we
       // can call it as a native function.