[MCJIT][Orc] Refactor RTDyldMemoryManager, weave RuntimeDyld::SymbolInfo through
[oota-llvm.git] / include / llvm / ExecutionEngine / Orc / CompileOnDemandLayer.h
index b10a275b12e72031c9cd94da0f7ce287b1cae976..7c1398a51c8ef4476072bd2ae1797a224652ad53 100644 (file)
@@ -16,7 +16,7 @@
 #define LLVM_EXECUTIONENGINE_ORC_COMPILEONDEMANDLAYER_H
 
 #include "IndirectionUtils.h"
-#include "LookasideRTDyldMM.h"
+#include "LambdaResolver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ExecutionEngine/SectionMemoryManager.h"
 #include <list>
@@ -139,12 +139,11 @@ public:
   typedef typename ModuleSetInfoListT::iterator ModuleSetHandleT;
 
   // @brief Fallback lookup functor.
-  typedef std::function<uint64_t(const std::string &)> LookupFtor;
+  typedef std::function<RuntimeDyld::SymbolInfo(const std::string &)> LookupFtor;
 
   /// @brief Construct a compile-on-demand layer instance.
-  CompileOnDemandLayer(BaseLayerT &BaseLayer, LLVMContext &Context)
-    : BaseLayer(BaseLayer),
-      CompileCallbackMgr(BaseLayer, Context, 0, 64) {}
+  CompileOnDemandLayer(BaseLayerT &BaseLayer, CompileCallbackMgrT &CallbackMgr)
+      : BaseLayer(BaseLayer), CompileCallbackMgr(CallbackMgr) {}
 
   /// @brief Add a module to the compile-on-demand layer.
   template <typename ModuleSetT>
@@ -154,9 +153,13 @@ public:
     // If the user didn't supply a fallback lookup then just use
     // getSymbolAddress.
     if (!FallbackLookup)
-      FallbackLookup = [=](const std::string &Name) {
-                         return findSymbol(Name, true).getAddress();
-                       };
+      FallbackLookup =
+        [=](const std::string &Name) -> RuntimeDyld::SymbolInfo {
+          if (auto Symbol = findSymbol(Name, true))
+            return RuntimeDyld::SymbolInfo(Symbol.getAddress(),
+                                           Symbol.getFlags());
+          return nullptr;
+        };
 
     // Create a lookup context and ModuleSetInfo for this module set.
     // For the purposes of symbol resolution the set Ms will be treated as if
@@ -194,8 +197,8 @@ public:
   ///        below this one.
   JITSymbol findSymbolIn(ModuleSetHandleT H, const std::string &Name,
                          bool ExportedSymbolsOnly) {
-    BaseLayerModuleSetHandleListT &BaseLayerHandles = H->second;
-    for (auto &BH : BaseLayerHandles) {
+
+    for (auto &BH : H->BaseLayerModuleSetHandles) {
       if (auto Symbol = BaseLayer.findSymbolIn(BH, Name, ExportedSymbolsOnly))
         return Symbol;
     }
@@ -256,10 +259,11 @@ private:
         Function *Proto = StubsModule->getFunction(Name);
         assert(Proto && "Failed to clone function decl into stubs module.");
         auto CallbackInfo =
-          CompileCallbackMgr.getCompileCallback(*Proto->getFunctionType());
+          CompileCallbackMgr.getCompileCallback(Proto->getContext());
         GlobalVariable *FunctionBodyPointer =
           createImplPointer(*Proto, Name + AddrSuffix,
-                            CallbackInfo.getAddress());
+                            createIRTypedAddress(*Proto->getFunctionType(),
+                                                 CallbackInfo.getAddress()));
         makeStub(*Proto, *FunctionBodyPointer);
 
         F.setName(Name + BodySuffix);
@@ -294,7 +298,7 @@ private:
                                     M.getDataLayout());
       auto &CCInfo = KVPair.second;
       CCInfo.setUpdateAction(
-        CompileCallbackMgr.getLocalFPUpdater(StubsH, AddrName));
+        getLocalFPUpdater(BaseLayer, StubsH, AddrName));
     }
   }
 
@@ -314,19 +318,25 @@ private:
     MSet.push_back(std::move(M));
 
     auto DylibLookup = MSI.Lookup;
-    auto MM =
-      createLookasideRTDyldMM<SectionMemoryManager>(
+    auto Resolver =
+      createLambdaResolver(
         [=](const std::string &Name) {
           if (auto Symbol = DylibLookup->findSymbol(LogicalModule, Name))
-            return Symbol.getAddress();
+            return RuntimeDyld::SymbolInfo(Symbol.getAddress(),
+                                           Symbol.getFlags());
           return FallbackLookup(Name);
         },
-        [=](const std::string &Name) {
-          return DylibLookup->findSymbol(LogicalModule, Name).getAddress();
+        [=](const std::string &Name) -> RuntimeDyld::SymbolInfo {
+          if (auto Symbol = DylibLookup->findSymbol(LogicalModule, Name))
+            return RuntimeDyld::SymbolInfo(Symbol.getAddress(),
+                                           Symbol.getFlags());
+          return nullptr;
         });
 
     BaseLayerModuleSetHandleT H =
-      BaseLayer.addModuleSet(std::move(MSet), std::move(MM));
+      BaseLayer.addModuleSet(std::move(MSet),
+                             make_unique<SectionMemoryManager>(),
+                             std::move(Resolver));
     // Add this module to the logical module lookup.
     DylibLookup->addToLogicalModule(LogicalModule, H);
     MSI.BaseLayerModuleSetHandles.push_back(H);
@@ -345,7 +355,7 @@ private:
   }
 
   BaseLayerT &BaseLayer;
-  CompileCallbackMgrT CompileCallbackMgr;
+  CompileCallbackMgrT &CompileCallbackMgr;
   ModuleSetInfoListT ModuleSetInfos;
 };