Reorganize the C API headers to improve build times.
[oota-llvm.git] / unittests / ExecutionEngine / Orc / OrcCAPITest.cpp
index bddb6a76bbf70bd5a45a27c60d67a341b48ef9df..776d26970a311b7d163a977d5f8a7f5c55febb0b 100644 (file)
@@ -9,6 +9,7 @@
 
 #include "OrcTestCommon.h"
 #include "gtest/gtest.h"
+#include "llvm-c/Core.h"
 #include "llvm-c/OrcBindings.h"
 #include "llvm-c/Target.h"
 #include "llvm-c/TargetMachine.h"
 
 namespace llvm {
 
-DEFINE_SIMPLE_CONVERSION_FUNCTIONS(TargetMachine, LLVMTargetMachineRef);
+DEFINE_SIMPLE_CONVERSION_FUNCTIONS(TargetMachine, LLVMTargetMachineRef)
 
 class OrcCAPIExecutionTest : public testing::Test, public OrcExecutionTest {
 protected:
-
   std::unique_ptr<Module> createTestModule(const Triple &TT) {
     ModuleBuilder MB(getGlobalContext(), TT.str(), "");
     Function *TestFunc = MB.createFunctionDecl<int()>("testFunc");
@@ -37,9 +37,9 @@ protected:
     return MB.takeModule();
   }
 
-  typedef int (*MainFnTy)(void);
+  typedef int (*MainFnTy)();
 
-  static int myTestFuncImpl(void) {
+  static int myTestFuncImpl() {
     return 42;
   }
 
@@ -51,25 +51,44 @@ protected:
     return 0;
   }
 
+  struct CompileContext {
+    CompileContext() : Compiled(false) { }
+
+    OrcCAPIExecutionTest* APIExecTest;
+    std::unique_ptr<Module> M;
+    LLVMOrcModuleHandle H;
+    bool Compiled;
+  };
+
+  static LLVMOrcTargetAddress myCompileCallback(LLVMOrcJITStackRef JITStack,
+                                                void *Ctx) {
+    CompileContext *CCtx = static_cast<CompileContext*>(Ctx);
+    auto *ET = CCtx->APIExecTest;
+    CCtx->M = ET->createTestModule(ET->TM->getTargetTriple());
+    CCtx->H = LLVMOrcAddEagerlyCompiledIR(JITStack, wrap(CCtx->M.get()),
+                                          myResolver, nullptr);
+    CCtx->Compiled = true;
+    LLVMOrcTargetAddress MainAddr = LLVMOrcGetSymbolAddress(JITStack, "main");
+    LLVMOrcSetIndirectStubPointer(JITStack, "foo", MainAddr);
+    return MainAddr;
+  }
 };
 
-char *OrcCAPIExecutionTest::testFuncName = 0;
+char *OrcCAPIExecutionTest::testFuncName = nullptr;
 
 TEST_F(OrcCAPIExecutionTest, TestEagerIRCompilation) {
-  auto TM = getHostTargetMachineIfSupported();
-
   if (!TM)
     return;
 
-  std::unique_ptr<Module> M = createTestModule(TM->getTargetTriple());
-
   LLVMOrcJITStackRef JIT =
-    LLVMOrcCreateInstance(wrap(TM.get()), LLVMGetGlobalContext());
+    LLVMOrcCreateInstance(wrap(TM.get()));
+
+  std::unique_ptr<Module> M = createTestModule(TM->getTargetTriple());
 
   LLVMOrcGetMangledSymbol(JIT, &testFuncName, "testFunc");
 
   LLVMOrcModuleHandle H =
-    LLVMOrcAddEagerlyCompiledIR(JIT, wrap(M.get()), myResolver, 0);
+    LLVMOrcAddEagerlyCompiledIR(JIT, wrap(M.get()), myResolver, nullptr);
   MainFnTy MainFn = (MainFnTy)LLVMOrcGetSymbolAddress(JIT, "main");
   int Result = MainFn();
   EXPECT_EQ(Result, 42)
@@ -82,19 +101,18 @@ TEST_F(OrcCAPIExecutionTest, TestEagerIRCompilation) {
 }
 
 TEST_F(OrcCAPIExecutionTest, TestLazyIRCompilation) {
-  auto TM = getHostTargetMachineIfSupported();
-
   if (!TM)
     return;
 
-  std::unique_ptr<Module> M = createTestModule(TM->getTargetTriple());
-
   LLVMOrcJITStackRef JIT =
-    LLVMOrcCreateInstance(wrap(TM.get()), LLVMGetGlobalContext());
+    LLVMOrcCreateInstance(wrap(TM.get()));
+
+  std::unique_ptr<Module> M = createTestModule(TM->getTargetTriple());
 
   LLVMOrcGetMangledSymbol(JIT, &testFuncName, "testFunc");
+
   LLVMOrcModuleHandle H =
-    LLVMOrcAddLazilyCompiledIR(JIT, wrap(M.get()), myResolver, 0);
+    LLVMOrcAddLazilyCompiledIR(JIT, wrap(M.get()), myResolver, nullptr);
   MainFnTy MainFn = (MainFnTy)LLVMOrcGetSymbolAddress(JIT, "main");
   int Result = MainFn();
   EXPECT_EQ(Result, 42)
@@ -106,4 +124,37 @@ TEST_F(OrcCAPIExecutionTest, TestLazyIRCompilation) {
   LLVMOrcDisposeInstance(JIT);
 }
 
+TEST_F(OrcCAPIExecutionTest, TestDirectCallbacksAPI) {
+  if (!TM)
+    return;
+
+  LLVMOrcJITStackRef JIT =
+    LLVMOrcCreateInstance(wrap(TM.get()));
+
+  LLVMOrcGetMangledSymbol(JIT, &testFuncName, "testFunc");
+
+  CompileContext C;
+  C.APIExecTest = this;
+  LLVMOrcCreateIndirectStub(JIT, "foo",
+                            LLVMOrcCreateLazyCompileCallback(JIT,
+                                                             myCompileCallback,
+                                                             &C));
+  MainFnTy FooFn = (MainFnTy)LLVMOrcGetSymbolAddress(JIT, "foo");
+  int Result = FooFn();
+  EXPECT_TRUE(C.Compiled)
+    << "Function wasn't lazily compiled";
+  EXPECT_EQ(Result, 42)
+    << "Direct-callback JIT'd code did not return expected result";
+
+  C.Compiled = false;
+  FooFn();
+  EXPECT_FALSE(C.Compiled)
+    << "Direct-callback JIT'd code was JIT'd twice";
+
+  LLVMOrcRemoveModule(JIT, C.H);
+
+  LLVMOrcDisposeMangledSymbol(testFuncName);
+  LLVMOrcDisposeInstance(JIT);
 }
+
+} // namespace llvm