Big Kaleidoscope tutorial update.
authorLang Hames <lhames@gmail.com>
Wed, 26 Aug 2015 03:07:41 +0000 (03:07 +0000)
committerLang Hames <lhames@gmail.com>
Wed, 26 Aug 2015 03:07:41 +0000 (03:07 +0000)
This commit switches the underlying JIT for the Kaleidoscope tutorials from
MCJIT to a custom ORC-based JIT, KaleidoscopeJIT. This fixes a lot of the bugs
in Kaleidoscope that were introduced when we deleted the legacy JIT. The
documentation for Chapter 4, which introduces the JIT APIs, is updated to
reflect the change.

Also included are a number of C++11 modernizations and general cleanup. Where
appropriate, the docs have been updated to reflect these changes too.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@246002 91177308-0d34-0410-b5e6-96231b3b80d8

15 files changed:
docs/tutorial/LangImpl2.rst
docs/tutorial/LangImpl3.rst
docs/tutorial/LangImpl4.rst
docs/tutorial/LangImpl5.rst
docs/tutorial/LangImpl6.rst
docs/tutorial/LangImpl7.rst
docs/tutorial/LangImpl8.rst
examples/Kaleidoscope/Chapter2/toy.cpp
examples/Kaleidoscope/Chapter3/toy.cpp
examples/Kaleidoscope/Chapter4/toy.cpp
examples/Kaleidoscope/Chapter5/toy.cpp
examples/Kaleidoscope/Chapter6/toy.cpp
examples/Kaleidoscope/Chapter7/toy.cpp
examples/Kaleidoscope/Chapter8/toy.cpp
examples/Kaleidoscope/include/KaleidoscopeJIT.h [new file with mode: 0644]

index 09c55e60f3ac6c579d69aafc1d6aa533080ac925..92a266eeb031b838e32d522357b855d0079d7297 100644 (file)
@@ -85,7 +85,7 @@ language:
     /// CallExprAST - Expression class for function calls.
     class CallExprAST : public ExprAST {
       std::string Callee;
-      std::vector<ExprAST*> Args;
+      std::vector<std::unique_ptr<ExprAST>> Args;
 
     public:
       CallExprAST(const std::string &Callee,
index d80140ef241b1b724bff64f34334a5d7e99e749b..49711d581b9410915551b9180e5362b4790c89f5 100644 (file)
@@ -15,8 +15,8 @@ LLVM IR. This will teach you a little bit about how LLVM does things, as
 well as demonstrate how easy it is to use. It's much more work to build
 a lexer and parser than it is to generate LLVM IR code. :)
 
-**Please note**: the code in this chapter and later require LLVM 2.2 or
-later. LLVM 2.1 and before will not work with it. Also note that you
+**Please note**: the code in this chapter and later require LLVM 3.7 or
+later. LLVM 3.6 and before will not work with it. Also note that you
 need to use a version of this tutorial that matches your LLVM release:
 If you are using an official LLVM release, use the version of the
 documentation included with your release or on the `llvm.org releases
@@ -35,7 +35,7 @@ class:
     class ExprAST {
     public:
       virtual ~ExprAST() {}
-      virtual Value *Codegen() = 0;
+      virtual Value *codegen() = 0;
     };
 
     /// NumberExprAST - Expression class for numeric literals like "1.0".
@@ -44,11 +44,11 @@ class:
 
     public:
       NumberExprAST(double Val) : Val(Val) {}
-      virtual Value *Codegen();
+      virtual Value *codegen();
     };
     ...
 
-The Codegen() method says to emit IR for that AST node along with all
+The codegen() method says to emit IR for that AST node along with all
 the things it depends on, and they all return an LLVM Value object.
 "Value" is the class used to represent a "`Static Single Assignment
 (SSA) <http://en.wikipedia.org/wiki/Static_single_assignment_form>`_
@@ -73,19 +73,20 @@ parser, which will be used to report errors found during code generation
 
 .. code-block:: c++
 
+    static std::unique_ptr<Module> *TheModule;
+    static IRBuilder<> Builder(getGlobalContext());
+    static std::map<std::string, Value*> NamedValues;
+
     Value *ErrorV(const char *Str) {
       Error(Str);
       return nullptr;
     }
 
-    static Module *TheModule;
-    static IRBuilder<> Builder(getGlobalContext());
-    static std::map<std::string, Value*> NamedValues;
-
 The static variables will be used during code generation. ``TheModule``
-is the LLVM construct that contains all of the functions and global
-variables in a chunk of code. In many ways, it is the top-level
-structure that the LLVM IR uses to contain code.
+is an LLVM construct that contains functions and global variables. In many
+ways, it is the top-level structure that the LLVM IR uses to contain code.
+It will own the memory for all of the IR that we generate, which is why
+the codegen() method returns a raw Value\*, rather than a unique_ptr<Value>.
 
 The ``Builder`` object is a helper object that makes it easy to generate
 LLVM instructions. Instances of the
@@ -114,7 +115,7 @@ First we'll do numeric literals:
 
 .. code-block:: c++
 
-    Value *NumberExprAST::Codegen() {
+    Value *NumberExprAST::codegen() {
       return ConstantFP::get(getGlobalContext(), APFloat(Val));
     }
 
@@ -128,7 +129,7 @@ are all uniqued together and shared. For this reason, the API uses the
 
 .. code-block:: c++
 
-    Value *VariableExprAST::Codegen() {
+    Value *VariableExprAST::codegen() {
       // Look this variable up in the function.
       Value *V = NamedValues[Name];
       if (!V)
@@ -148,9 +149,9 @@ variables <LangImpl7.html#localvars>`_.
 
 .. code-block:: c++
 
-    Value *BinaryExprAST::Codegen() {
-      Value *L = LHS->Codegen();
-      Value *R = RHS->Codegen();
+    Value *BinaryExprAST::codegen() {
+      Value *L = LHS->codegen();
+      Value *R = RHS->codegen();
       if (!L || !R)
         return nullptr;
 
@@ -209,7 +210,7 @@ would return 0.0 and -1.0, depending on the input value.
 
 .. code-block:: c++
 
-    Value *CallExprAST::Codegen() {
+    Value *CallExprAST::codegen() {
       // Look up the name in the global module table.
       Function *CalleeF = TheModule->getFunction(Callee);
       if (!CalleeF)
@@ -221,7 +222,7 @@ would return 0.0 and -1.0, depending on the input value.
 
       std::vector<Value *> ArgsV;
       for (unsigned i = 0, e = Args.size(); i != e; ++i) {
-        ArgsV.push_back(Args[i]->Codegen());
+        ArgsV.push_back(Args[i]->codegen());
         if (!ArgsV.back())
           return nullptr;
       }
@@ -229,12 +230,11 @@ would return 0.0 and -1.0, depending on the input value.
       return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
     }
 
-Code generation for function calls is quite straightforward with LLVM.
-The code above initially does a function name lookup in the LLVM
-Module's symbol table. Recall that the LLVM Module is the container that
-holds all of the functions we are JIT'ing. By giving each function the
-same name as what the user specifies, we can use the LLVM symbol table
-to resolve function names for us.
+Code generation for function calls is quite straightforward with LLVM. The code
+above initially does a function name lookup in the LLVM Module's symbol table.
+Recall that the LLVM Module is the container that holds the functions we are
+JIT'ing. By giving each function the same name as what the user specifies, we
+can use the LLVM symbol table to resolve function names for us.
 
 Once we have the function to call, we recursively codegen each argument
 that is to be passed in, and create an LLVM `call
@@ -261,7 +261,7 @@ with:
 
 .. code-block:: c++
 
-    Function *PrototypeAST::Codegen() {
+    Function *PrototypeAST::codegen() {
       // Make the function type:  double(double,double) etc.
       std::vector<Type*> Doubles(Args.size(),
                                  Type::getDoubleTy(getGlobalContext()));
@@ -286,119 +286,67 @@ double as a result, and that is not vararg (the false parameter
 indicates this). Note that Types in LLVM are uniqued just like Constants
 are, so you don't "new" a type, you "get" it.
 
-The final line above actually creates the function that the prototype
-will correspond to. This indicates the type, linkage and name to use, as
+The final line above actually creates the IR Function corresponding to
+the Prototype. This indicates the type, linkage and name to use, as
 well as which module to insert into. "`external
 linkage <../LangRef.html#linkage>`_" means that the function may be
 defined outside the current module and/or that it is callable by
 functions outside the module. The Name passed in is the name the user
 specified: since "``TheModule``" is specified, this name is registered
-in "``TheModule``"s symbol table, which is used by the function call
-code above.
+in "``TheModule``"s symbol table.
 
 .. code-block:: c++
 
-      // If F conflicted, there was already something named 'Name'.  If it has a
-      // body, don't allow redefinition or reextern.
-      if (F->getName() != Name) {
-        // Delete the one we just made and get the existing one.
-        F->eraseFromParent();
-        F = TheModule->getFunction(Name);
-
-The Module symbol table works just like the Function symbol table when
-it comes to name conflicts: if a new function is created with a name
-that was previously added to the symbol table, the new function will get
-implicitly renamed when added to the Module. The code above exploits
-this fact to determine if there was a previous definition of this
-function.
-
-In Kaleidoscope, I choose to allow redefinitions of functions in two
-cases: first, we want to allow 'extern'ing a function more than once, as
-long as the prototypes for the externs match (since all arguments have
-the same type, we just have to check that the number of arguments
-match). Second, we want to allow 'extern'ing a function and then
-defining a body for it. This is useful when defining mutually recursive
-functions.
-
-In order to implement this, the code above first checks to see if there
-is a collision on the name of the function. If so, it deletes the
-function we just created (by calling ``eraseFromParent``) and then
-calling ``getFunction`` to get the existing function with the specified
-name. Note that many APIs in LLVM have "erase" forms and "remove" forms.
-The "remove" form unlinks the object from its parent (e.g. a Function
-from a Module) and returns it. The "erase" form unlinks the object and
-then deletes it.
+  // Set names for all arguments.
+  unsigned Idx = 0;
+  for (auto &Arg : F->args())
+    Arg.setName(Args[Idx++]);
 
-.. code-block:: c++
+  return F;
 
-        // If F already has a body, reject this.
-        if (!F->empty()) {
-          ErrorF("redefinition of function");
-          return nullptr;
-        }
+Finally, we set the name of each of the function's arguments according to the
+names given in the Prototype. This step isn't strictly necessary, but keeping
+the names consistent makes the IR more readable, and allows subsequent code to
+refer directly to the arguments for their names, rather than having to look up
+them up in the Prototype AST.
 
-        // If F took a different number of args, reject.
-        if (F->arg_size() != Args.size()) {
-          ErrorF("redefinition of function with different # args");
-          return nullptr;
-        }
-      }
-
-In order to verify the logic above, we first check to see if the
-pre-existing function is "empty". In this case, empty means that it has
-no basic blocks in it, which means it has no body. If it has no body, it
-is a forward declaration. Since we don't allow anything after a full
-definition of the function, the code rejects this case. If the previous
-reference to a function was an 'extern', we simply verify that the
-number of arguments for that definition and this one match up. If not,
-we emit an error.
+At this point we have a function prototype with no body. This is how LLVM IR
+represents function declarations. For extern statements in Kaleidoscope, this
+is as far as we need to go. For function definitions however, we need to
+codegen and attach a function body.
 
 .. code-block:: c++
 
-      // Set names for all arguments.
-      unsigned Idx = 0;
-      for (Function::arg_iterator AI = F->arg_begin(); Idx != Args.size();
-           ++AI, ++Idx) {
-        AI->setName(Args[Idx]);
-
-        // Add arguments to variable symbol table.
-        NamedValues[Args[Idx]] = AI;
-      }
-
-      return F;
-    }
+  Function *FunctionAST::codegen() {
+      // First, check for an existing function from a previous 'extern' declaration.
+    Function *TheFunction = TheModule->getFunction(Proto->getName());
 
-The last bit of code for prototypes loops over all of the arguments in
-the function, setting the name of the LLVM Argument objects to match,
-and registering the arguments in the ``NamedValues`` map for future use
-by the ``VariableExprAST`` AST node. Once this is set up, it returns the
-Function object to the caller. Note that we don't check for conflicting
-argument names here (e.g. "extern foo(a b a)"). Doing so would be very
-straight-forward with the mechanics we have already used above.
+    if (!TheFunction)
+      TheFunction = Proto->codegen();
 
-.. code-block:: c++
+    if (!TheFunction)
+      return nullptr;
 
-    Function *FunctionAST::Codegen() {
-      NamedValues.clear();
+    if (!TheFunction->empty())
+      return (Function*)ErrorV("Function cannot be redefined.");
 
-      Function *TheFunction = Proto->Codegen();
-      if (!TheFunction)
-        return nullptr;
 
-Code generation for function definitions starts out simply enough: we
-just codegen the prototype (Proto) and verify that it is ok. We then
-clear out the ``NamedValues`` map to make sure that there isn't anything
-in it from the last function we compiled. Code generation of the
-prototype ensures that there is an LLVM Function object that is ready to
-go for us.
+For function definitions, we start by searching TheModule's symbol table for an
+existing version of this function, in case one has already been created using an
+'extern' statement. If Module::getFunction returns null then no previous version
+exists, so we'll codegen one from the Prototype. In either case, we want to
+assert that the function is empty (i.e. has no body yet) before we start.
 
 .. code-block:: c++
 
-      // Create a new basic block to start insertion into.
-      BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
-      Builder.SetInsertPoint(BB);
+  // Create a new basic block to start insertion into.
+  BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
+  Builder.SetInsertPoint(BB);
 
-      if (Value *RetVal = Body->Codegen()) {
+  // Record the function arguments in the NamedValues map.
+  NamedValues.clear();
+  for (auto &Arg : TheFunction->args())
+    NamedValues[Arg.getName()] = &Arg;
 
 Now we get to the point where the ``Builder`` is set up. The first line
 creates a new `basic block <http://en.wikipedia.org/wiki/Basic_block>`_
@@ -410,9 +358,12 @@ Graph <http://en.wikipedia.org/wiki/Control_flow_graph>`_. Since we
 don't have any control flow, our functions will only contain one block
 at this point. We'll fix this in `Chapter 5 <LangImpl5.html>`_ :).
 
+Next we add the function arguments to the NamedValues map (after first clearing
+it out) so that they're accessible to ``VariableExprAST`` nodes.
+
 .. code-block:: c++
 
-      if (Value *RetVal = Body->Codegen()) {
+      if (Value *RetVal = Body->codegen()) {
         // Finish off the function.
         Builder.CreateRet(RetVal);
 
@@ -422,11 +373,11 @@ at this point. We'll fix this in `Chapter 5 <LangImpl5.html>`_ :).
         return TheFunction;
       }
 
-Once the insertion point is set up, we call the ``CodeGen()`` method for
-the root expression of the function. If no error happens, this emits
-code to compute the expression into the entry block and returns the
-value that was computed. Assuming no error, we then create an LLVM `ret
-instruction <../LangRef.html#i_ret>`_, which completes the function.
+Once the insertion point has been set up and the NamedValues map populated,
+we call the ``codegen()`` method for the root expression of the function. If no
+error happens, this emits code to compute the expression into the entry block
+and returns the value that was computed. Assuming no error, we then create an
+LLVM `ret instruction <../LangRef.html#i_ret>`_, which completes the function.
 Once the function is built, we call ``verifyFunction``, which is
 provided by LLVM. This function does a variety of consistency checks on
 the generated code, to determine if our compiler is doing everything
@@ -446,23 +397,25 @@ we handle this by merely deleting the function we produced with the
 that they incorrectly typed in before: if we didn't delete it, it would
 live in the symbol table, with a body, preventing future redefinition.
 
-This code does have a bug, though. Since the ``PrototypeAST::Codegen``
-can return a previously defined forward declaration, our code can
-actually delete a forward declaration. There are a number of ways to fix
-this bug, see what you can come up with! Here is a testcase:
+This code does have a bug, though: If the ``FunctionAST::codegen()`` method
+finds an existing IR Function, it does not validate its signature against the
+definition's own prototype. This means that an earlier 'extern' declaration will
+take precedence over the function definition's signature, which can cause
+codegen to fail, for instance if the function arguments are named differently.
+There are a number of ways to fix this bug, see what you can come up with! Here
+is a testcase:
 
 ::
 
-    extern foo(a b);     # ok, defines foo.
-    def foo(a b) c;      # error, 'c' is invalid.
-    def bar() foo(1, 2); # error, unknown function "foo"
+    extern foo(a);     # ok, defines foo.
+    def foo(b) b;      # Error: Unknown variable name. (decl using 'a' takes precedence).
 
 Driver Changes and Closing Thoughts
 ===================================
 
 For now, code generation to LLVM doesn't really get us much, except that
 we can look at the pretty IR calls. The sample code inserts calls to
-Codegen into the "``HandleDefinition``", "``HandleExtern``" etc
+codegen into the "``HandleDefinition``", "``HandleExtern``" etc
 functions, and then dumps out the LLVM IR. This gives a nice way to look
 at the LLVM IR for simple functions. For example:
 
index 497a4c56a38cf3b9f96c5c90f5b1f3dc53563e70..702886f6aec6b432f4b51795339307634e99cb4a 100644 (file)
@@ -122,55 +122,51 @@ optimizer until the entire file has been parsed.
 In order to get per-function optimizations going, we need to set up a
 `FunctionPassManager <../WritingAnLLVMPass.html#passmanager>`_ to hold
 and organize the LLVM optimizations that we want to run. Once we have
-that, we can add a set of optimizations to run. The code looks like
-this:
+that, we can add a set of optimizations to run. We'll need a new
+FunctionPassManager for each module that we want to optimize, so we'll
+write a function to create and initialize both the module and pass manager
+for us:
 
 .. code-block:: c++
 
-      FunctionPassManager OurFPM(TheModule);
+    void InitializeModuleAndPassManager(void) {
+      // Open a new module.
+      TheModule = llvm::make_unique<Module>("my cool jit", getGlobalContext());
+      TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
+
+      // Create a new pass manager attached to it.
+      TheFPM = llvm::make_unique<FunctionPassManager>(TheModule.get());
 
-      // Set up the optimizer pipeline.  Start with registering info about how the
-      // target lays out data structures.
-      OurFPM.add(new DataLayout(*TheExecutionEngine->getDataLayout()));
       // Provide basic AliasAnalysis support for GVN.
-      OurFPM.add(createBasicAliasAnalysisPass());
+      TheFPM.add(createBasicAliasAnalysisPass());
       // Do simple "peephole" optimizations and bit-twiddling optzns.
-      OurFPM.add(createInstructionCombiningPass());
+      TheFPM.add(createInstructionCombiningPass());
       // Reassociate expressions.
-      OurFPM.add(createReassociatePass());
+      TheFPM.add(createReassociatePass());
       // Eliminate Common SubExpressions.
-      OurFPM.add(createGVNPass());
+      TheFPM.add(createGVNPass());
       // Simplify the control flow graph (deleting unreachable blocks, etc).
-      OurFPM.add(createCFGSimplificationPass());
-
-      OurFPM.doInitialization();
-
-      // Set the global so the code gen can use this.
-      TheFPM = &OurFPM;
+      TheFPM.add(createCFGSimplificationPass());
 
-      // Run the main "interpreter loop" now.
-      MainLoop();
+      TheFPM.doInitialization();
+    }
 
-This code defines a ``FunctionPassManager``, "``OurFPM``". It requires a
-pointer to the ``Module`` to construct itself. Once it is set up, we use
-a series of "add" calls to add a bunch of LLVM passes. The first pass is
-basically boilerplate, it adds a pass so that later optimizations know
-how the data structures in the program are laid out. The
-"``TheExecutionEngine``" variable is related to the JIT, which we will
-get to in the next section.
+This code initializes the global module ``TheModule``, and the function pass
+manager ``TheFPM``, which is attached to ``TheModule``. One the pass manager is
+set up, we use a series of "add" calls to add a bunch of LLVM passes.
 
-In this case, we choose to add 4 optimization passes. The passes we
-chose here are a pretty standard set of "cleanup" optimizations that are
-useful for a wide variety of code. I won't delve into what they do but,
-believe me, they are a good starting place :).
+In this case, we choose to add five passes: one analysis pass (alias analysis),
+and four optimization passes. The passes we choose here are a pretty standard set
+of "cleanup" optimizations that are useful for a wide variety of code. I won't
+delve into what they do but, believe me, they are a good starting place :).
 
 Once the PassManager is set up, we need to make use of it. We do this by
 running it after our newly created function is constructed (in
-``FunctionAST::Codegen``), but before it is returned to the client:
+``FunctionAST::codegen()``), but before it is returned to the client:
 
 .. code-block:: c++
 
-      if (Value *RetVal = Body->Codegen()) {
+      if (Value *RetVal = Body->codegen()) {
         // Finish off the function.
         Builder.CreateRet(RetVal);
 
@@ -231,55 +227,85 @@ should evaluate and print out 3. If they define a function, they should
 be able to call it from the command line.
 
 In order to do this, we first declare and initialize the JIT. This is
-done by adding a global variable and a call in ``main``:
+done by adding a global variable ``TheJIT``, and initializing it in
+``main``:
 
 .. code-block:: c++
 
-    static ExecutionEngine *TheExecutionEngine;
+    static std::unique_ptr<KaleidoscopeJIT> TheJIT;
     ...
     int main() {
       ..
-      // Create the JIT.  This takes ownership of the module.
-      TheExecutionEngine = EngineBuilder(TheModule).create();
-      ..
+      TheJIT = llvm::make_unique<KaleidoscopeJIT>();
+
+      // Run the main "interpreter loop" now.
+      MainLoop();
+
+      return 0;
     }
 
-This creates an abstract "Execution Engine" which can be either a JIT
-compiler or the LLVM interpreter. LLVM will automatically pick a JIT
-compiler for you if one is available for your platform, otherwise it
-will fall back to the interpreter.
+The KaleidoscopeJIT class is a simple JIT built specifically for these
+tutorials. In later chapters we will look at how it works and extend it with
+new features, but for now we will take it as given. Its API is very simple::
+``addModule`` adds an LLVM IR module to the JIT, making its functions
+available for execution; ``removeModule`` removes a module, freeing any
+memory associated with the code in that module; and ``findSymbol`` allows us
+to look up pointers to the compiled code.
 
-Once the ``ExecutionEngine`` is created, the JIT is ready to be used.
-There are a variety of APIs that are useful, but the simplest one is the
-"``getPointerToFunction(F)``" method. This method JIT compiles the
-specified LLVM Function and returns a function pointer to the generated
-machine code. In our case, this means that we can change the code that
-parses a top-level expression to look like this:
+We can take this simple API and change our code that parses top-level expressions to
+look like this:
 
 .. code-block:: c++
 
     static void HandleTopLevelExpression() {
       // Evaluate a top-level expression into an anonymous function.
       if (auto FnAST = ParseTopLevelExpr()) {
-        if (auto *FnIR = FnAST->Codegen()) {
-          FnIR->dump();  // Dump the function for exposition purposes.
+        if (FnAST->codegen()) {
+
+          // JIT the module containing the anonymous expression, keeping a handle so
+          // we can free it later.
+          auto H = TheJIT->addModule(std::move(TheModule));
+          InitializeModuleAndPassManager();
 
-          // JIT the function, returning a function pointer.
-          void *FPtr = TheExecutionEngine->getPointerToFunction(FnIR);
+          // Search the JIT for the __anon_expr symbol.
+          auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
+          assert(ExprSymbol && "Function not found");
 
-          // Cast it to the right type (takes no arguments, returns a double) so we
-          // can call it as a native function.
-          double (*FP)() = (double (*)())(intptr_t)FPtr;
+          // Get the symbol's address and cast it to the right type (takes no
+          // arguments, returns a double) so we can call it as a native function.
+          double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
           fprintf(stderr, "Evaluated to %f\n", FP());
+
+          // Delete the anonymous expression module from the JIT.
+          TheJIT->removeModule(H);
         }
 
-Recall that we compile top-level expressions into a self-contained LLVM
-function that takes no arguments and returns the computed double.
-Because the LLVM JIT compiler matches the native platform ABI, this
-means that you can just cast the result pointer to a function pointer of
-that type and call it directly. This means, there is no difference
-between JIT compiled code and native machine code that is statically
-linked into your application.
+If parsing and codegen succeeed, the next step is to add the module containing
+the top-level expression to the JIT. We do this by calling addModule, which
+triggers code generation for all the functions in the module, and returns a
+handle that can be used to remove the module from the JIT later. Once the module
+has been added to the JIT it can no longer be modified, so we also open a new
+module to hold subsequent code by calling ``InitializeModuleAndPassManager()``.
+
+Once we've added the module to the JIT we need to get a pointer to the final
+generated code. We do this by calling the JIT's findSymbol method, and passing
+the name of the top-level expression function: ``__anon_expr``. Since we just
+added this function, we assert that findSymbol returned a result.
+
+Next, we get the in-memory address of the ``__anon_expr`` function by calling
+``getAddress()`` on the symbol. Recall that we compile top-level expressions
+into a self-contained LLVM function that takes no arguments and returns the
+computed double. Because the LLVM JIT compiler matches the native platform ABI,
+this means that you can just cast the result pointer to a function pointer of
+that type and call it directly. This means, there is no difference between JIT
+compiled code and native machine code that is statically linked into your
+application.
+
+Finally, since we don't support re-evaluation of top-level expressions, we
+remove the module from the JIT when we're done to free the associated memory.
+Recall, however, that the module we created a few lines earlier (via
+``InitializeModuleAndPassManager``) is still open and waiting for new code to be
+added.
 
 With just these two changes, lets see how Kaleidoscope works now!
 
@@ -320,19 +346,161 @@ demonstrates very basic functionality, but can we do more?
 
     Evaluated to 24.000000
 
-This illustrates that we can now call user code, but there is something
-a bit subtle going on here. Note that we only invoke the JIT on the
-anonymous functions that *call testfunc*, but we never invoked it on
-*testfunc* itself. What actually happened here is that the JIT scanned
-for all non-JIT'd functions transitively called from the anonymous
-function and compiled all of them before returning from
-``getPointerToFunction()``.
+    ready> testfunc(5, 10);
+    ready> LLVM ERROR: Program used external function 'testfunc' which could not be resolved!
+
+
+Function definitions and calls also work, but something went very wrong on that
+last line. The call looks valid, so what happened? As you may have guessed from
+the the API a Module is a unit of allocation for the JIT, and testfunc was part
+of the same module that contained anonymous expression. When we removed that
+module from the JIT to free the memory for the anonymous expression, we deleted
+the definition of ``testfunc`` along with it. Then, when we tried to call
+testfunc a second time, the JIT could no longer find it.
+
+The easiest way to fix this is to put the anonymous expression in a separate
+module from the rest of the function definitions. The JIT will happily resolve
+function calls across module boundaries, as long as each of the functions called
+has a prototype, and is added to the JIT before it is called. By putting the
+anonymous expression in a different module we can delete it without affecting
+the rest of the functions.
+
+In fact, we're going to go a step further and put every function in its own
+module. Doing so allows us to exploit a useful property of the KaleidoscopeJIT
+that will make our environment more REPL-like: Functions can be added to the
+JIT more than once (unlike a module where every function must have a unique
+definition). When you look up a symbol in KaleidoscopeJIT it will always return
+the most recent definition:
+
+::
+
+    ready> def foo(x) x + 1;
+    Read function definition:
+    define double @foo(double %x) {
+    entry:
+      %addtmp = fadd double %x, 1.000000e+00
+      ret double %addtmp
+    }
+
+    ready> foo(2);
+    Evaluated to 3.000000
+
+    ready> def foo(x) x + 2;
+    define double @foo(double %x) {
+    entry:
+      %addtmp = fadd double %x, 2.000000e+00
+      ret double %addtmp
+    }
+
+    ready> foo(2);
+    Evaluated to 4.000000
+
+
+To allow each function to live in its own module we'll need a way to
+re-generate previous function declarations into each new module we open:
+
+.. code-block:: c++
+
+    static std::unique_ptr<KaleidoscopeJIT> TheJIT;
+
+    ...
+
+    Function *getFunction(std::string Name) {
+      // First, see if the function has already been added to the current module.
+      if (auto *F = TheModule->getFunction(Name))
+        return F;
+
+      // If not, check whether we can codegen the declaration from some existing
+      // prototype.
+      auto FI = FunctionProtos.find(Name);
+      if (FI != FunctionProtos.end())
+        return FI->second->codegen();
+
+      // If no existing prototype exists, return null.
+      return nullptr;
+    }
+
+    ...
+
+    Value *CallExprAST::codegen() {
+      // Look up the name in the global module table.
+      Function *CalleeF = getFunction(Callee);
+
+    ...
+
+    Function *FunctionAST::codegen() {
+      // Transfer ownership of the prototype to the FunctionProtos map, but keep a
+      // reference to it for use below.
+      auto &P = *Proto;
+      FunctionProtos[Proto->getName()] = std::move(Proto);
+      Function *TheFunction = getFunction(P.getName());
+      if (!TheFunction)
+        return nullptr;
+
+
+To enable this, we'll start by adding a new global, ``FunctionProtos``, that
+holds the most recent prototype for each function. We'll also add a convenience
+method, ``getFunction()``, to replace calls to ``TheModule->getFunction()``.
+Our convenience method searches ``TheModule`` for an existing function
+declaration, falling back to generating a new declaration from FunctionProtos if
+it doesn't find one. In ``CallExprAST::codegen()`` we just need to replace the
+call to ``TheModule->getFunction()``. In ``FunctionAST::codegen()`` we need to
+update the FunctionProtos map first, then call ``getFunction()``. With this
+done, we can always obtain a function declaration in the current module for any
+previously declared function.
+
+We also need to update HandleDefinition and HandleExtern:
+
+.. code-block:: c++
+
+    static void HandleDefinition() {
+      if (auto FnAST = ParseDefinition()) {
+        if (auto *FnIR = FnAST->codegen()) {
+          fprintf(stderr, "Read function definition:");
+          FnIR->dump();
+          TheJIT->addModule(std::move(TheModule));
+          InitializeModuleAndPassManager();
+        }
+      } else {
+        // Skip token for error recovery.
+         getNextToken();
+      }
+    }
+
+    static void HandleExtern() {
+      if (auto ProtoAST = ParseExtern()) {
+        if (auto *FnIR = ProtoAST->codegen()) {
+          fprintf(stderr, "Read extern: ");
+          FnIR->dump();
+          FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
+        }
+      } else {
+        // Skip token for error recovery.
+        getNextToken();
+      }
+    }
+
+In HandleDefinition, we add two lines to transfer the newly defined function to
+the JIT and open a new module. In HandleExtern, we just need to add one line to
+add the prototype to FunctionProtos.
+
+With these changes made, lets try our REPL again (I removed the dump of the
+anonymous functions this time, you should get the idea by now :) :
+
+::
+
+    ready> def foo(x) x + 1;
+    ready> foo(2);
+    Evaluated to 3.000000
+
+    ready> def foo(x) x + 2;
+    ready> foo(2);
+    Evaluated to 4.000000
+
+It works!
 
-The JIT provides a number of other more advanced interfaces for things
-like freeing allocated machine code, rejit'ing functions to update them,
-etc. However, even with this simple code, we get some surprisingly
-powerful capabilities - check this out (I removed the dump of the
-anonymous functions, you should get the idea by now :) :
+Even with this simple code, we get some surprisingly powerful capabilities -
+check this out:
 
 ::
 
@@ -375,27 +543,24 @@ anonymous functions, you should get the idea by now :) :
 
     Evaluated to 1.000000
 
-Whoa, how does the JIT know about sin and cos? The answer is
-surprisingly simple: in this example, the JIT started execution of a
-function and got to a function call. It realized that the function was
-not yet JIT compiled and invoked the standard set of routines to resolve
-the function. In this case, there is no body defined for the function,
-so the JIT ended up calling "``dlsym("sin")``" on the Kaleidoscope
-process itself. Since "``sin``" is defined within the JIT's address
-space, it simply patches up calls in the module to call the libm version
-of ``sin`` directly.
-
-The LLVM JIT provides a number of interfaces (look in the
-``ExecutionEngine.h`` file) for controlling how unknown functions get
-resolved. It allows you to establish explicit mappings between IR
-objects and addresses (useful for LLVM global variables that you want to
-map to static tables, for example), allows you to dynamically decide on
-the fly based on the function name, and even allows you to have the JIT
-compile functions lazily the first time they're called.
-
-One interesting application of this is that we can now extend the
-language by writing arbitrary C++ code to implement operations. For
-example, if we add:
+Whoa, how does the JIT know about sin and cos? The answer is surprisingly
+simple: The KaleidoscopeJIT has a straightforward symbol resolution rule that
+it uses to find symbols that aren't available in any given module: First
+it searches all the modules that have already been added to the JIT, from the
+most recent to the oldest, to find the newest definition. If no definition is
+found inside the JIT, it falls back to calling "``dlsym("sin")``" on the
+Kaleidoscope process itself. Since "``sin``" is defined within the JIT's
+address space, it simply patches up calls in the module to call the libm
+version of ``sin`` directly.
+
+In the future we'll see how tweaking this symbol resolution rule can be used to
+enable all sorts of useful features, from security (restricting the set of
+symbols available to JIT'd code), to dynamic code generation based on symbol
+names, and even lazy compilation.
+
+One immediate benefit of the symbol resolution rule is that we can now extend
+the language by writing arbitrary C++ code to implement operations. For example,
+if we add:
 
 .. code-block:: c++
 
index c0420fa70f7b39e77d09b788cf978a2d792afadf..7b8c29a1977f9092efaa07313e8eb3228492c930 100644 (file)
@@ -103,7 +103,7 @@ To represent the new expression we add a new AST node for it:
       IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then,
                 std::unique_ptr<ExprAST> Else)
         : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {}
-      virtual Value *Codegen();
+      virtual Value *codegen();
     };
 
 The AST node just has pointers to the various subexpressions.
@@ -280,13 +280,13 @@ Okay, enough of the motivation and overview, lets generate code!
 Code Generation for If/Then/Else
 --------------------------------
 
-In order to generate code for this, we implement the ``Codegen`` method
+In order to generate code for this, we implement the ``codegen`` method
 for ``IfExprAST``:
 
 .. code-block:: c++
 
-    Value *IfExprAST::Codegen() {
-      Value *CondV = Cond->Codegen();
+    Value *IfExprAST::codegen() {
+      Value *CondV = Cond->codegen();
       if (!CondV)
         return nullptr;
 
@@ -337,7 +337,7 @@ that LLVM supports forward references.
       // Emit then value.
       Builder.SetInsertPoint(ThenBB);
 
-      Value *ThenV = Then->Codegen();
+      Value *ThenV = Then->codegen();
       if (!ThenV)
         return nullptr;
 
@@ -369,7 +369,7 @@ of the block in the CFG. Why then, are we getting the current block when
 we just set it to ThenBB 5 lines above? The problem is that the "Then"
 expression may actually itself change the block that the Builder is
 emitting into if, for example, it contains a nested "if/then/else"
-expression. Because calling Codegen recursively could arbitrarily change
+expression. Because calling ``codegen()`` recursively could arbitrarily change
 the notion of the current block, we are required to get an up-to-date
 value for code that will set up the Phi node.
 
@@ -379,12 +379,12 @@ value for code that will set up the Phi node.
       TheFunction->getBasicBlockList().push_back(ElseBB);
       Builder.SetInsertPoint(ElseBB);
 
-      Value *ElseV = Else->Codegen();
+      Value *ElseV = Else->codegen();
       if (!ElseV)
         return nullptr;
 
       Builder.CreateBr(MergeBB);
-      // Codegen of 'Else' can change the current block, update ElseBB for the PHI.
+      // codegen of 'Else' can change the current block, update ElseBB for the PHI.
       ElseBB = Builder.GetInsertBlock();
 
 Code generation for the 'else' block is basically identical to codegen
@@ -500,7 +500,7 @@ variable name and the constituent expressions in the node.
                  std::unique_ptr<ExprAST> Body)
         : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
           Step(std::move(Step)), Body(std::move(Body)) {}
-      virtual Value *Codegen();
+      virtual Value *codegen();
     };
 
 Parser Extensions for the 'for' Loop
@@ -602,14 +602,14 @@ together.
 Code Generation for the 'for' Loop
 ----------------------------------
 
-The first part of Codegen is very simple: we just output the start
+The first part of codegen is very simple: we just output the start
 expression for the loop value:
 
 .. code-block:: c++
 
-    Value *ForExprAST::Codegen() {
+    Value *ForExprAST::codegen() {
       // Emit the start code first, without 'variable' in scope.
-      Value *StartVal = Start->Codegen();
+      Value *StartVal = Start->codegen();
       if (StartVal == 0) return 0;
 
 With this out of the way, the next step is to set up the LLVM basic
@@ -663,7 +663,7 @@ backedge, but we can't set it up yet (because it doesn't exist!).
       // Emit the body of the loop.  This, like any other expr, can change the
       // current BB.  Note that we ignore the value computed by the body, but don't
       // allow an error.
-      if (!Body->Codegen())
+      if (!Body->codegen())
         return nullptr;
 
 Now the code starts to get more interesting. Our 'for' loop introduces a
@@ -688,7 +688,7 @@ table.
       // Emit the step value.
       Value *StepVal = nullptr;
       if (Step) {
-        StepVal = Step->Codegen();
+        StepVal = Step->codegen();
         if (!StepVal)
           return nullptr;
       } else {
@@ -706,7 +706,7 @@ iteration of the loop.
 .. code-block:: c++
 
       // Compute the end condition.
-      Value *EndCond = End->Codegen();
+      Value *EndCond = End->codegen();
       if (!EndCond)
         return nullptr;
 
@@ -759,7 +759,7 @@ value, we can add the incoming value to the loop PHI node. After that,
 we remove the loop variable from the symbol table, so that it isn't in
 scope after the for loop. Finally, code generation of the for loop
 always returns 0.0, so that is what we return from
-``ForExprAST::Codegen``.
+``ForExprAST::codegen()``.
 
 With this, we conclude the "adding control flow to Kaleidoscope" chapter
 of the tutorial. In this chapter we added two control flow constructs,
index 4918cb08edf28ec08748b8ebf71d80d7a2bbff00..d2884ba61c28fbcf484ec5c0d12f8bf9ced730a4 100644 (file)
@@ -153,7 +153,7 @@ this:
 
       unsigned getBinaryPrecedence() const { return Precedence; }
 
-      Function *Codegen();
+      Function *codegen();
     };
 
 Basically, in addition to knowing a name for the prototype, we now keep
@@ -235,9 +235,9 @@ default case for our existing binary operator node:
 
 .. code-block:: c++
 
-    Value *BinaryExprAST::Codegen() {
-      Value *L = LHS->Codegen();
-      Value *R = RHS->Codegen();
+    Value *BinaryExprAST::codegen() {
+      Value *L = LHS->codegen();
+      Value *R = RHS->codegen();
       if (!L || !R)
         return nullptr;
 
@@ -276,10 +276,10 @@ The final piece of code we are missing, is a bit of top-level magic:
 
 .. code-block:: c++
 
-    Function *FunctionAST::Codegen() {
+    Function *FunctionAST::codegen() {
       NamedValues.clear();
 
-      Function *TheFunction = Proto->Codegen();
+      Function *TheFunction = Proto->codegen();
       if (!TheFunction)
         return nullptr;
 
@@ -291,7 +291,7 @@ The final piece of code we are missing, is a bit of top-level magic:
       BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
       Builder.SetInsertPoint(BB);
 
-      if (Value *RetVal = Body->Codegen()) {
+      if (Value *RetVal = Body->codegen()) {
         ...
 
 Basically, before codegening a function, if it is a user-defined
@@ -323,7 +323,7 @@ that, we need an AST node:
     public:
       UnaryExprAST(char Opcode, std::unique_ptr<ExprAST> Operand)
         : Opcode(Opcode), Operand(std::move(Operand)) {}
-      virtual Value *Codegen();
+      virtual Value *codegen();
     };
 
 This AST node is very simple and obvious by now. It directly mirrors the
@@ -428,8 +428,8 @@ unary operators. It looks like this:
 
 .. code-block:: c++
 
-    Value *UnaryExprAST::Codegen() {
-      Value *OperandV = Operand->Codegen();
+    Value *UnaryExprAST::codegen() {
+      Value *OperandV = Operand->codegen();
       if (!OperandV)
         return nullptr;
 
index 8c35f2ac0191fdbf4b94027c1697e67b7f1e6fb6..3ab27deebe2d3e4bdc0d1f58d3b517cec1f615e2 100644 (file)
@@ -355,7 +355,7 @@ from the stack slot:
 
 .. code-block:: c++
 
-    Value *VariableExprAST::Codegen() {
+    Value *VariableExprAST::codegen() {
       // Look this variable up in the function.
       Value *V = NamedValues[Name];
       if (!V)
@@ -367,7 +367,7 @@ from the stack slot:
 
 As you can see, this is pretty straightforward. Now we need to update
 the things that define the variables to set up the alloca. We'll start
-with ``ForExprAST::Codegen`` (see the `full code listing <#code>`_ for
+with ``ForExprAST::codegen()`` (see the `full code listing <#code>`_ for
 the unabridged code):
 
 .. code-block:: c++
@@ -378,7 +378,7 @@ the unabridged code):
       AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName);
 
         // Emit the start code first, without 'variable' in scope.
-      Value *StartVal = Start->Codegen();
+      Value *StartVal = Start->codegen();
       if (!StartVal)
         return nullptr;
 
@@ -387,7 +387,7 @@ the unabridged code):
       ...
 
       // Compute the end condition.
-      Value *EndCond = End->Codegen();
+      Value *EndCond = End->codegen();
       if (!EndCond)
         return nullptr;
 
@@ -426,7 +426,7 @@ them. The code for this is also pretty simple:
 
 For each argument, we make an alloca, store the input value to the
 function into the alloca, and register the alloca as the memory location
-for the argument. This method gets invoked by ``FunctionAST::Codegen``
+for the argument. This method gets invoked by ``FunctionAST::codegen()``
 right after it sets up the entry block for the function.
 
 The final missing piece is adding the mem2reg pass, which allows us to
@@ -572,7 +572,7 @@ implement codegen for the assignment operator. This looks like:
 
 .. code-block:: c++
 
-    Value *BinaryExprAST::Codegen() {
+    Value *BinaryExprAST::codegen() {
       // Special case '=' because we don't want to emit the LHS as an expression.
       if (Op == '=') {
         // Assignment requires the LHS to be an identifier.
@@ -590,7 +590,7 @@ allowed.
 .. code-block:: c++
 
         // Codegen the RHS.
-        Value *Val = RHS->Codegen();
+        Value *Val = RHS->codegen();
         if (!Val)
           return nullptr;
 
@@ -680,7 +680,7 @@ var/in, it looks like this:
                  std::unique_ptr<ExprAST> body)
       : VarNames(std::move(VarNames)), Body(std::move(Body)) {}
 
-      virtual Value *Codegen();
+      virtual Value *codegen();
     };
 
 var/in allows a list of names to be defined all at once, and each name
@@ -785,7 +785,7 @@ emission of LLVM IR for it. This code starts out with:
 
 .. code-block:: c++
 
-    Value *VarExprAST::Codegen() {
+    Value *VarExprAST::codegen() {
       std::vector<AllocaInst *> OldBindings;
 
       Function *TheFunction = Builder.GetInsertBlock()->getParent();
@@ -808,7 +808,7 @@ previous value that we replace in OldBindings.
         //    var a = a in ...   # refers to outer 'a'.
         Value *InitVal;
         if (Init) {
-          InitVal = Init->Codegen();
+          InitVal = Init->codegen();
           if (!InitVal)
             return nullptr;
         } else { // If not specified, use 0.0.
@@ -834,7 +834,7 @@ we evaluate the body of the var/in expression:
 .. code-block:: c++
 
       // Codegen the body, now that all vars are in scope.
-      Value *BodyVal = Body->Codegen();
+      Value *BodyVal = Body->codegen();
       if (!BodyVal)
         return nullptr;
 
index 77e3e429674eefc42aa0fa45ddb3a784508a737c..dff6ddcf27051c2e73074cd300e7f784b97b97b3 100644 (file)
@@ -109,7 +109,7 @@ code is that the llvm IR goes to standard error:
    static void HandleTopLevelExpression() {
      // Evaluate a top-level expression into an anonymous function.
      if (auto FnAST = ParseTopLevelExpr()) {
-  -    if (auto *FnIR = FnAST->Codegen()) {
+  -    if (auto *FnIR = FnAST->codegen()) {
   -      // We're just doing this to make sure it executes.
   -      TheExecutionEngine->finalizeObject();
   -      // JIT the function, returning a function pointer.
@@ -120,7 +120,7 @@ code is that the llvm IR goes to standard error:
   -      double (*FP)() = (double (*)())(intptr_t)FPtr;
   -      // Ignore the return value for this.
   -      (void)FP;
-  +    if (!F->Codegen()) {
+  +    if (!F->codegen()) {
   +      fprintf(stderr, "Error generating code for top level expr");
        }
      } else {
@@ -237,7 +237,7 @@ Functions
 =========
 
 Now that we have our ``Compile Unit`` and our source locations, we can add
-function definitions to the debug info. So in ``PrototypeAST::Codegen`` we
+function definitions to the debug info. So in ``PrototypeAST::codegen()`` we
 add a few lines of code to describe a context for our subprogram, in this
 case the "File", and the actual definition of the function itself.
 
@@ -309,7 +309,7 @@ and then we have added to all of our AST classes a source location:
      public:
        ExprAST(SourceLocation Loc = CurLoc) : Loc(Loc) {}
        virtual ~ExprAST() {}
-       virtual Value* Codegen() = 0;
+       virtual Value* codegen() = 0;
        int getLine() const { return Loc.Line; }
        int getCol() const { return Loc.Col; }
        virtual raw_ostream &dump(raw_ostream &out, int ind) {
index 9ecf75d346706f71d3469eaaec1a3b6b0800344f..14cba32f5f91308584baa2a236efda4af9b67a27 100644 (file)
@@ -348,8 +348,8 @@ static std::unique_ptr<FunctionAST> ParseDefinition() {
 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
   if (auto E = ParseExpression()) {
     // Make an anonymous proto.
-    auto Proto =
-        llvm::make_unique<PrototypeAST>("", std::vector<std::string>());
+    auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
+                                                 std::vector<std::string>());
     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
   }
   return nullptr;
index cb4f75f7917debd5dd7ee00c94a49f00163740bf..328189c118465e66480ee303e74fa8d16454eeca 100644 (file)
@@ -1,9 +1,8 @@
 #include "llvm/ADT/STLExtras.h"
-#include "llvm/IR/Verifier.h"
-#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/Verifier.h"
 #include <cctype>
 #include <cstdio>
 #include <map>
@@ -91,7 +90,7 @@ namespace {
 class ExprAST {
 public:
   virtual ~ExprAST() {}
-  virtual Value *Codegen() = 0;
+  virtual Value *codegen() = 0;
 };
 
 /// NumberExprAST - Expression class for numeric literals like "1.0".
@@ -100,7 +99,7 @@ class NumberExprAST : public ExprAST {
 
 public:
   NumberExprAST(double Val) : Val(Val) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// VariableExprAST - Expression class for referencing a variable, like "a".
@@ -109,7 +108,7 @@ class VariableExprAST : public ExprAST {
 
 public:
   VariableExprAST(const std::string &Name) : Name(Name) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// BinaryExprAST - Expression class for a binary operator.
@@ -121,7 +120,7 @@ public:
   BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
                 std::unique_ptr<ExprAST> RHS)
       : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// CallExprAST - Expression class for function calls.
@@ -133,7 +132,7 @@ public:
   CallExprAST(const std::string &Callee,
               std::vector<std::unique_ptr<ExprAST>> Args)
       : Callee(Callee), Args(std::move(Args)) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// PrototypeAST - This class represents the "prototype" for a function,
@@ -146,7 +145,8 @@ class PrototypeAST {
 public:
   PrototypeAST(const std::string &Name, std::vector<std::string> Args)
       : Name(Name), Args(std::move(Args)) {}
-  Function *Codegen();
+  Function *codegen();
+  const std::string &getName() const { return Name; }
 };
 
 /// FunctionAST - This class represents a function definition itself.
@@ -158,7 +158,7 @@ public:
   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
               std::unique_ptr<ExprAST> Body)
       : Proto(std::move(Proto)), Body(std::move(Body)) {}
-  Function *Codegen();
+  Function *codegen();
 };
 } // end anonymous namespace
 
@@ -197,10 +197,6 @@ std::unique_ptr<PrototypeAST> ErrorP(const char *Str) {
   Error(Str);
   return nullptr;
 }
-std::unique_ptr<FunctionAST> ErrorF(const char *Str) {
-  Error(Str);
-  return nullptr;
-}
 
 static std::unique_ptr<ExprAST> ParseExpression();
 
@@ -365,8 +361,8 @@ static std::unique_ptr<FunctionAST> ParseDefinition() {
 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
   if (auto E = ParseExpression()) {
     // Make an anonymous proto.
-    auto Proto =
-        llvm::make_unique<PrototypeAST>("", std::vector<std::string>());
+    auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
+                                                 std::vector<std::string>());
     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
   }
   return nullptr;
@@ -382,20 +378,20 @@ static std::unique_ptr<PrototypeAST> ParseExtern() {
 // Code Generation
 //===----------------------------------------------------------------------===//
 
+static std::unique_ptr<Module> TheModule;
+static IRBuilder<> Builder(getGlobalContext());
+static std::map<std::string, Value *> NamedValues;
+
 Value *ErrorV(const char *Str) {
   Error(Str);
   return nullptr;
 }
 
-static Module *TheModule;
-static IRBuilder<> Builder(getGlobalContext());
-static std::map<std::string, Value *> NamedValues;
-
-Value *NumberExprAST::Codegen() {
+Value *NumberExprAST::codegen() {
   return ConstantFP::get(getGlobalContext(), APFloat(Val));
 }
 
-Value *VariableExprAST::Codegen() {
+Value *VariableExprAST::codegen() {
   // Look this variable up in the function.
   Value *V = NamedValues[Name];
   if (!V)
@@ -403,9 +399,9 @@ Value *VariableExprAST::Codegen() {
   return V;
 }
 
-Value *BinaryExprAST::Codegen() {
-  Value *L = LHS->Codegen();
-  Value *R = RHS->Codegen();
+Value *BinaryExprAST::codegen() {
+  Value *L = LHS->codegen();
+  Value *R = RHS->codegen();
   if (!L || !R)
     return nullptr;
 
@@ -426,7 +422,7 @@ Value *BinaryExprAST::Codegen() {
   }
 }
 
-Value *CallExprAST::Codegen() {
+Value *CallExprAST::codegen() {
   // Look up the name in the global module table.
   Function *CalleeF = TheModule->getFunction(Callee);
   if (!CalleeF)
@@ -438,7 +434,7 @@ Value *CallExprAST::Codegen() {
 
   std::vector<Value *> ArgsV;
   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
-    ArgsV.push_back(Args[i]->Codegen());
+    ArgsV.push_back(Args[i]->codegen());
     if (!ArgsV.back())
       return nullptr;
   }
@@ -446,7 +442,7 @@ Value *CallExprAST::Codegen() {
   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
 }
 
-Function *PrototypeAST::Codegen() {
+Function *PrototypeAST::codegen() {
   // Make the function type:  double(double,double) etc.
   std::vector<Type *> Doubles(Args.size(),
                               Type::getDoubleTy(getGlobalContext()));
@@ -454,45 +450,23 @@ Function *PrototypeAST::Codegen() {
       FunctionType::get(Type::getDoubleTy(getGlobalContext()), Doubles, false);
 
   Function *F =
-      Function::Create(FT, Function::ExternalLinkage, Name, TheModule);
-
-  // If F conflicted, there was already something named 'Name'.  If it has a
-  // body, don't allow redefinition or reextern.
-  if (F->getName() != Name) {
-    // Delete the one we just made and get the existing one.
-    F->eraseFromParent();
-    F = TheModule->getFunction(Name);
-
-    // If F already has a body, reject this.
-    if (!F->empty()) {
-      ErrorF("redefinition of function");
-      return nullptr;
-    }
-
-    // If F took a different number of args, reject.
-    if (F->arg_size() != Args.size()) {
-      ErrorF("redefinition of function with different # args");
-      return nullptr;
-    }
-  }
+      Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
 
   // Set names for all arguments.
   unsigned Idx = 0;
-  for (Function::arg_iterator AI = F->arg_begin(); Idx != Args.size();
-       ++AI, ++Idx) {
-    AI->setName(Args[Idx]);
-
-    // Add arguments to variable symbol table.
-    NamedValues[Args[Idx]] = AI;
-  }
+  for (auto &Arg : F->args())
+    Arg.setName(Args[Idx++]);
 
   return F;
 }
 
-Function *FunctionAST::Codegen() {
-  NamedValues.clear();
+Function *FunctionAST::codegen() {
+  // First, check for an existing function from a previous 'extern' declaration.
+  Function *TheFunction = TheModule->getFunction(Proto->getName());
+
+  if (!TheFunction)
+    TheFunction = Proto->codegen();
 
-  Function *TheFunction = Proto->Codegen();
   if (!TheFunction)
     return nullptr;
 
@@ -500,7 +474,12 @@ Function *FunctionAST::Codegen() {
   BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
   Builder.SetInsertPoint(BB);
 
-  if (Value *RetVal = Body->Codegen()) {
+  // Record the function arguments in the NamedValues map.
+  NamedValues.clear();
+  for (auto &Arg : TheFunction->args())
+    NamedValues[Arg.getName()] = &Arg;
+
+  if (Value *RetVal = Body->codegen()) {
     // Finish off the function.
     Builder.CreateRet(RetVal);
 
@@ -521,7 +500,7 @@ Function *FunctionAST::Codegen() {
 
 static void HandleDefinition() {
   if (auto FnAST = ParseDefinition()) {
-    if (auto *FnIR = FnAST->Codegen()) {
+    if (auto *FnIR = FnAST->codegen()) {
       fprintf(stderr, "Read function definition:");
       FnIR->dump();
     }
@@ -533,7 +512,7 @@ static void HandleDefinition() {
 
 static void HandleExtern() {
   if (auto ProtoAST = ParseExtern()) {
-    if (auto *FnIR = ProtoAST->Codegen()) {
+    if (auto *FnIR = ProtoAST->codegen()) {
       fprintf(stderr, "Read extern: ");
       FnIR->dump();
     }
@@ -546,7 +525,7 @@ static void HandleExtern() {
 static void HandleTopLevelExpression() {
   // Evaluate a top-level expression into an anonymous function.
   if (auto FnAST = ParseTopLevelExpr()) {
-    if (auto *FnIR = FnAST->Codegen()) {
+    if (auto *FnIR = FnAST->codegen()) {
       fprintf(stderr, "Read top-level expression:");
       FnIR->dump();
     }
@@ -584,8 +563,6 @@ static void MainLoop() {
 //===----------------------------------------------------------------------===//
 
 int main() {
-  LLVMContext &Context = getGlobalContext();
-
   // Install standard binary operators.
   // 1 is lowest precedence.
   BinopPrecedence['<'] = 10;
@@ -598,9 +575,7 @@ int main() {
   getNextToken();
 
   // Make the module, which holds all the code.
-  std::unique_ptr<Module> Owner =
-      llvm::make_unique<Module>("my cool jit", Context);
-  TheModule = Owner.get();
+  TheModule = llvm::make_unique<Module>("my cool jit", getGlobalContext());
 
   // Run the main "interpreter loop" now.
   MainLoop();
index 9b0d2ecf67c602ae29dcedad186f436d0f4ee33e..12777ae2c758fa93587b969b6c546aaedd9c9423 100644 (file)
@@ -1,11 +1,6 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Analysis/BasicAliasAnalysis.h"
 #include "llvm/Analysis/Passes.h"
-#include "llvm/ExecutionEngine/ExecutionEngine.h"
-#include "llvm/ExecutionEngine/MCJIT.h"
-#include "llvm/ExecutionEngine/SectionMemoryManager.h"
-#include "llvm/IR/DataLayout.h"
-#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/LegacyPassManager.h"
 #include <map>
 #include <string>
 #include <vector>
+#include "../include/KaleidoscopeJIT.h"
+
 using namespace llvm;
+using namespace llvm::orc;
 
 //===----------------------------------------------------------------------===//
 // Lexer
@@ -100,7 +98,7 @@ namespace {
 class ExprAST {
 public:
   virtual ~ExprAST() {}
-  virtual Value *Codegen() = 0;
+  virtual Value *codegen() = 0;
 };
 
 /// NumberExprAST - Expression class for numeric literals like "1.0".
@@ -109,7 +107,7 @@ class NumberExprAST : public ExprAST {
 
 public:
   NumberExprAST(double Val) : Val(Val) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// VariableExprAST - Expression class for referencing a variable, like "a".
@@ -118,7 +116,7 @@ class VariableExprAST : public ExprAST {
 
 public:
   VariableExprAST(const std::string &Name) : Name(Name) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// BinaryExprAST - Expression class for a binary operator.
@@ -130,7 +128,7 @@ public:
   BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
                 std::unique_ptr<ExprAST> RHS)
       : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// CallExprAST - Expression class for function calls.
@@ -142,7 +140,7 @@ public:
   CallExprAST(const std::string &Callee,
               std::vector<std::unique_ptr<ExprAST>> Args)
       : Callee(Callee), Args(std::move(Args)) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// PrototypeAST - This class represents the "prototype" for a function,
@@ -155,7 +153,8 @@ class PrototypeAST {
 public:
   PrototypeAST(const std::string &Name, std::vector<std::string> Args)
       : Name(Name), Args(std::move(Args)) {}
-  Function *Codegen();
+  Function *codegen();
+  const std::string &getName() const { return Name; }
 };
 
 /// FunctionAST - This class represents a function definition itself.
@@ -167,7 +166,7 @@ public:
   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
               std::unique_ptr<ExprAST> Body)
       : Proto(std::move(Proto)), Body(std::move(Body)) {}
-  Function *Codegen();
+  Function *codegen();
 };
 } // end anonymous namespace
 
@@ -206,10 +205,6 @@ std::unique_ptr<PrototypeAST> ErrorP(const char *Str) {
   Error(Str);
   return nullptr;
 }
-std::unique_ptr<FunctionAST> ErrorF(const char *Str) {
-  Error(Str);
-  return nullptr;
-}
 
 static std::unique_ptr<ExprAST> ParseExpression();
 
@@ -374,8 +369,8 @@ static std::unique_ptr<FunctionAST> ParseDefinition() {
 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
   if (auto E = ParseExpression()) {
     // Make an anonymous proto.
-    auto Proto =
-        llvm::make_unique<PrototypeAST>("", std::vector<std::string>());
+    auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
+                                                 std::vector<std::string>());
     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
   }
   return nullptr;
@@ -387,259 +382,42 @@ static std::unique_ptr<PrototypeAST> ParseExtern() {
   return ParsePrototype();
 }
 
-//===----------------------------------------------------------------------===//
-// Quick and dirty hack
-//===----------------------------------------------------------------------===//
-
-// FIXME: Obviously we can do better than this
-std::string GenerateUniqueName(const char *root) {
-  static int i = 0;
-  char s[16];
-  sprintf(s, "%s%d", root, i++);
-  std::string S = s;
-  return S;
-}
-
-std::string MakeLegalFunctionName(std::string Name) {
-  std::string NewName;
-  if (!Name.length())
-    return GenerateUniqueName("anon_func_");
-
-  // Start with what we have
-  NewName = Name;
-
-  // Look for a numberic first character
-  if (NewName.find_first_of("0123456789") == 0) {
-    NewName.insert(0, 1, 'n');
-  }
-
-  // Replace illegal characters with their ASCII equivalent
-  std::string legal_elements =
-      "_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
-  size_t pos;
-  while ((pos = NewName.find_first_not_of(legal_elements)) !=
-         std::string::npos) {
-    char old_c = NewName.at(pos);
-    char new_str[16];
-    sprintf(new_str, "%d", (int)old_c);
-    NewName = NewName.replace(pos, 1, new_str);
-  }
-
-  return NewName;
-}
-
-//===----------------------------------------------------------------------===//
-// MCJIT helper class
-//===----------------------------------------------------------------------===//
-
-class MCJITHelper {
-public:
-  MCJITHelper(LLVMContext &C) : Context(C), OpenModule(NULL) {}
-  ~MCJITHelper();
-
-  Function *getFunction(const std::string FnName);
-  Module *getModuleForNewFunction();
-  void *getPointerToFunction(Function *F);
-  void *getSymbolAddress(const std::string &Name);
-  void dump();
-
-private:
-  typedef std::vector<Module *> ModuleVector;
-  typedef std::vector<ExecutionEngine *> EngineVector;
-
-  LLVMContext &Context;
-  Module *OpenModule;
-  ModuleVector Modules;
-  EngineVector Engines;
-};
-
-class HelpingMemoryManager : public SectionMemoryManager {
-  HelpingMemoryManager(const HelpingMemoryManager &) = delete;
-  void operator=(const HelpingMemoryManager &) = delete;
-
-public:
-  HelpingMemoryManager(MCJITHelper *Helper) : MasterHelper(Helper) {}
-  ~HelpingMemoryManager() override {}
-
-  /// This method returns the address of the specified symbol.
-  /// Our implementation will attempt to find symbols in other
-  /// modules associated with the MCJITHelper to cross link symbols
-  /// from one generated module to another.
-  uint64_t getSymbolAddress(const std::string &Name) override;
-
-private:
-  MCJITHelper *MasterHelper;
-};
-
-uint64_t HelpingMemoryManager::getSymbolAddress(const std::string &Name) {
-  uint64_t FnAddr = SectionMemoryManager::getSymbolAddress(Name);
-  if (FnAddr)
-    return FnAddr;
-
-  uint64_t HelperFun = (uint64_t)MasterHelper->getSymbolAddress(Name);
-  if (!HelperFun)
-    report_fatal_error("Program used extern function '" + Name +
-                       "' which could not be resolved!");
-
-  return HelperFun;
-}
-
-MCJITHelper::~MCJITHelper() {
-  if (OpenModule)
-    delete OpenModule;
-  EngineVector::iterator begin = Engines.begin();
-  EngineVector::iterator end = Engines.end();
-  EngineVector::iterator it;
-  for (it = begin; it != end; ++it)
-    delete *it;
-}
-
-Function *MCJITHelper::getFunction(const std::string FnName) {
-  ModuleVector::iterator begin = Modules.begin();
-  ModuleVector::iterator end = Modules.end();
-  ModuleVector::iterator it;
-  for (it = begin; it != end; ++it) {
-    Function *F = (*it)->getFunction(FnName);
-    if (F) {
-      if (*it == OpenModule)
-        return F;
-
-      assert(OpenModule != NULL);
-
-      // This function is in a module that has already been JITed.
-      // We need to generate a new prototype for external linkage.
-      Function *PF = OpenModule->getFunction(FnName);
-      if (PF && !PF->empty()) {
-        ErrorF("redefinition of function across modules");
-        return nullptr;
-      }
-
-      // If we don't have a prototype yet, create one.
-      if (!PF)
-        PF = Function::Create(F->getFunctionType(), Function::ExternalLinkage,
-                              FnName, OpenModule);
-      return PF;
-    }
-  }
-  return NULL;
-}
-
-Module *MCJITHelper::getModuleForNewFunction() {
-  // If we have a Module that hasn't been JITed, use that.
-  if (OpenModule)
-    return OpenModule;
-
-  // Otherwise create a new Module.
-  std::string ModName = GenerateUniqueName("mcjit_module_");
-  Module *M = new Module(ModName, Context);
-  Modules.push_back(M);
-  OpenModule = M;
-  return M;
-}
-
-void *MCJITHelper::getPointerToFunction(Function *F) {
-  // See if an existing instance of MCJIT has this function.
-  EngineVector::iterator begin = Engines.begin();
-  EngineVector::iterator end = Engines.end();
-  EngineVector::iterator it;
-  for (it = begin; it != end; ++it) {
-    void *P = (*it)->getPointerToFunction(F);
-    if (P)
-      return P;
-  }
-
-  // If we didn't find the function, see if we can generate it.
-  if (OpenModule) {
-    std::string ErrStr;
-    ExecutionEngine *NewEngine =
-        EngineBuilder(std::unique_ptr<Module>(OpenModule))
-            .setErrorStr(&ErrStr)
-            .setMCJITMemoryManager(std::unique_ptr<HelpingMemoryManager>(
-                new HelpingMemoryManager(this)))
-            .create();
-    if (!NewEngine) {
-      fprintf(stderr, "Could not create ExecutionEngine: %s\n", ErrStr.c_str());
-      exit(1);
-    }
-
-    // Create a function pass manager for this engine
-    auto *FPM = new legacy::FunctionPassManager(OpenModule);
-
-    // Set up the optimizer pipeline.  Start with registering info about how the
-    // target lays out data structures.
-    OpenModule->setDataLayout(NewEngine->getDataLayout());
-    // Provide basic AliasAnalysis support for GVN.
-    FPM->add(createBasicAliasAnalysisPass());
-    // Promote allocas to registers.
-    FPM->add(createPromoteMemoryToRegisterPass());
-    // Do simple "peephole" optimizations and bit-twiddling optzns.
-    FPM->add(createInstructionCombiningPass());
-    // Reassociate expressions.
-    FPM->add(createReassociatePass());
-    // Eliminate Common SubExpressions.
-    FPM->add(createGVNPass());
-    // Simplify the control flow graph (deleting unreachable blocks, etc).
-    FPM->add(createCFGSimplificationPass());
-    FPM->doInitialization();
-
-    // For each function in the module
-    Module::iterator it;
-    Module::iterator end = OpenModule->end();
-    for (it = OpenModule->begin(); it != end; ++it) {
-      // Run the FPM on this function
-      FPM->run(*it);
-    }
-
-    // We don't need this anymore
-    delete FPM;
-
-    OpenModule = NULL;
-    Engines.push_back(NewEngine);
-    NewEngine->finalizeObject();
-    return NewEngine->getPointerToFunction(F);
-  }
-  return NULL;
-}
-
-void *MCJITHelper::getSymbolAddress(const std::string &Name) {
-  // Look for the symbol in each of our execution engines.
-  EngineVector::iterator begin = Engines.begin();
-  EngineVector::iterator end = Engines.end();
-  EngineVector::iterator it;
-  for (it = begin; it != end; ++it) {
-    uint64_t FAddr = (*it)->getFunctionAddress(Name);
-    if (FAddr) {
-      return (void *)FAddr;
-    }
-  }
-  return NULL;
-}
-
-void MCJITHelper::dump() {
-  ModuleVector::iterator begin = Modules.begin();
-  ModuleVector::iterator end = Modules.end();
-  ModuleVector::iterator it;
-  for (it = begin; it != end; ++it)
-    (*it)->dump();
-}
 //===----------------------------------------------------------------------===//
 // Code Generation
 //===----------------------------------------------------------------------===//
 
-static MCJITHelper *JITHelper;
+static std::unique_ptr<Module> TheModule;
 static IRBuilder<> Builder(getGlobalContext());
 static std::map<std::string, Value *> NamedValues;
+static std::unique_ptr<legacy::FunctionPassManager> TheFPM;
+static std::unique_ptr<KaleidoscopeJIT> TheJIT;
+static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
 
 Value *ErrorV(const char *Str) {
   Error(Str);
   return nullptr;
 }
 
-Value *NumberExprAST::Codegen() {
+Function *getFunction(std::string Name) {
+  // First, see if the function has already been added to the current module.
+  if (auto *F = TheModule->getFunction(Name))
+    return F;
+
+  // If not, check whether we can codegen the declaration from some existing
+  // prototype.
+  auto FI = FunctionProtos.find(Name);
+  if (FI != FunctionProtos.end())
+    return FI->second->codegen();
+
+  // If no existing prototype exists, return null.
+  return nullptr;
+}
+
+Value *NumberExprAST::codegen() {
   return ConstantFP::get(getGlobalContext(), APFloat(Val));
 }
 
-Value *VariableExprAST::Codegen() {
+Value *VariableExprAST::codegen() {
   // Look this variable up in the function.
   Value *V = NamedValues[Name];
   if (!V)
@@ -647,9 +425,9 @@ Value *VariableExprAST::Codegen() {
   return V;
 }
 
-Value *BinaryExprAST::Codegen() {
-  Value *L = LHS->Codegen();
-  Value *R = RHS->Codegen();
+Value *BinaryExprAST::codegen() {
+  Value *L = LHS->codegen();
+  Value *R = RHS->codegen();
   if (!L || !R)
     return nullptr;
 
@@ -670,9 +448,9 @@ Value *BinaryExprAST::Codegen() {
   }
 }
 
-Value *CallExprAST::Codegen() {
+Value *CallExprAST::codegen() {
   // Look up the name in the global module table.
-  Function *CalleeF = JITHelper->getFunction(Callee);
+  Function *CalleeF = getFunction(Callee);
   if (!CalleeF)
     return ErrorV("Unknown function referenced");
 
@@ -682,7 +460,7 @@ Value *CallExprAST::Codegen() {
 
   std::vector<Value *> ArgsV;
   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
-    ArgsV.push_back(Args[i]->Codegen());
+    ArgsV.push_back(Args[i]->codegen());
     if (!ArgsV.back())
       return nullptr;
   }
@@ -690,55 +468,30 @@ Value *CallExprAST::Codegen() {
   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
 }
 
-Function *PrototypeAST::Codegen() {
+Function *PrototypeAST::codegen() {
   // Make the function type:  double(double,double) etc.
   std::vector<Type *> Doubles(Args.size(),
                               Type::getDoubleTy(getGlobalContext()));
   FunctionType *FT =
       FunctionType::get(Type::getDoubleTy(getGlobalContext()), Doubles, false);
 
-  std::string FnName = MakeLegalFunctionName(Name);
-
-  Module *M = JITHelper->getModuleForNewFunction();
-
-  Function *F = Function::Create(FT, Function::ExternalLinkage, FnName, M);
-
-  // If F conflicted, there was already something named 'Name'.  If it has a
-  // body, don't allow redefinition or reextern.
-  if (F->getName() != FnName) {
-    // Delete the one we just made and get the existing one.
-    F->eraseFromParent();
-    F = JITHelper->getFunction(Name);
-    // If F already has a body, reject this.
-    if (!F->empty()) {
-      ErrorF("redefinition of function");
-      return nullptr;
-    }
-
-    // If F took a different number of args, reject.
-    if (F->arg_size() != Args.size()) {
-      ErrorF("redefinition of function with different # args");
-      return nullptr;
-    }
-  }
+  Function *F =
+      Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
 
   // Set names for all arguments.
   unsigned Idx = 0;
-  for (Function::arg_iterator AI = F->arg_begin(); Idx != Args.size();
-       ++AI, ++Idx) {
-    AI->setName(Args[Idx]);
-
-    // Add arguments to variable symbol table.
-    NamedValues[Args[Idx]] = AI;
-  }
+  for (auto &Arg : F->args())
+    Arg.setName(Args[Idx++]);
 
   return F;
 }
 
-Function *FunctionAST::Codegen() {
-  NamedValues.clear();
-
-  Function *TheFunction = Proto->Codegen();
+Function *FunctionAST::codegen() {
+  // Transfer ownership of the prototype to the FunctionProtos map, but keep a
+  // reference to it for use below.
+  auto &P = *Proto;
+  FunctionProtos[Proto->getName()] = std::move(Proto);
+  Function *TheFunction = getFunction(P.getName());
   if (!TheFunction)
     return nullptr;
 
@@ -746,13 +499,21 @@ Function *FunctionAST::Codegen() {
   BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
   Builder.SetInsertPoint(BB);
 
-  if (Value *RetVal = Body->Codegen()) {
+  // Record the function arguments in the NamedValues map.
+  NamedValues.clear();
+  for (auto &Arg : TheFunction->args())
+    NamedValues[Arg.getName()] = &Arg;
+
+  if (Value *RetVal = Body->codegen()) {
     // Finish off the function.
     Builder.CreateRet(RetVal);
 
     // Validate the generated code, checking for consistency.
     verifyFunction(*TheFunction);
 
+    // Run the optimizer on the function.
+    TheFPM->run(*TheFunction);
+
     return TheFunction;
   }
 
@@ -765,11 +526,35 @@ Function *FunctionAST::Codegen() {
 // Top-Level parsing and JIT Driver
 //===----------------------------------------------------------------------===//
 
+static void InitializeModuleAndPassManager() {
+  // Open a new module.
+  TheModule = llvm::make_unique<Module>("my cool jit", getGlobalContext());
+  TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
+
+  // Create a new pass manager attached to it.
+  TheFPM = llvm::make_unique<legacy::FunctionPassManager>(TheModule.get());
+
+  // Provide basic AliasAnalysis support for GVN.
+  TheFPM->add(createBasicAliasAnalysisPass());
+  // Do simple "peephole" optimizations and bit-twiddling optzns.
+  TheFPM->add(createInstructionCombiningPass());
+  // Reassociate expressions.
+  TheFPM->add(createReassociatePass());
+  // Eliminate Common SubExpressions.
+  TheFPM->add(createGVNPass());
+  // Simplify the control flow graph (deleting unreachable blocks, etc).
+  TheFPM->add(createCFGSimplificationPass());
+
+  TheFPM->doInitialization();
+}
+
 static void HandleDefinition() {
   if (auto FnAST = ParseDefinition()) {
-    if (auto *FnIR = FnAST->Codegen()) {
+    if (auto *FnIR = FnAST->codegen()) {
       fprintf(stderr, "Read function definition:");
       FnIR->dump();
+      TheJIT->addModule(std::move(TheModule));
+      InitializeModuleAndPassManager();
     }
   } else {
     // Skip token for error recovery.
@@ -779,9 +564,10 @@ static void HandleDefinition() {
 
 static void HandleExtern() {
   if (auto ProtoAST = ParseExtern()) {
-    if (auto *FnIR = ProtoAST->Codegen()) {
+    if (auto *FnIR = ProtoAST->codegen()) {
       fprintf(stderr, "Read extern: ");
       FnIR->dump();
+      FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
     }
   } else {
     // Skip token for error recovery.
@@ -792,14 +578,24 @@ static void HandleExtern() {
 static void HandleTopLevelExpression() {
   // Evaluate a top-level expression into an anonymous function.
   if (auto FnAST = ParseTopLevelExpr()) {
-    if (auto *FnIR = FnAST->Codegen()) {
-      // JIT the function, returning a function pointer.
-      void *FPtr = JITHelper->getPointerToFunction(FnIR);
+    if (FnAST->codegen()) {
+
+      // JIT the module containing the anonymous expression, keeping a handle so
+      // we can free it later.
+      auto H = TheJIT->addModule(std::move(TheModule));
+      InitializeModuleAndPassManager();
 
-      // Cast it to the right type (takes no arguments, returns a double) so we
-      // can call it as a native function.
-      double (*FP)() = (double (*)())(intptr_t)FPtr;
+      // Search the JIT for the __anon_expr symbol.
+      auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
+      assert(ExprSymbol && "Function not found");
+
+      // Get the symbol's address and cast it to the right type (takes no
+      // arguments, returns a double) so we can call it as a native function.
+      double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
       fprintf(stderr, "Evaluated to %f\n", FP());
+
+      // Delete the anonymous expression module from the JIT.
+      TheJIT->removeModule(H);
     }
   } else {
     // Skip token for error recovery.
@@ -854,8 +650,6 @@ int main() {
   InitializeNativeTarget();
   InitializeNativeTargetAsmPrinter();
   InitializeNativeTargetAsmParser();
-  LLVMContext &Context = getGlobalContext();
-  JITHelper = new MCJITHelper(Context);
 
   // Install standard binary operators.
   // 1 is lowest precedence.
@@ -868,11 +662,12 @@ int main() {
   fprintf(stderr, "ready> ");
   getNextToken();
 
+  TheJIT = llvm::make_unique<KaleidoscopeJIT>();
+
+  InitializeModuleAndPassManager();
+
   // Run the main "interpreter loop" now.
   MainLoop();
 
-  // Print out all of the generated code.
-  JITHelper->dump();
-
   return 0;
 }
index da7a81c6776f833e6708784c1428bc74b39cdcc9..83af1776b20ebcf80307eee32e3d50ef8b30de4a 100644 (file)
@@ -1,11 +1,6 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Analysis/BasicAliasAnalysis.h"
 #include "llvm/Analysis/Passes.h"
-#include "llvm/ExecutionEngine/ExecutionEngine.h"
-#include "llvm/ExecutionEngine/MCJIT.h"
-#include "llvm/ExecutionEngine/SectionMemoryManager.h"
-#include "llvm/IR/DataLayout.h"
-#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/LegacyPassManager.h"
 #include <map>
 #include <string>
 #include <vector>
+#include "../include/KaleidoscopeJIT.h"
+
 using namespace llvm;
+using namespace llvm::orc;
 
 //===----------------------------------------------------------------------===//
 // Lexer
@@ -117,7 +115,7 @@ namespace {
 class ExprAST {
 public:
   virtual ~ExprAST() {}
-  virtual Value *Codegen() = 0;
+  virtual Value *codegen() = 0;
 };
 
 /// NumberExprAST - Expression class for numeric literals like "1.0".
@@ -126,7 +124,7 @@ class NumberExprAST : public ExprAST {
 
 public:
   NumberExprAST(double Val) : Val(Val) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// VariableExprAST - Expression class for referencing a variable, like "a".
@@ -135,7 +133,7 @@ class VariableExprAST : public ExprAST {
 
 public:
   VariableExprAST(const std::string &Name) : Name(Name) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// BinaryExprAST - Expression class for a binary operator.
@@ -147,7 +145,7 @@ public:
   BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
                 std::unique_ptr<ExprAST> RHS)
       : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// CallExprAST - Expression class for function calls.
@@ -159,7 +157,7 @@ public:
   CallExprAST(const std::string &Callee,
               std::vector<std::unique_ptr<ExprAST>> Args)
       : Callee(Callee), Args(std::move(Args)) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// IfExprAST - Expression class for if/then/else.
@@ -170,7 +168,7 @@ public:
   IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then,
             std::unique_ptr<ExprAST> Else)
       : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// ForExprAST - Expression class for for/in.
@@ -184,7 +182,7 @@ public:
              std::unique_ptr<ExprAST> Body)
       : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
         Step(std::move(Step)), Body(std::move(Body)) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// PrototypeAST - This class represents the "prototype" for a function,
@@ -197,7 +195,8 @@ class PrototypeAST {
 public:
   PrototypeAST(const std::string &Name, std::vector<std::string> Args)
       : Name(Name), Args(std::move(Args)) {}
-  Function *Codegen();
+  Function *codegen();
+  const std::string &getName() const { return Name; }
 };
 
 /// FunctionAST - This class represents a function definition itself.
@@ -209,7 +208,7 @@ public:
   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
               std::unique_ptr<ExprAST> Body)
       : Proto(std::move(Proto)), Body(std::move(Body)) {}
-  Function *Codegen();
+  Function *codegen();
 };
 } // end anonymous namespace
 
@@ -248,10 +247,6 @@ std::unique_ptr<PrototypeAST> ErrorP(const char *Str) {
   Error(Str);
   return nullptr;
 }
-std::unique_ptr<FunctionAST> ErrorF(const char *Str) {
-  Error(Str);
-  return nullptr;
-}
 
 static std::unique_ptr<ExprAST> ParseExpression();
 
@@ -498,8 +493,8 @@ static std::unique_ptr<FunctionAST> ParseDefinition() {
 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
   if (auto E = ParseExpression()) {
     // Make an anonymous proto.
-    auto Proto =
-        llvm::make_unique<PrototypeAST>("", std::vector<std::string>());
+    auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
+                                                 std::vector<std::string>());
     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
   }
   return nullptr;
@@ -515,21 +510,38 @@ static std::unique_ptr<PrototypeAST> ParseExtern() {
 // Code Generation
 //===----------------------------------------------------------------------===//
 
-static Module *TheModule;
+static std::unique_ptr<Module> TheModule;
 static IRBuilder<> Builder(getGlobalContext());
 static std::map<std::string, Value *> NamedValues;
-static legacy::FunctionPassManager *TheFPM;
+static std::unique_ptr<legacy::FunctionPassManager> TheFPM;
+static std::unique_ptr<KaleidoscopeJIT> TheJIT;
+static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
 
 Value *ErrorV(const char *Str) {
   Error(Str);
   return nullptr;
 }
 
-Value *NumberExprAST::Codegen() {
+Function *getFunction(std::string Name) {
+  // First, see if the function has already been added to the current module.
+  if (auto *F = TheModule->getFunction(Name))
+    return F;
+
+  // If not, check whether we can codegen the declaration from some existing
+  // prototype.
+  auto FI = FunctionProtos.find(Name);
+  if (FI != FunctionProtos.end())
+    return FI->second->codegen();
+
+  // If no existing prototype exists, return null.
+  return nullptr;
+}
+
+Value *NumberExprAST::codegen() {
   return ConstantFP::get(getGlobalContext(), APFloat(Val));
 }
 
-Value *VariableExprAST::Codegen() {
+Value *VariableExprAST::codegen() {
   // Look this variable up in the function.
   Value *V = NamedValues[Name];
   if (!V)
@@ -537,9 +549,9 @@ Value *VariableExprAST::Codegen() {
   return V;
 }
 
-Value *BinaryExprAST::Codegen() {
-  Value *L = LHS->Codegen();
-  Value *R = RHS->Codegen();
+Value *BinaryExprAST::codegen() {
+  Value *L = LHS->codegen();
+  Value *R = RHS->codegen();
   if (!L || !R)
     return nullptr;
 
@@ -560,9 +572,9 @@ Value *BinaryExprAST::Codegen() {
   }
 }
 
-Value *CallExprAST::Codegen() {
+Value *CallExprAST::codegen() {
   // Look up the name in the global module table.
-  Function *CalleeF = TheModule->getFunction(Callee);
+  Function *CalleeF = getFunction(Callee);
   if (!CalleeF)
     return ErrorV("Unknown function referenced");
 
@@ -572,7 +584,7 @@ Value *CallExprAST::Codegen() {
 
   std::vector<Value *> ArgsV;
   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
-    ArgsV.push_back(Args[i]->Codegen());
+    ArgsV.push_back(Args[i]->codegen());
     if (!ArgsV.back())
       return nullptr;
   }
@@ -580,8 +592,8 @@ Value *CallExprAST::Codegen() {
   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
 }
 
-Value *IfExprAST::Codegen() {
-  Value *CondV = Cond->Codegen();
+Value *IfExprAST::codegen() {
+  Value *CondV = Cond->codegen();
   if (!CondV)
     return nullptr;
 
@@ -603,7 +615,7 @@ Value *IfExprAST::Codegen() {
   // Emit then value.
   Builder.SetInsertPoint(ThenBB);
 
-  Value *ThenV = Then->Codegen();
+  Value *ThenV = Then->codegen();
   if (!ThenV)
     return nullptr;
 
@@ -615,7 +627,7 @@ Value *IfExprAST::Codegen() {
   TheFunction->getBasicBlockList().push_back(ElseBB);
   Builder.SetInsertPoint(ElseBB);
 
-  Value *ElseV = Else->Codegen();
+  Value *ElseV = Else->codegen();
   if (!ElseV)
     return nullptr;
 
@@ -649,9 +661,9 @@ Value *IfExprAST::Codegen() {
 //   endcond = endexpr
 //   br endcond, loop, endloop
 // outloop:
-Value *ForExprAST::Codegen() {
+Value *ForExprAST::codegen() {
   // Emit the start code first, without 'variable' in scope.
-  Value *StartVal = Start->Codegen();
+  Value *StartVal = Start->codegen();
   if (!StartVal)
     return nullptr;
 
@@ -681,13 +693,13 @@ Value *ForExprAST::Codegen() {
   // Emit the body of the loop.  This, like any other expr, can change the
   // current BB.  Note that we ignore the value computed by the body, but don't
   // allow an error.
-  if (!Body->Codegen())
+  if (!Body->codegen())
     return nullptr;
 
   // Emit the step value.
   Value *StepVal = nullptr;
   if (Step) {
-    StepVal = Step->Codegen();
+    StepVal = Step->codegen();
     if (!StepVal)
       return nullptr;
   } else {
@@ -698,7 +710,7 @@ Value *ForExprAST::Codegen() {
   Value *NextVar = Builder.CreateFAdd(Variable, StepVal, "nextvar");
 
   // Compute the end condition.
-  Value *EndCond = End->Codegen();
+  Value *EndCond = End->codegen();
   if (!EndCond)
     return nullptr;
 
@@ -730,7 +742,7 @@ Value *ForExprAST::Codegen() {
   return Constant::getNullValue(Type::getDoubleTy(getGlobalContext()));
 }
 
-Function *PrototypeAST::Codegen() {
+Function *PrototypeAST::codegen() {
   // Make the function type:  double(double,double) etc.
   std::vector<Type *> Doubles(Args.size(),
                               Type::getDoubleTy(getGlobalContext()));
@@ -738,45 +750,22 @@ Function *PrototypeAST::Codegen() {
       FunctionType::get(Type::getDoubleTy(getGlobalContext()), Doubles, false);
 
   Function *F =
-      Function::Create(FT, Function::ExternalLinkage, Name, TheModule);
-
-  // If F conflicted, there was already something named 'Name'.  If it has a
-  // body, don't allow redefinition or reextern.
-  if (F->getName() != Name) {
-    // Delete the one we just made and get the existing one.
-    F->eraseFromParent();
-    F = TheModule->getFunction(Name);
-
-    // If F already has a body, reject this.
-    if (!F->empty()) {
-      ErrorF("redefinition of function");
-      return nullptr;
-    }
-
-    // If F took a different number of args, reject.
-    if (F->arg_size() != Args.size()) {
-      ErrorF("redefinition of function with different # args");
-      return nullptr;
-    }
-  }
+      Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
 
   // Set names for all arguments.
   unsigned Idx = 0;
-  for (Function::arg_iterator AI = F->arg_begin(); Idx != Args.size();
-       ++AI, ++Idx) {
-    AI->setName(Args[Idx]);
-
-    // Add arguments to variable symbol table.
-    NamedValues[Args[Idx]] = AI;
-  }
+  for (auto &Arg : F->args())
+    Arg.setName(Args[Idx++]);
 
   return F;
 }
 
-Function *FunctionAST::Codegen() {
-  NamedValues.clear();
-
-  Function *TheFunction = Proto->Codegen();
+Function *FunctionAST::codegen() {
+  // Transfer ownership of the prototype to the FunctionProtos map, but keep a
+  // reference to it for use below.
+  auto &P = *Proto;
+  FunctionProtos[Proto->getName()] = std::move(Proto);
+  Function *TheFunction = getFunction(P.getName());
   if (!TheFunction)
     return nullptr;
 
@@ -784,14 +773,19 @@ Function *FunctionAST::Codegen() {
   BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
   Builder.SetInsertPoint(BB);
 
-  if (Value *RetVal = Body->Codegen()) {
+  // Record the function arguments in the NamedValues map.
+  NamedValues.clear();
+  for (auto &Arg : TheFunction->args())
+    NamedValues[Arg.getName()] = &Arg;
+
+  if (Value *RetVal = Body->codegen()) {
     // Finish off the function.
     Builder.CreateRet(RetVal);
 
     // Validate the generated code, checking for consistency.
     verifyFunction(*TheFunction);
 
-    // Optimize the function.
+    // Run the optimizer on the function.
     TheFPM->run(*TheFunction);
 
     return TheFunction;
@@ -806,13 +800,35 @@ Function *FunctionAST::Codegen() {
 // Top-Level parsing and JIT Driver
 //===----------------------------------------------------------------------===//
 
-static ExecutionEngine *TheExecutionEngine;
+static void InitializeModuleAndPassManager() {
+  // Open a new module.
+  TheModule = llvm::make_unique<Module>("my cool jit", getGlobalContext());
+  TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
+
+  // Create a new pass manager attached to it.
+  TheFPM = llvm::make_unique<legacy::FunctionPassManager>(TheModule.get());
+
+  // Provide basic AliasAnalysis support for GVN.
+  TheFPM->add(createBasicAliasAnalysisPass());
+  // Do simple "peephole" optimizations and bit-twiddling optzns.
+  TheFPM->add(createInstructionCombiningPass());
+  // Reassociate expressions.
+  TheFPM->add(createReassociatePass());
+  // Eliminate Common SubExpressions.
+  TheFPM->add(createGVNPass());
+  // Simplify the control flow graph (deleting unreachable blocks, etc).
+  TheFPM->add(createCFGSimplificationPass());
+
+  TheFPM->doInitialization();
+}
 
 static void HandleDefinition() {
   if (auto FnAST = ParseDefinition()) {
-    if (auto *FnIR = FnAST->Codegen()) {
+    if (auto *FnIR = FnAST->codegen()) {
       fprintf(stderr, "Read function definition:");
       FnIR->dump();
+      TheJIT->addModule(std::move(TheModule));
+      InitializeModuleAndPassManager();
     }
   } else {
     // Skip token for error recovery.
@@ -822,9 +838,10 @@ static void HandleDefinition() {
 
 static void HandleExtern() {
   if (auto ProtoAST = ParseExtern()) {
-    if (auto *FnIR = ProtoAST->Codegen()) {
+    if (auto *FnIR = ProtoAST->codegen()) {
       fprintf(stderr, "Read extern: ");
       FnIR->dump();
+      FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
     }
   } else {
     // Skip token for error recovery.
@@ -835,15 +852,24 @@ static void HandleExtern() {
 static void HandleTopLevelExpression() {
   // Evaluate a top-level expression into an anonymous function.
   if (auto FnAST = ParseTopLevelExpr()) {
-    if (auto *FnIR = FnAST->Codegen()) {
-      TheExecutionEngine->finalizeObject();
-      // JIT the function, returning a function pointer.
-      void *FPtr = TheExecutionEngine->getPointerToFunction(FnIR);
-
-      // Cast it to the right type (takes no arguments, returns a double) so we
-      // can call it as a native function.
-      double (*FP)() = (double (*)())(intptr_t)FPtr;
+    if (FnAST->codegen()) {
+
+      // JIT the module containing the anonymous expression, keeping a handle so
+      // we can free it later.
+      auto H = TheJIT->addModule(std::move(TheModule));
+      InitializeModuleAndPassManager();
+
+      // Search the JIT for the __anon_expr symbol.
+      auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
+      assert(ExprSymbol && "Function not found");
+
+      // Get the symbol's address and cast it to the right type (takes no
+      // arguments, returns a double) so we can call it as a native function.
+      double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
       fprintf(stderr, "Evaluated to %f\n", FP());
+
+      // Delete the anonymous expression module from the JIT.
+      TheJIT->removeModule(H);
     }
   } else {
     // Skip token for error recovery.
@@ -898,7 +924,6 @@ int main() {
   InitializeNativeTarget();
   InitializeNativeTargetAsmPrinter();
   InitializeNativeTargetAsmParser();
-  LLVMContext &Context = getGlobalContext();
 
   // Install standard binary operators.
   // 1 is lowest precedence.
@@ -911,50 +936,12 @@ int main() {
   fprintf(stderr, "ready> ");
   getNextToken();
 
-  // Make the module, which holds all the code.
-  std::unique_ptr<Module> Owner = make_unique<Module>("my cool jit", Context);
-  TheModule = Owner.get();
-
-  // Create the JIT.  This takes ownership of the module.
-  std::string ErrStr;
-  TheExecutionEngine =
-      EngineBuilder(std::move(Owner))
-          .setErrorStr(&ErrStr)
-          .setMCJITMemoryManager(llvm::make_unique<SectionMemoryManager>())
-          .create();
-  if (!TheExecutionEngine) {
-    fprintf(stderr, "Could not create ExecutionEngine: %s\n", ErrStr.c_str());
-    exit(1);
-  }
+  TheJIT = llvm::make_unique<KaleidoscopeJIT>();
 
-  legacy::FunctionPassManager OurFPM(TheModule);
-
-  // Set up the optimizer pipeline.  Start with registering info about how the
-  // target lays out data structures.
-  TheModule->setDataLayout(TheExecutionEngine->getDataLayout());
-  // Provide basic AliasAnalysis support for GVN.
-  OurFPM.add(createBasicAliasAnalysisPass());
-  // Do simple "peephole" optimizations and bit-twiddling optzns.
-  OurFPM.add(createInstructionCombiningPass());
-  // Reassociate expressions.
-  OurFPM.add(createReassociatePass());
-  // Eliminate Common SubExpressions.
-  OurFPM.add(createGVNPass());
-  // Simplify the control flow graph (deleting unreachable blocks, etc).
-  OurFPM.add(createCFGSimplificationPass());
-
-  OurFPM.doInitialization();
-
-  // Set the global so the code gen can use this.
-  TheFPM = &OurFPM;
+  InitializeModuleAndPassManager();
 
   // Run the main "interpreter loop" now.
   MainLoop();
 
-  TheFPM = 0;
-
-  // Print out all of the generated code.
-  TheModule->dump();
-
   return 0;
 }
index b4e8397e9e6a01d4286352c29cfb780ae55ac0b2..e1bed45189e0593d36c113036b200757080b7f9f 100644 (file)
@@ -1,11 +1,6 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Analysis/BasicAliasAnalysis.h"
 #include "llvm/Analysis/Passes.h"
-#include "llvm/ExecutionEngine/ExecutionEngine.h"
-#include "llvm/ExecutionEngine/MCJIT.h"
-#include "llvm/ExecutionEngine/SectionMemoryManager.h"
-#include "llvm/IR/DataLayout.h"
-#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/LegacyPassManager.h"
 #include <map>
 #include <string>
 #include <vector>
+#include "../include/KaleidoscopeJIT.h"
+
 using namespace llvm;
+using namespace llvm::orc;
 
 //===----------------------------------------------------------------------===//
 // Lexer
@@ -125,7 +123,7 @@ namespace {
 class ExprAST {
 public:
   virtual ~ExprAST() {}
-  virtual Value *Codegen() = 0;
+  virtual Value *codegen() = 0;
 };
 
 /// NumberExprAST - Expression class for numeric literals like "1.0".
@@ -134,7 +132,7 @@ class NumberExprAST : public ExprAST {
 
 public:
   NumberExprAST(double Val) : Val(Val) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// VariableExprAST - Expression class for referencing a variable, like "a".
@@ -143,7 +141,7 @@ class VariableExprAST : public ExprAST {
 
 public:
   VariableExprAST(const std::string &Name) : Name(Name) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// UnaryExprAST - Expression class for a unary operator.
@@ -154,7 +152,7 @@ class UnaryExprAST : public ExprAST {
 public:
   UnaryExprAST(char Opcode, std::unique_ptr<ExprAST> Operand)
       : Opcode(Opcode), Operand(std::move(Operand)) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// BinaryExprAST - Expression class for a binary operator.
@@ -166,7 +164,7 @@ public:
   BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
                 std::unique_ptr<ExprAST> RHS)
       : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// CallExprAST - Expression class for function calls.
@@ -178,7 +176,7 @@ public:
   CallExprAST(const std::string &Callee,
               std::vector<std::unique_ptr<ExprAST>> Args)
       : Callee(Callee), Args(std::move(Args)) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// IfExprAST - Expression class for if/then/else.
@@ -189,7 +187,7 @@ public:
   IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then,
             std::unique_ptr<ExprAST> Else)
       : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// ForExprAST - Expression class for for/in.
@@ -203,7 +201,7 @@ public:
              std::unique_ptr<ExprAST> Body)
       : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
         Step(std::move(Step)), Body(std::move(Body)) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// PrototypeAST - This class represents the "prototype" for a function,
@@ -220,6 +218,8 @@ public:
                bool IsOperator = false, unsigned Prec = 0)
       : Name(Name), Args(std::move(Args)), IsOperator(IsOperator),
         Precedence(Prec) {}
+  Function *codegen();
+  const std::string &getName() const { return Name; }
 
   bool isUnaryOp() const { return IsOperator && Args.size() == 1; }
   bool isBinaryOp() const { return IsOperator && Args.size() == 2; }
@@ -230,8 +230,6 @@ public:
   }
 
   unsigned getBinaryPrecedence() const { return Precedence; }
-
-  Function *Codegen();
 };
 
 /// FunctionAST - This class represents a function definition itself.
@@ -243,7 +241,7 @@ public:
   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
               std::unique_ptr<ExprAST> Body)
       : Proto(std::move(Proto)), Body(std::move(Body)) {}
-  Function *Codegen();
+  Function *codegen();
 };
 } // end anonymous namespace
 
@@ -282,10 +280,6 @@ std::unique_ptr<PrototypeAST> ErrorP(const char *Str) {
   Error(Str);
   return nullptr;
 }
-std::unique_ptr<FunctionAST> ErrorF(const char *Str) {
-  Error(Str);
-  return nullptr;
-}
 
 static std::unique_ptr<ExprAST> ParseExpression();
 
@@ -590,8 +584,8 @@ static std::unique_ptr<FunctionAST> ParseDefinition() {
 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
   if (auto E = ParseExpression()) {
     // Make an anonymous proto.
-    auto Proto =
-        llvm::make_unique<PrototypeAST>("", std::vector<std::string>());
+    auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
+                                                 std::vector<std::string>());
     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
   }
   return nullptr;
@@ -607,21 +601,38 @@ static std::unique_ptr<PrototypeAST> ParseExtern() {
 // Code Generation
 //===----------------------------------------------------------------------===//
 
-static Module *TheModule;
+static std::unique_ptr<Module> TheModule;
 static IRBuilder<> Builder(getGlobalContext());
 static std::map<std::string, Value *> NamedValues;
-static legacy::FunctionPassManager *TheFPM;
+static std::unique_ptr<legacy::FunctionPassManager> TheFPM;
+static std::unique_ptr<KaleidoscopeJIT> TheJIT;
+static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
 
 Value *ErrorV(const char *Str) {
   Error(Str);
   return nullptr;
 }
 
-Value *NumberExprAST::Codegen() {
+Function *getFunction(std::string Name) {
+  // First, see if the function has already been added to the current module.
+  if (auto *F = TheModule->getFunction(Name))
+    return F;
+
+  // If not, check whether we can codegen the declaration from some existing
+  // prototype.
+  auto FI = FunctionProtos.find(Name);
+  if (FI != FunctionProtos.end())
+    return FI->second->codegen();
+
+  // If no existing prototype exists, return null.
+  return nullptr;
+}
+
+Value *NumberExprAST::codegen() {
   return ConstantFP::get(getGlobalContext(), APFloat(Val));
 }
 
-Value *VariableExprAST::Codegen() {
+Value *VariableExprAST::codegen() {
   // Look this variable up in the function.
   Value *V = NamedValues[Name];
   if (!V)
@@ -629,21 +640,21 @@ Value *VariableExprAST::Codegen() {
   return V;
 }
 
-Value *UnaryExprAST::Codegen() {
-  Value *OperandV = Operand->Codegen();
+Value *UnaryExprAST::codegen() {
+  Value *OperandV = Operand->codegen();
   if (!OperandV)
     return nullptr;
 
-  Function *F = TheModule->getFunction(std::string("unary") + Opcode);
+  Function *F = getFunction(std::string("unary") + Opcode);
   if (!F)
     return ErrorV("Unknown unary operator");
 
   return Builder.CreateCall(F, OperandV, "unop");
 }
 
-Value *BinaryExprAST::Codegen() {
-  Value *L = LHS->Codegen();
-  Value *R = RHS->Codegen();
+Value *BinaryExprAST::codegen() {
+  Value *L = LHS->codegen();
+  Value *R = RHS->codegen();
   if (!L || !R)
     return nullptr;
 
@@ -665,16 +676,16 @@ Value *BinaryExprAST::Codegen() {
 
   // If it wasn't a builtin binary operator, it must be a user defined one. Emit
   // a call to it.
-  Function *F = TheModule->getFunction(std::string("binary") + Op);
+  Function *F = getFunction(std::string("binary") + Op);
   assert(F && "binary operator not found!");
 
   Value *Ops[] = {L, R};
   return Builder.CreateCall(F, Ops, "binop");
 }
 
-Value *CallExprAST::Codegen() {
+Value *CallExprAST::codegen() {
   // Look up the name in the global module table.
-  Function *CalleeF = TheModule->getFunction(Callee);
+  Function *CalleeF = getFunction(Callee);
   if (!CalleeF)
     return ErrorV("Unknown function referenced");
 
@@ -684,7 +695,7 @@ Value *CallExprAST::Codegen() {
 
   std::vector<Value *> ArgsV;
   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
-    ArgsV.push_back(Args[i]->Codegen());
+    ArgsV.push_back(Args[i]->codegen());
     if (!ArgsV.back())
       return nullptr;
   }
@@ -692,8 +703,8 @@ Value *CallExprAST::Codegen() {
   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
 }
 
-Value *IfExprAST::Codegen() {
-  Value *CondV = Cond->Codegen();
+Value *IfExprAST::codegen() {
+  Value *CondV = Cond->codegen();
   if (!CondV)
     return nullptr;
 
@@ -715,7 +726,7 @@ Value *IfExprAST::Codegen() {
   // Emit then value.
   Builder.SetInsertPoint(ThenBB);
 
-  Value *ThenV = Then->Codegen();
+  Value *ThenV = Then->codegen();
   if (!ThenV)
     return nullptr;
 
@@ -727,7 +738,7 @@ Value *IfExprAST::Codegen() {
   TheFunction->getBasicBlockList().push_back(ElseBB);
   Builder.SetInsertPoint(ElseBB);
 
-  Value *ElseV = Else->Codegen();
+  Value *ElseV = Else->codegen();
   if (!ElseV)
     return nullptr;
 
@@ -761,9 +772,9 @@ Value *IfExprAST::Codegen() {
 //   endcond = endexpr
 //   br endcond, loop, endloop
 // outloop:
-Value *ForExprAST::Codegen() {
+Value *ForExprAST::codegen() {
   // Emit the start code first, without 'variable' in scope.
-  Value *StartVal = Start->Codegen();
+  Value *StartVal = Start->codegen();
   if (!StartVal)
     return nullptr;
 
@@ -793,13 +804,13 @@ Value *ForExprAST::Codegen() {
   // Emit the body of the loop.  This, like any other expr, can change the
   // current BB.  Note that we ignore the value computed by the body, but don't
   // allow an error.
-  if (!Body->Codegen())
+  if (!Body->codegen())
     return nullptr;
 
   // Emit the step value.
   Value *StepVal = nullptr;
   if (Step) {
-    StepVal = Step->Codegen();
+    StepVal = Step->codegen();
     if (!StepVal)
       return nullptr;
   } else {
@@ -810,7 +821,7 @@ Value *ForExprAST::Codegen() {
   Value *NextVar = Builder.CreateFAdd(Variable, StepVal, "nextvar");
 
   // Compute the end condition.
-  Value *EndCond = End->Codegen();
+  Value *EndCond = End->codegen();
   if (!EndCond)
     return nullptr;
 
@@ -842,7 +853,7 @@ Value *ForExprAST::Codegen() {
   return Constant::getNullValue(Type::getDoubleTy(getGlobalContext()));
 }
 
-Function *PrototypeAST::Codegen() {
+Function *PrototypeAST::codegen() {
   // Make the function type:  double(double,double) etc.
   std::vector<Type *> Doubles(Args.size(),
                               Type::getDoubleTy(getGlobalContext()));
@@ -850,64 +861,46 @@ Function *PrototypeAST::Codegen() {
       FunctionType::get(Type::getDoubleTy(getGlobalContext()), Doubles, false);
 
   Function *F =
-      Function::Create(FT, Function::ExternalLinkage, Name, TheModule);
-
-  // If F conflicted, there was already something named 'Name'.  If it has a
-  // body, don't allow redefinition or reextern.
-  if (F->getName() != Name) {
-    // Delete the one we just made and get the existing one.
-    F->eraseFromParent();
-    F = TheModule->getFunction(Name);
-
-    // If F already has a body, reject this.
-    if (!F->empty()) {
-      ErrorF("redefinition of function");
-      return nullptr;
-    }
-
-    // If F took a different number of args, reject.
-    if (F->arg_size() != Args.size()) {
-      ErrorF("redefinition of function with different # args");
-      return nullptr;
-    }
-  }
+      Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
 
   // Set names for all arguments.
   unsigned Idx = 0;
-  for (Function::arg_iterator AI = F->arg_begin(); Idx != Args.size();
-       ++AI, ++Idx) {
-    AI->setName(Args[Idx]);
-
-    // Add arguments to variable symbol table.
-    NamedValues[Args[Idx]] = AI;
-  }
+  for (auto &Arg : F->args())
+    Arg.setName(Args[Idx++]);
 
   return F;
 }
 
-Function *FunctionAST::Codegen() {
-  NamedValues.clear();
-
-  Function *TheFunction = Proto->Codegen();
+Function *FunctionAST::codegen() {
+  // Transfer ownership of the prototype to the FunctionProtos map, but keep a
+  // reference to it for use below.
+  auto &P = *Proto;
+  FunctionProtos[Proto->getName()] = std::move(Proto);
+  Function *TheFunction = getFunction(P.getName());
   if (!TheFunction)
     return nullptr;
 
   // If this is an operator, install it.
-  if (Proto->isBinaryOp())
-    BinopPrecedence[Proto->getOperatorName()] = Proto->getBinaryPrecedence();
+  if (P.isBinaryOp())
+    BinopPrecedence[P.getOperatorName()] = P.getBinaryPrecedence();
 
   // Create a new basic block to start insertion into.
   BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
   Builder.SetInsertPoint(BB);
 
-  if (Value *RetVal = Body->Codegen()) {
+  // Record the function arguments in the NamedValues map.
+  NamedValues.clear();
+  for (auto &Arg : TheFunction->args())
+    NamedValues[Arg.getName()] = &Arg;
+
+  if (Value *RetVal = Body->codegen()) {
     // Finish off the function.
     Builder.CreateRet(RetVal);
 
     // Validate the generated code, checking for consistency.
     verifyFunction(*TheFunction);
 
-    // Optimize the function.
+    // Run the optimizer on the function.
     TheFPM->run(*TheFunction);
 
     return TheFunction;
@@ -916,7 +909,7 @@ Function *FunctionAST::Codegen() {
   // Error reading body, remove function.
   TheFunction->eraseFromParent();
 
-  if (Proto->isBinaryOp())
+  if (P.isBinaryOp())
     BinopPrecedence.erase(Proto->getOperatorName());
   return nullptr;
 }
@@ -925,13 +918,35 @@ Function *FunctionAST::Codegen() {
 // Top-Level parsing and JIT Driver
 //===----------------------------------------------------------------------===//
 
-static ExecutionEngine *TheExecutionEngine;
+static void InitializeModuleAndPassManager() {
+  // Open a new module.
+  TheModule = llvm::make_unique<Module>("my cool jit", getGlobalContext());
+  TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
+
+  // Create a new pass manager attached to it.
+  TheFPM = llvm::make_unique<legacy::FunctionPassManager>(TheModule.get());
+
+  // Provide basic AliasAnalysis support for GVN.
+  TheFPM->add(createBasicAliasAnalysisPass());
+  // Do simple "peephole" optimizations and bit-twiddling optzns.
+  TheFPM->add(createInstructionCombiningPass());
+  // Reassociate expressions.
+  TheFPM->add(createReassociatePass());
+  // Eliminate Common SubExpressions.
+  TheFPM->add(createGVNPass());
+  // Simplify the control flow graph (deleting unreachable blocks, etc).
+  TheFPM->add(createCFGSimplificationPass());
+
+  TheFPM->doInitialization();
+}
 
 static void HandleDefinition() {
   if (auto FnAST = ParseDefinition()) {
-    if (auto *FnIR = FnAST->Codegen()) {
+    if (auto *FnIR = FnAST->codegen()) {
       fprintf(stderr, "Read function definition:");
       FnIR->dump();
+      TheJIT->addModule(std::move(TheModule));
+      InitializeModuleAndPassManager();
     }
   } else {
     // Skip token for error recovery.
@@ -941,9 +956,10 @@ static void HandleDefinition() {
 
 static void HandleExtern() {
   if (auto ProtoAST = ParseExtern()) {
-    if (auto *FnIR = ProtoAST->Codegen()) {
+    if (auto *FnIR = ProtoAST->codegen()) {
       fprintf(stderr, "Read extern: ");
       FnIR->dump();
+      FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
     }
   } else {
     // Skip token for error recovery.
@@ -954,15 +970,24 @@ static void HandleExtern() {
 static void HandleTopLevelExpression() {
   // Evaluate a top-level expression into an anonymous function.
   if (auto FnAST = ParseTopLevelExpr()) {
-    if (auto *FnIR = FnAST->Codegen()) {
-      TheExecutionEngine->finalizeObject();
-      // JIT the function, returning a function pointer.
-      void *FPtr = TheExecutionEngine->getPointerToFunction(FnIR);
-
-      // Cast it to the right type (takes no arguments, returns a double) so we
-      // can call it as a native function.
-      double (*FP)() = (double (*)())(intptr_t)FPtr;
+    if (FnAST->codegen()) {
+
+      // JIT the module containing the anonymous expression, keeping a handle so
+      // we can free it later.
+      auto H = TheJIT->addModule(std::move(TheModule));
+      InitializeModuleAndPassManager();
+
+      // Search the JIT for the __anon_expr symbol.
+      auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
+      assert(ExprSymbol && "Function not found");
+
+      // Get the symbol's address and cast it to the right type (takes no
+      // arguments, returns a double) so we can call it as a native function.
+      double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
       fprintf(stderr, "Evaluated to %f\n", FP());
+
+      // Delete the anonymous expression module from the JIT.
+      TheJIT->removeModule(H);
     }
   } else {
     // Skip token for error recovery.
@@ -1017,7 +1042,6 @@ int main() {
   InitializeNativeTarget();
   InitializeNativeTargetAsmPrinter();
   InitializeNativeTargetAsmParser();
-  LLVMContext &Context = getGlobalContext();
 
   // Install standard binary operators.
   // 1 is lowest precedence.
@@ -1030,50 +1054,12 @@ int main() {
   fprintf(stderr, "ready> ");
   getNextToken();
 
-  // Make the module, which holds all the code.
-  std::unique_ptr<Module> Owner = make_unique<Module>("my cool jit", Context);
-  TheModule = Owner.get();
-
-  // Create the JIT.  This takes ownership of the module.
-  std::string ErrStr;
-  TheExecutionEngine =
-      EngineBuilder(std::move(Owner))
-          .setErrorStr(&ErrStr)
-          .setMCJITMemoryManager(llvm::make_unique<SectionMemoryManager>())
-          .create();
-  if (!TheExecutionEngine) {
-    fprintf(stderr, "Could not create ExecutionEngine: %s\n", ErrStr.c_str());
-    exit(1);
-  }
+  TheJIT = llvm::make_unique<KaleidoscopeJIT>();
 
-  legacy::FunctionPassManager OurFPM(TheModule);
-
-  // Set up the optimizer pipeline.  Start with registering info about how the
-  // target lays out data structures.
-  TheModule->setDataLayout(TheExecutionEngine->getDataLayout());
-  // Provide basic AliasAnalysis support for GVN.
-  OurFPM.add(createBasicAliasAnalysisPass());
-  // Do simple "peephole" optimizations and bit-twiddling optzns.
-  OurFPM.add(createInstructionCombiningPass());
-  // Reassociate expressions.
-  OurFPM.add(createReassociatePass());
-  // Eliminate Common SubExpressions.
-  OurFPM.add(createGVNPass());
-  // Simplify the control flow graph (deleting unreachable blocks, etc).
-  OurFPM.add(createCFGSimplificationPass());
-
-  OurFPM.doInitialization();
-
-  // Set the global so the code gen can use this.
-  TheFPM = &OurFPM;
+  InitializeModuleAndPassManager();
 
   // Run the main "interpreter loop" now.
   MainLoop();
 
-  TheFPM = 0;
-
-  // Print out all of the generated code.
-  TheModule->dump();
-
   return 0;
 }
index a900c5f7b16f3039e6b74b0e6da8867f5f6bdcca..4558522952ce26d405745f2c8c1f3e426d349ccd 100644 (file)
@@ -1,11 +1,6 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Analysis/BasicAliasAnalysis.h"
 #include "llvm/Analysis/Passes.h"
-#include "llvm/ExecutionEngine/ExecutionEngine.h"
-#include "llvm/ExecutionEngine/MCJIT.h"
-#include "llvm/ExecutionEngine/SectionMemoryManager.h"
-#include "llvm/IR/DataLayout.h"
-#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/LegacyPassManager.h"
 #include <map>
 #include <string>
 #include <vector>
+#include "../include/KaleidoscopeJIT.h"
+
 using namespace llvm;
+using namespace llvm::orc;
 
 //===----------------------------------------------------------------------===//
 // Lexer
@@ -130,7 +128,7 @@ namespace {
 class ExprAST {
 public:
   virtual ~ExprAST() {}
-  virtual Value *Codegen() = 0;
+  virtual Value *codegen() = 0;
 };
 
 /// NumberExprAST - Expression class for numeric literals like "1.0".
@@ -139,7 +137,7 @@ class NumberExprAST : public ExprAST {
 
 public:
   NumberExprAST(double Val) : Val(Val) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// VariableExprAST - Expression class for referencing a variable, like "a".
@@ -149,7 +147,7 @@ class VariableExprAST : public ExprAST {
 public:
   VariableExprAST(const std::string &Name) : Name(Name) {}
   const std::string &getName() const { return Name; }
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// UnaryExprAST - Expression class for a unary operator.
@@ -160,7 +158,7 @@ class UnaryExprAST : public ExprAST {
 public:
   UnaryExprAST(char Opcode, std::unique_ptr<ExprAST> Operand)
       : Opcode(Opcode), Operand(std::move(Operand)) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// BinaryExprAST - Expression class for a binary operator.
@@ -172,7 +170,7 @@ public:
   BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
                 std::unique_ptr<ExprAST> RHS)
       : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// CallExprAST - Expression class for function calls.
@@ -184,7 +182,7 @@ public:
   CallExprAST(const std::string &Callee,
               std::vector<std::unique_ptr<ExprAST>> Args)
       : Callee(Callee), Args(std::move(Args)) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// IfExprAST - Expression class for if/then/else.
@@ -195,7 +193,7 @@ public:
   IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then,
             std::unique_ptr<ExprAST> Else)
       : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// ForExprAST - Expression class for for/in.
@@ -209,7 +207,7 @@ public:
              std::unique_ptr<ExprAST> Body)
       : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
         Step(std::move(Step)), Body(std::move(Body)) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// VarExprAST - Expression class for var/in
@@ -222,7 +220,7 @@ public:
       std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames,
       std::unique_ptr<ExprAST> Body)
       : VarNames(std::move(VarNames)), Body(std::move(Body)) {}
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// PrototypeAST - This class represents the "prototype" for a function,
@@ -239,6 +237,8 @@ public:
                bool IsOperator = false, unsigned Prec = 0)
       : Name(Name), Args(std::move(Args)), IsOperator(IsOperator),
         Precedence(Prec) {}
+  Function *codegen();
+  const std::string &getName() const { return Name; }
 
   bool isUnaryOp() const { return IsOperator && Args.size() == 1; }
   bool isBinaryOp() const { return IsOperator && Args.size() == 2; }
@@ -249,10 +249,6 @@ public:
   }
 
   unsigned getBinaryPrecedence() const { return Precedence; }
-
-  Function *Codegen();
-
-  void CreateArgumentAllocas(Function *F);
 };
 
 /// FunctionAST - This class represents a function definition itself.
@@ -264,7 +260,7 @@ public:
   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
               std::unique_ptr<ExprAST> Body)
       : Proto(std::move(Proto)), Body(std::move(Body)) {}
-  Function *Codegen();
+  Function *codegen();
 };
 } // end anonymous namespace
 
@@ -303,10 +299,6 @@ std::unique_ptr<PrototypeAST> ErrorP(const char *Str) {
   Error(Str);
   return nullptr;
 }
-std::unique_ptr<FunctionAST> ErrorF(const char *Str) {
-  Error(Str);
-  return nullptr;
-}
 
 static std::unique_ptr<ExprAST> ParseExpression();
 
@@ -662,8 +654,8 @@ static std::unique_ptr<FunctionAST> ParseDefinition() {
 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
   if (auto E = ParseExpression()) {
     // Make an anonymous proto.
-    auto Proto =
-        llvm::make_unique<PrototypeAST>("", std::vector<std::string>());
+    auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
+                                                 std::vector<std::string>());
     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
   }
   return nullptr;
@@ -679,16 +671,33 @@ static std::unique_ptr<PrototypeAST> ParseExtern() {
 // Code Generation
 //===----------------------------------------------------------------------===//
 
-static Module *TheModule;
+static std::unique_ptr<Module> TheModule;
 static IRBuilder<> Builder(getGlobalContext());
 static std::map<std::string, AllocaInst *> NamedValues;
-static legacy::FunctionPassManager *TheFPM;
+static std::unique_ptr<legacy::FunctionPassManager> TheFPM;
+static std::unique_ptr<KaleidoscopeJIT> TheJIT;
+static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
 
 Value *ErrorV(const char *Str) {
   Error(Str);
   return nullptr;
 }
 
+Function *getFunction(std::string Name) {
+  // First, see if the function has already been added to the current module.
+  if (auto *F = TheModule->getFunction(Name))
+    return F;
+
+  // If not, check whether we can codegen the declaration from some existing
+  // prototype.
+  auto FI = FunctionProtos.find(Name);
+  if (FI != FunctionProtos.end())
+    return FI->second->codegen();
+
+  // If no existing prototype exists, return null.
+  return nullptr;
+}
+
 /// CreateEntryBlockAlloca - Create an alloca instruction in the entry block of
 /// the function.  This is used for mutable variables etc.
 static AllocaInst *CreateEntryBlockAlloca(Function *TheFunction,
@@ -699,11 +708,11 @@ static AllocaInst *CreateEntryBlockAlloca(Function *TheFunction,
                            VarName.c_str());
 }
 
-Value *NumberExprAST::Codegen() {
+Value *NumberExprAST::codegen() {
   return ConstantFP::get(getGlobalContext(), APFloat(Val));
 }
 
-Value *VariableExprAST::Codegen() {
+Value *VariableExprAST::codegen() {
   // Look this variable up in the function.
   Value *V = NamedValues[Name];
   if (!V)
@@ -713,19 +722,19 @@ Value *VariableExprAST::Codegen() {
   return Builder.CreateLoad(V, Name.c_str());
 }
 
-Value *UnaryExprAST::Codegen() {
-  Value *OperandV = Operand->Codegen();
+Value *UnaryExprAST::codegen() {
+  Value *OperandV = Operand->codegen();
   if (!OperandV)
     return nullptr;
 
-  Function *F = TheModule->getFunction(std::string("unary") + Opcode);
+  Function *F = getFunction(std::string("unary") + Opcode);
   if (!F)
     return ErrorV("Unknown unary operator");
 
   return Builder.CreateCall(F, OperandV, "unop");
 }
 
-Value *BinaryExprAST::Codegen() {
+Value *BinaryExprAST::codegen() {
   // Special case '=' because we don't want to emit the LHS as an expression.
   if (Op == '=') {
     // Assignment requires the LHS to be an identifier.
@@ -736,7 +745,7 @@ Value *BinaryExprAST::Codegen() {
     if (!LHSE)
       return ErrorV("destination of '=' must be a variable");
     // Codegen the RHS.
-    Value *Val = RHS->Codegen();
+    Value *Val = RHS->codegen();
     if (!Val)
       return nullptr;
 
@@ -749,8 +758,8 @@ Value *BinaryExprAST::Codegen() {
     return Val;
   }
 
-  Value *L = LHS->Codegen();
-  Value *R = RHS->Codegen();
+  Value *L = LHS->codegen();
+  Value *R = RHS->codegen();
   if (!L || !R)
     return nullptr;
 
@@ -772,16 +781,16 @@ Value *BinaryExprAST::Codegen() {
 
   // If it wasn't a builtin binary operator, it must be a user defined one. Emit
   // a call to it.
-  Function *F = TheModule->getFunction(std::string("binary") + Op);
+  Function *F = getFunction(std::string("binary") + Op);
   assert(F && "binary operator not found!");
 
   Value *Ops[] = {L, R};
   return Builder.CreateCall(F, Ops, "binop");
 }
 
-Value *CallExprAST::Codegen() {
+Value *CallExprAST::codegen() {
   // Look up the name in the global module table.
-  Function *CalleeF = TheModule->getFunction(Callee);
+  Function *CalleeF = getFunction(Callee);
   if (!CalleeF)
     return ErrorV("Unknown function referenced");
 
@@ -791,7 +800,7 @@ Value *CallExprAST::Codegen() {
 
   std::vector<Value *> ArgsV;
   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
-    ArgsV.push_back(Args[i]->Codegen());
+    ArgsV.push_back(Args[i]->codegen());
     if (!ArgsV.back())
       return nullptr;
   }
@@ -799,8 +808,8 @@ Value *CallExprAST::Codegen() {
   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
 }
 
-Value *IfExprAST::Codegen() {
-  Value *CondV = Cond->Codegen();
+Value *IfExprAST::codegen() {
+  Value *CondV = Cond->codegen();
   if (!CondV)
     return nullptr;
 
@@ -822,7 +831,7 @@ Value *IfExprAST::Codegen() {
   // Emit then value.
   Builder.SetInsertPoint(ThenBB);
 
-  Value *ThenV = Then->Codegen();
+  Value *ThenV = Then->codegen();
   if (!ThenV)
     return nullptr;
 
@@ -834,7 +843,7 @@ Value *IfExprAST::Codegen() {
   TheFunction->getBasicBlockList().push_back(ElseBB);
   Builder.SetInsertPoint(ElseBB);
 
-  Value *ElseV = Else->Codegen();
+  Value *ElseV = Else->codegen();
   if (!ElseV)
     return nullptr;
 
@@ -872,14 +881,14 @@ Value *IfExprAST::Codegen() {
 //   store nextvar -> var
 //   br endcond, loop, endloop
 // outloop:
-Value *ForExprAST::Codegen() {
+Value *ForExprAST::codegen() {
   Function *TheFunction = Builder.GetInsertBlock()->getParent();
 
   // Create an alloca for the variable in the entry block.
   AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName);
 
   // Emit the start code first, without 'variable' in scope.
-  Value *StartVal = Start->Codegen();
+  Value *StartVal = Start->codegen();
   if (!StartVal)
     return nullptr;
 
@@ -905,13 +914,13 @@ Value *ForExprAST::Codegen() {
   // Emit the body of the loop.  This, like any other expr, can change the
   // current BB.  Note that we ignore the value computed by the body, but don't
   // allow an error.
-  if (!Body->Codegen())
+  if (!Body->codegen())
     return nullptr;
 
   // Emit the step value.
   Value *StepVal = nullptr;
   if (Step) {
-    StepVal = Step->Codegen();
+    StepVal = Step->codegen();
     if (!StepVal)
       return nullptr;
   } else {
@@ -920,7 +929,7 @@ Value *ForExprAST::Codegen() {
   }
 
   // Compute the end condition.
-  Value *EndCond = End->Codegen();
+  Value *EndCond = End->codegen();
   if (!EndCond)
     return nullptr;
 
@@ -954,7 +963,7 @@ Value *ForExprAST::Codegen() {
   return Constant::getNullValue(Type::getDoubleTy(getGlobalContext()));
 }
 
-Value *VarExprAST::Codegen() {
+Value *VarExprAST::codegen() {
   std::vector<AllocaInst *> OldBindings;
 
   Function *TheFunction = Builder.GetInsertBlock()->getParent();
@@ -971,7 +980,7 @@ Value *VarExprAST::Codegen() {
     //    var a = a in ...   # refers to outer 'a'.
     Value *InitVal;
     if (Init) {
-      InitVal = Init->Codegen();
+      InitVal = Init->codegen();
       if (!InitVal)
         return nullptr;
     } else { // If not specified, use 0.0.
@@ -990,7 +999,7 @@ Value *VarExprAST::Codegen() {
   }
 
   // Codegen the body, now that all vars are in scope.
-  Value *BodyVal = Body->Codegen();
+  Value *BodyVal = Body->codegen();
   if (!BodyVal)
     return nullptr;
 
@@ -1002,7 +1011,7 @@ Value *VarExprAST::Codegen() {
   return BodyVal;
 }
 
-Function *PrototypeAST::Codegen() {
+Function *PrototypeAST::codegen() {
   // Make the function type:  double(double,double) etc.
   std::vector<Type *> Doubles(Args.size(),
                               Type::getDoubleTy(getGlobalContext()));
@@ -1010,79 +1019,54 @@ Function *PrototypeAST::Codegen() {
       FunctionType::get(Type::getDoubleTy(getGlobalContext()), Doubles, false);
 
   Function *F =
-      Function::Create(FT, Function::ExternalLinkage, Name, TheModule);
-
-  // If F conflicted, there was already something named 'Name'.  If it has a
-  // body, don't allow redefinition or reextern.
-  if (F->getName() != Name) {
-    // Delete the one we just made and get the existing one.
-    F->eraseFromParent();
-    F = TheModule->getFunction(Name);
-
-    // If F already has a body, reject this.
-    if (!F->empty()) {
-      ErrorF("redefinition of function");
-      return nullptr;
-    }
-
-    // If F took a different number of args, reject.
-    if (F->arg_size() != Args.size()) {
-      ErrorF("redefinition of function with different # args");
-      return nullptr;
-    }
-  }
+      Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
 
   // Set names for all arguments.
   unsigned Idx = 0;
-  for (Function::arg_iterator AI = F->arg_begin(); Idx != Args.size();
-       ++AI, ++Idx)
-    AI->setName(Args[Idx]);
+  for (auto &Arg : F->args())
+    Arg.setName(Args[Idx++]);
 
   return F;
 }
 
-/// CreateArgumentAllocas - Create an alloca for each argument and register the
-/// argument in the symbol table so that references to it will succeed.
-void PrototypeAST::CreateArgumentAllocas(Function *F) {
-  Function::arg_iterator AI = F->arg_begin();
-  for (unsigned Idx = 0, e = Args.size(); Idx != e; ++Idx, ++AI) {
-    // Create an alloca for this variable.
-    AllocaInst *Alloca = CreateEntryBlockAlloca(F, Args[Idx]);
-
-    // Store the initial value into the alloca.
-    Builder.CreateStore(AI, Alloca);
-
-    // Add arguments to variable symbol table.
-    NamedValues[Args[Idx]] = Alloca;
-  }
-}
-
-Function *FunctionAST::Codegen() {
-  NamedValues.clear();
-
-  Function *TheFunction = Proto->Codegen();
+Function *FunctionAST::codegen() {
+  // Transfer ownership of the prototype to the FunctionProtos map, but keep a
+  // reference to it for use below.
+  auto &P = *Proto;
+  FunctionProtos[Proto->getName()] = std::move(Proto);
+  Function *TheFunction = getFunction(P.getName());
   if (!TheFunction)
     return nullptr;
 
   // If this is an operator, install it.
-  if (Proto->isBinaryOp())
-    BinopPrecedence[Proto->getOperatorName()] = Proto->getBinaryPrecedence();
+  if (P.isBinaryOp())
+    BinopPrecedence[P.getOperatorName()] = P.getBinaryPrecedence();
 
   // Create a new basic block to start insertion into.
   BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
   Builder.SetInsertPoint(BB);
 
-  // Add all arguments to the symbol table and create their allocas.
-  Proto->CreateArgumentAllocas(TheFunction);
+  // Record the function arguments in the NamedValues map.
+  NamedValues.clear();
+  for (auto &Arg : TheFunction->args()) {
+    // Create an alloca for this variable.
+    AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, Arg.getName());
 
-  if (Value *RetVal = Body->Codegen()) {
+    // Store the initial value into the alloca.
+    Builder.CreateStore(&Arg, Alloca);
+
+    // Add arguments to variable symbol table.
+    NamedValues[Arg.getName()] = Alloca;
+  }
+
+  if (Value *RetVal = Body->codegen()) {
     // Finish off the function.
     Builder.CreateRet(RetVal);
 
     // Validate the generated code, checking for consistency.
     verifyFunction(*TheFunction);
 
-    // Optimize the function.
+    // Run the optimizer on the function.
     TheFPM->run(*TheFunction);
 
     return TheFunction;
@@ -1091,7 +1075,7 @@ Function *FunctionAST::Codegen() {
   // Error reading body, remove function.
   TheFunction->eraseFromParent();
 
-  if (Proto->isBinaryOp())
+  if (P.isBinaryOp())
     BinopPrecedence.erase(Proto->getOperatorName());
   return nullptr;
 }
@@ -1100,13 +1084,35 @@ Function *FunctionAST::Codegen() {
 // Top-Level parsing and JIT Driver
 //===----------------------------------------------------------------------===//
 
-static ExecutionEngine *TheExecutionEngine;
+static void InitializeModuleAndPassManager() {
+  // Open a new module.
+  TheModule = llvm::make_unique<Module>("my cool jit", getGlobalContext());
+  TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
+
+  // Create a new pass manager attached to it.
+  TheFPM = llvm::make_unique<legacy::FunctionPassManager>(TheModule.get());
+
+  // Provide basic AliasAnalysis support for GVN.
+  TheFPM->add(createBasicAliasAnalysisPass());
+  // Do simple "peephole" optimizations and bit-twiddling optzns.
+  TheFPM->add(createInstructionCombiningPass());
+  // Reassociate expressions.
+  TheFPM->add(createReassociatePass());
+  // Eliminate Common SubExpressions.
+  TheFPM->add(createGVNPass());
+  // Simplify the control flow graph (deleting unreachable blocks, etc).
+  TheFPM->add(createCFGSimplificationPass());
+
+  TheFPM->doInitialization();
+}
 
 static void HandleDefinition() {
   if (auto FnAST = ParseDefinition()) {
-    if (auto *FnIR = FnAST->Codegen()) {
+    if (auto *FnIR = FnAST->codegen()) {
       fprintf(stderr, "Read function definition:");
       FnIR->dump();
+      TheJIT->addModule(std::move(TheModule));
+      InitializeModuleAndPassManager();
     }
   } else {
     // Skip token for error recovery.
@@ -1116,9 +1122,10 @@ static void HandleDefinition() {
 
 static void HandleExtern() {
   if (auto ProtoAST = ParseExtern()) {
-    if (auto *FnIR = ProtoAST->Codegen()) {
+    if (auto *FnIR = ProtoAST->codegen()) {
       fprintf(stderr, "Read extern: ");
       FnIR->dump();
+      FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
     }
   } else {
     // Skip token for error recovery.
@@ -1129,15 +1136,24 @@ static void HandleExtern() {
 static void HandleTopLevelExpression() {
   // Evaluate a top-level expression into an anonymous function.
   if (auto FnAST = ParseTopLevelExpr()) {
-    if (auto *FnIR = FnAST->Codegen()) {
-      TheExecutionEngine->finalizeObject();
-      // JIT the function, returning a function pointer.
-      void *FPtr = TheExecutionEngine->getPointerToFunction(FnIR);
-
-      // Cast it to the right type (takes no arguments, returns a double) so we
-      // can call it as a native function.
-      double (*FP)() = (double (*)())(intptr_t)FPtr;
+    if (FnAST->codegen()) {
+
+      // JIT the module containing the anonymous expression, keeping a handle so
+      // we can free it later.
+      auto H = TheJIT->addModule(std::move(TheModule));
+      InitializeModuleAndPassManager();
+
+      // Search the JIT for the __anon_expr symbol.
+      auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
+      assert(ExprSymbol && "Function not found");
+
+      // Get the symbol's address and cast it to the right type (takes no
+      // arguments, returns a double) so we can call it as a native function.
+      double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
       fprintf(stderr, "Evaluated to %f\n", FP());
+
+      // Delete the anonymous expression module from the JIT.
+      TheJIT->removeModule(H);
     }
   } else {
     // Skip token for error recovery.
@@ -1192,7 +1208,6 @@ int main() {
   InitializeNativeTarget();
   InitializeNativeTargetAsmPrinter();
   InitializeNativeTargetAsmParser();
-  LLVMContext &Context = getGlobalContext();
 
   // Install standard binary operators.
   // 1 is lowest precedence.
@@ -1206,52 +1221,12 @@ int main() {
   fprintf(stderr, "ready> ");
   getNextToken();
 
-  // Make the module, which holds all the code.
-  std::unique_ptr<Module> Owner = make_unique<Module>("my cool jit", Context);
-  TheModule = Owner.get();
-
-  // Create the JIT.  This takes ownership of the module.
-  std::string ErrStr;
-  TheExecutionEngine =
-      EngineBuilder(std::move(Owner))
-          .setErrorStr(&ErrStr)
-          .setMCJITMemoryManager(llvm::make_unique<SectionMemoryManager>())
-          .create();
-  if (!TheExecutionEngine) {
-    fprintf(stderr, "Could not create ExecutionEngine: %s\n", ErrStr.c_str());
-    exit(1);
-  }
-
-  legacy::FunctionPassManager OurFPM(TheModule);
-
-  // Set up the optimizer pipeline.  Start with registering info about how the
-  // target lays out data structures.
-  TheModule->setDataLayout(TheExecutionEngine->getDataLayout());
-  // Provide basic AliasAnalysis support for GVN.
-  OurFPM.add(createBasicAliasAnalysisPass());
-  // Promote allocas to registers.
-  OurFPM.add(createPromoteMemoryToRegisterPass());
-  // Do simple "peephole" optimizations and bit-twiddling optzns.
-  OurFPM.add(createInstructionCombiningPass());
-  // Reassociate expressions.
-  OurFPM.add(createReassociatePass());
-  // Eliminate Common SubExpressions.
-  OurFPM.add(createGVNPass());
-  // Simplify the control flow graph (deleting unreachable blocks, etc).
-  OurFPM.add(createCFGSimplificationPass());
-
-  OurFPM.doInitialization();
+  TheJIT = llvm::make_unique<KaleidoscopeJIT>();
 
-  // Set the global so the code gen can use this.
-  TheFPM = &OurFPM;
+  InitializeModuleAndPassManager();
 
   // Run the main "interpreter loop" now.
   MainLoop();
 
-  TheFPM = 0;
-
-  // Print out all of the generated code.
-  TheModule->dump();
-
   return 0;
 }
index b39a5a215af409dd5d4980963a8e785a3e6e4ef9..7338c6ebc50bd15b86b686827031e96a64bb23b5 100644 (file)
@@ -1,18 +1,12 @@
 #include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/Triple.h"
+#include "llvm/Analysis/BasicAliasAnalysis.h"
 #include "llvm/Analysis/Passes.h"
-#include "llvm/ExecutionEngine/ExecutionEngine.h"
-#include "llvm/ExecutionEngine/MCJIT.h"
-#include "llvm/ExecutionEngine/SectionMemoryManager.h"
 #include "llvm/IR/DIBuilder.h"
-#include "llvm/IR/DataLayout.h"
-#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/LegacyPassManager.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Verifier.h"
-#include "llvm/Support/Host.h"
 #include "llvm/Support/TargetSelect.h"
 #include "llvm/Transforms/Scalar.h"
 #include <cctype>
 #include <map>
 #include <string>
 #include <vector>
+#include "../include/KaleidoscopeJIT.h"
+
 using namespace llvm;
+using namespace llvm::orc;
 
 //===----------------------------------------------------------------------===//
 // Lexer
@@ -95,14 +92,11 @@ struct DebugInfo {
   DICompileUnit *TheCU;
   DIType *DblTy;
   std::vector<DIScope *> LexicalBlocks;
-  std::map<const PrototypeAST *, DIScope *> FnScopeMap;
 
   void emitLocation(ExprAST *AST);
   DIType *getDoubleTy();
 } KSDbgInfo;
 
-static std::string IdentifierStr; // Filled in if tok_identifier
-static double NumVal;             // Filled in if tok_number
 struct SourceLocation {
   int Line;
   int Col;
@@ -121,6 +115,9 @@ static int advance() {
   return LastChar;
 }
 
+static std::string IdentifierStr; // Filled in if tok_identifier
+static double NumVal;             // Filled in if tok_number
+
 /// gettok - Return the next token from standard input.
 static int gettok() {
   static int LastChar = ' ';
@@ -206,7 +203,7 @@ class ExprAST {
 public:
   ExprAST(SourceLocation Loc = CurLoc) : Loc(Loc) {}
   virtual ~ExprAST() {}
-  virtual Value *Codegen() = 0;
+  virtual Value *codegen() = 0;
   int getLine() const { return Loc.Line; }
   int getCol() const { return Loc.Col; }
   virtual raw_ostream &dump(raw_ostream &out, int ind) {
@@ -223,7 +220,7 @@ public:
   raw_ostream &dump(raw_ostream &out, int ind) override {
     return ExprAST::dump(out << Val, ind);
   }
-  Value *Codegen() override;
+  Value *codegen() override;
 };
 
 /// VariableExprAST - Expression class for referencing a variable, like "a".
@@ -234,10 +231,10 @@ public:
   VariableExprAST(SourceLocation Loc, const std::string &Name)
       : ExprAST(Loc), Name(Name) {}
   const std::string &getName() const { return Name; }
+  Value *codegen() override;
   raw_ostream &dump(raw_ostream &out, int ind) override {
     return ExprAST::dump(out << Name, ind);
   }
-  Value *Codegen() override;
 };
 
 /// UnaryExprAST - Expression class for a unary operator.
@@ -248,12 +245,12 @@ class UnaryExprAST : public ExprAST {
 public:
   UnaryExprAST(char Opcode, std::unique_ptr<ExprAST> Operand)
       : Opcode(Opcode), Operand(std::move(Operand)) {}
+  Value *codegen() override;
   raw_ostream &dump(raw_ostream &out, int ind) override {
     ExprAST::dump(out << "unary" << Opcode, ind);
     Operand->dump(out, ind + 1);
     return out;
   }
-  Value *Codegen() override;
 };
 
 /// BinaryExprAST - Expression class for a binary operator.
@@ -265,13 +262,13 @@ public:
   BinaryExprAST(SourceLocation Loc, char Op, std::unique_ptr<ExprAST> LHS,
                 std::unique_ptr<ExprAST> RHS)
       : ExprAST(Loc), Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
+  Value *codegen() override;
   raw_ostream &dump(raw_ostream &out, int ind) override {
     ExprAST::dump(out << "binary" << Op, ind);
     LHS->dump(indent(out, ind) << "LHS:", ind + 1);
     RHS->dump(indent(out, ind) << "RHS:", ind + 1);
     return out;
   }
-  Value *Codegen() override;
 };
 
 /// CallExprAST - Expression class for function calls.
@@ -283,13 +280,13 @@ public:
   CallExprAST(SourceLocation Loc, const std::string &Callee,
               std::vector<std::unique_ptr<ExprAST>> Args)
       : ExprAST(Loc), Callee(Callee), Args(std::move(Args)) {}
+  Value *codegen() override;
   raw_ostream &dump(raw_ostream &out, int ind) override {
     ExprAST::dump(out << "call " << Callee, ind);
     for (const auto &Arg : Args)
       Arg->dump(indent(out, ind + 1), ind + 1);
     return out;
   }
-  Value *Codegen() override;
 };
 
 /// IfExprAST - Expression class for if/then/else.
@@ -301,6 +298,7 @@ public:
             std::unique_ptr<ExprAST> Then, std::unique_ptr<ExprAST> Else)
       : ExprAST(Loc), Cond(std::move(Cond)), Then(std::move(Then)),
         Else(std::move(Else)) {}
+  Value *codegen() override;
   raw_ostream &dump(raw_ostream &out, int ind) override {
     ExprAST::dump(out << "if", ind);
     Cond->dump(indent(out, ind) << "Cond:", ind + 1);
@@ -308,7 +306,6 @@ public:
     Else->dump(indent(out, ind) << "Else:", ind + 1);
     return out;
   }
-  Value *Codegen() override;
 };
 
 /// ForExprAST - Expression class for for/in.
@@ -322,6 +319,7 @@ public:
              std::unique_ptr<ExprAST> Body)
       : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
         Step(std::move(Step)), Body(std::move(Body)) {}
+  Value *codegen() override;
   raw_ostream &dump(raw_ostream &out, int ind) override {
     ExprAST::dump(out << "for", ind);
     Start->dump(indent(out, ind) << "Cond:", ind + 1);
@@ -330,7 +328,6 @@ public:
     Body->dump(indent(out, ind) << "Body:", ind + 1);
     return out;
   }
-  Value *Codegen() override;
 };
 
 /// VarExprAST - Expression class for var/in
@@ -343,6 +340,7 @@ public:
       std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames,
       std::unique_ptr<ExprAST> Body)
       : VarNames(std::move(VarNames)), Body(std::move(Body)) {}
+  Value *codegen() override;
   raw_ostream &dump(raw_ostream &out, int ind) override {
     ExprAST::dump(out << "var", ind);
     for (const auto &NamedVar : VarNames)
@@ -350,7 +348,6 @@ public:
     Body->dump(indent(out, ind) << "Body:", ind + 1);
     return out;
   }
-  Value *Codegen() override;
 };
 
 /// PrototypeAST - This class represents the "prototype" for a function,
@@ -369,6 +366,8 @@ public:
                unsigned Prec = 0)
       : Name(Name), Args(std::move(Args)), IsOperator(IsOperator),
         Precedence(Prec), Line(Loc.Line) {}
+  Function *codegen();
+  const std::string &getName() const { return Name; }
 
   bool isUnaryOp() const { return IsOperator && Args.size() == 1; }
   bool isBinaryOp() const { return IsOperator && Args.size() == 2; }
@@ -379,11 +378,7 @@ public:
   }
 
   unsigned getBinaryPrecedence() const { return Precedence; }
-
-  Function *Codegen();
-
-  void CreateArgumentAllocas(Function *F);
-  const std::vector<std::string> &getArgs() const { return Args; }
+  int getLine() const { return Line; }
 };
 
 /// FunctionAST - This class represents a function definition itself.
@@ -395,15 +390,13 @@ public:
   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
               std::unique_ptr<ExprAST> Body)
       : Proto(std::move(Proto)), Body(std::move(Body)) {}
-
+  Function *codegen();
   raw_ostream &dump(raw_ostream &out, int ind) {
     indent(out, ind) << "FunctionAST\n";
     ++ind;
     indent(out, ind) << "Body:";
     return Body ? Body->dump(out, ind) : out << "null\n";
   }
-
-  Function *Codegen();
 };
 } // end anonymous namespace
 
@@ -442,10 +435,6 @@ std::unique_ptr<PrototypeAST> ErrorP(const char *Str) {
   Error(Str);
   return nullptr;
 }
-std::unique_ptr<FunctionAST> ErrorF(const char *Str) {
-  Error(Str);
-  return nullptr;
-}
 
 static std::unique_ptr<ExprAST> ParseExpression();
 
@@ -809,7 +798,7 @@ static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
   SourceLocation FnLoc = CurLoc;
   if (auto E = ParseExpression()) {
     // Make an anonymous proto.
-    auto Proto = llvm::make_unique<PrototypeAST>(FnLoc, "main",
+    auto Proto = llvm::make_unique<PrototypeAST>(FnLoc, "__anon_expr",
                                                  std::vector<std::string>());
     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
   }
@@ -826,7 +815,7 @@ static std::unique_ptr<PrototypeAST> ParseExtern() {
 // Debug Info Support
 //===----------------------------------------------------------------------===//
 
-static DIBuilder *DBuilder;
+static std::unique_ptr<DIBuilder> DBuilder;
 
 DIType *DebugInfo::getDoubleTy() {
   if (DblTy)
@@ -866,15 +855,31 @@ static DISubroutineType *CreateFunctionType(unsigned NumArgs, DIFile *Unit) {
 // Code Generation
 //===----------------------------------------------------------------------===//
 
-static Module *TheModule;
+static std::unique_ptr<Module> TheModule;
 static std::map<std::string, AllocaInst *> NamedValues;
-static legacy::FunctionPassManager *TheFPM;
+static std::unique_ptr<KaleidoscopeJIT> TheJIT;
+static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
 
 Value *ErrorV(const char *Str) {
   Error(Str);
   return nullptr;
 }
 
+Function *getFunction(std::string Name) {
+  // First, see if the function has already been added to the current module.
+  if (auto *F = TheModule->getFunction(Name))
+    return F;
+
+  // If not, check whether we can codegen the declaration from some existing
+  // prototype.
+  auto FI = FunctionProtos.find(Name);
+  if (FI != FunctionProtos.end())
+    return FI->second->codegen();
+
+  // If no existing prototype exists, return null.
+  return nullptr;
+}
+
 /// CreateEntryBlockAlloca - Create an alloca instruction in the entry block of
 /// the function.  This is used for mutable variables etc.
 static AllocaInst *CreateEntryBlockAlloca(Function *TheFunction,
@@ -885,12 +890,12 @@ static AllocaInst *CreateEntryBlockAlloca(Function *TheFunction,
                            VarName.c_str());
 }
 
-Value *NumberExprAST::Codegen() {
+Value *NumberExprAST::codegen() {
   KSDbgInfo.emitLocation(this);
   return ConstantFP::get(getGlobalContext(), APFloat(Val));
 }
 
-Value *VariableExprAST::Codegen() {
+Value *VariableExprAST::codegen() {
   // Look this variable up in the function.
   Value *V = NamedValues[Name];
   if (!V)
@@ -901,12 +906,12 @@ Value *VariableExprAST::Codegen() {
   return Builder.CreateLoad(V, Name.c_str());
 }
 
-Value *UnaryExprAST::Codegen() {
-  Value *OperandV = Operand->Codegen();
+Value *UnaryExprAST::codegen() {
+  Value *OperandV = Operand->codegen();
   if (!OperandV)
     return nullptr;
 
-  Function *F = TheModule->getFunction(std::string("unary") + Opcode);
+  Function *F = getFunction(std::string("unary") + Opcode);
   if (!F)
     return ErrorV("Unknown unary operator");
 
@@ -914,7 +919,7 @@ Value *UnaryExprAST::Codegen() {
   return Builder.CreateCall(F, OperandV, "unop");
 }
 
-Value *BinaryExprAST::Codegen() {
+Value *BinaryExprAST::codegen() {
   KSDbgInfo.emitLocation(this);
 
   // Special case '=' because we don't want to emit the LHS as an expression.
@@ -927,7 +932,7 @@ Value *BinaryExprAST::Codegen() {
     if (!LHSE)
       return ErrorV("destination of '=' must be a variable");
     // Codegen the RHS.
-    Value *Val = RHS->Codegen();
+    Value *Val = RHS->codegen();
     if (!Val)
       return nullptr;
 
@@ -940,8 +945,8 @@ Value *BinaryExprAST::Codegen() {
     return Val;
   }
 
-  Value *L = LHS->Codegen();
-  Value *R = RHS->Codegen();
+  Value *L = LHS->codegen();
+  Value *R = RHS->codegen();
   if (!L || !R)
     return nullptr;
 
@@ -963,18 +968,18 @@ Value *BinaryExprAST::Codegen() {
 
   // If it wasn't a builtin binary operator, it must be a user defined one. Emit
   // a call to it.
-  Function *F = TheModule->getFunction(std::string("binary") + Op);
+  Function *F = getFunction(std::string("binary") + Op);
   assert(F && "binary operator not found!");
 
   Value *Ops[] = {L, R};
   return Builder.CreateCall(F, Ops, "binop");
 }
 
-Value *CallExprAST::Codegen() {
+Value *CallExprAST::codegen() {
   KSDbgInfo.emitLocation(this);
 
   // Look up the name in the global module table.
-  Function *CalleeF = TheModule->getFunction(Callee);
+  Function *CalleeF = getFunction(Callee);
   if (!CalleeF)
     return ErrorV("Unknown function referenced");
 
@@ -984,7 +989,7 @@ Value *CallExprAST::Codegen() {
 
   std::vector<Value *> ArgsV;
   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
-    ArgsV.push_back(Args[i]->Codegen());
+    ArgsV.push_back(Args[i]->codegen());
     if (!ArgsV.back())
       return nullptr;
   }
@@ -992,10 +997,10 @@ Value *CallExprAST::Codegen() {
   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
 }
 
-Value *IfExprAST::Codegen() {
+Value *IfExprAST::codegen() {
   KSDbgInfo.emitLocation(this);
 
-  Value *CondV = Cond->Codegen();
+  Value *CondV = Cond->codegen();
   if (!CondV)
     return nullptr;
 
@@ -1017,7 +1022,7 @@ Value *IfExprAST::Codegen() {
   // Emit then value.
   Builder.SetInsertPoint(ThenBB);
 
-  Value *ThenV = Then->Codegen();
+  Value *ThenV = Then->codegen();
   if (!ThenV)
     return nullptr;
 
@@ -1029,7 +1034,7 @@ Value *IfExprAST::Codegen() {
   TheFunction->getBasicBlockList().push_back(ElseBB);
   Builder.SetInsertPoint(ElseBB);
 
-  Value *ElseV = Else->Codegen();
+  Value *ElseV = Else->codegen();
   if (!ElseV)
     return nullptr;
 
@@ -1067,7 +1072,7 @@ Value *IfExprAST::Codegen() {
 //   store nextvar -> var
 //   br endcond, loop, endloop
 // outloop:
-Value *ForExprAST::Codegen() {
+Value *ForExprAST::codegen() {
   Function *TheFunction = Builder.GetInsertBlock()->getParent();
 
   // Create an alloca for the variable in the entry block.
@@ -1076,7 +1081,7 @@ Value *ForExprAST::Codegen() {
   KSDbgInfo.emitLocation(this);
 
   // Emit the start code first, without 'variable' in scope.
-  Value *StartVal = Start->Codegen();
+  Value *StartVal = Start->codegen();
   if (!StartVal)
     return nullptr;
 
@@ -1102,13 +1107,13 @@ Value *ForExprAST::Codegen() {
   // Emit the body of the loop.  This, like any other expr, can change the
   // current BB.  Note that we ignore the value computed by the body, but don't
   // allow an error.
-  if (!Body->Codegen())
+  if (!Body->codegen())
     return nullptr;
 
   // Emit the step value.
   Value *StepVal = nullptr;
   if (Step) {
-    StepVal = Step->Codegen();
+    StepVal = Step->codegen();
     if (!StepVal)
       return nullptr;
   } else {
@@ -1117,7 +1122,7 @@ Value *ForExprAST::Codegen() {
   }
 
   // Compute the end condition.
-  Value *EndCond = End->Codegen();
+  Value *EndCond = End->codegen();
   if (!EndCond)
     return nullptr;
 
@@ -1151,7 +1156,7 @@ Value *ForExprAST::Codegen() {
   return Constant::getNullValue(Type::getDoubleTy(getGlobalContext()));
 }
 
-Value *VarExprAST::Codegen() {
+Value *VarExprAST::codegen() {
   std::vector<AllocaInst *> OldBindings;
 
   Function *TheFunction = Builder.GetInsertBlock()->getParent();
@@ -1168,7 +1173,7 @@ Value *VarExprAST::Codegen() {
     //    var a = a in ...   # refers to outer 'a'.
     Value *InitVal;
     if (Init) {
-      InitVal = Init->Codegen();
+      InitVal = Init->codegen();
       if (!InitVal)
         return nullptr;
     } else { // If not specified, use 0.0.
@@ -1189,7 +1194,7 @@ Value *VarExprAST::Codegen() {
   KSDbgInfo.emitLocation(this);
 
   // Codegen the body, now that all vars are in scope.
-  Value *BodyVal = Body->Codegen();
+  Value *BodyVal = Body->codegen();
   if (!BodyVal)
     return nullptr;
 
@@ -1201,7 +1206,7 @@ Value *VarExprAST::Codegen() {
   return BodyVal;
 }
 
-Function *PrototypeAST::Codegen() {
+Function *PrototypeAST::codegen() {
   // Make the function type:  double(double,double) etc.
   std::vector<Type *> Doubles(Args.size(),
                               Type::getDoubleTy(getGlobalContext()));
@@ -1209,105 +1214,79 @@ Function *PrototypeAST::Codegen() {
       FunctionType::get(Type::getDoubleTy(getGlobalContext()), Doubles, false);
 
   Function *F =
-      Function::Create(FT, Function::ExternalLinkage, Name, TheModule);
-
-  // If F conflicted, there was already something named 'Name'.  If it has a
-  // body, don't allow redefinition or reextern.
-  if (F->getName() != Name) {
-    // Delete the one we just made and get the existing one.
-    F->eraseFromParent();
-    F = TheModule->getFunction(Name);
-
-    // If F already has a body, reject this.
-    if (!F->empty()) {
-      ErrorF("redefinition of function");
-      return nullptr;
-    }
-
-    // If F took a different number of args, reject.
-    if (F->arg_size() != Args.size()) {
-      ErrorF("redefinition of function with different # args");
-      return nullptr;
-    }
-  }
+      Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
 
   // Set names for all arguments.
   unsigned Idx = 0;
-  for (Function::arg_iterator AI = F->arg_begin(); Idx != Args.size();
-       ++AI, ++Idx)
-    AI->setName(Args[Idx]);
+  for (auto &Arg : F->args())
+    Arg.setName(Args[Idx++]);
+
+  return F;
+}
+
+Function *FunctionAST::codegen() {
+  // Transfer ownership of the prototype to the FunctionProtos map, but keep a
+  // reference to it for use below.
+  auto &P = *Proto;
+  FunctionProtos[Proto->getName()] = std::move(Proto);
+  Function *TheFunction = getFunction(P.getName());
+  if (!TheFunction)
+    return nullptr;
+
+  // If this is an operator, install it.
+  if (P.isBinaryOp())
+    BinopPrecedence[P.getOperatorName()] = P.getBinaryPrecedence();
+
+  // Create a new basic block to start insertion into.
+  BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
+  Builder.SetInsertPoint(BB);
 
   // Create a subprogram DIE for this function.
   DIFile *Unit = DBuilder->createFile(KSDbgInfo.TheCU->getFilename(),
                                       KSDbgInfo.TheCU->getDirectory());
   DIScope *FContext = Unit;
-  unsigned LineNo = Line;
-  unsigned ScopeLine = Line;
+  unsigned LineNo = P.getLine();
+  unsigned ScopeLine = LineNo;
   DISubprogram *SP = DBuilder->createFunction(
-      FContext, Name, StringRef(), Unit, LineNo,
-      CreateFunctionType(Args.size(), Unit), false /* internal linkage */,
-      true /* definition */, ScopeLine, DINode::FlagPrototyped, false, F);
+      FContext, P.getName(), StringRef(), Unit, LineNo,
+      CreateFunctionType(TheFunction->arg_size(), Unit),
+      false /* internal linkage */, true /* definition */, ScopeLine,
+      DINode::FlagPrototyped, false, TheFunction);
 
-  KSDbgInfo.FnScopeMap[this] = SP;
-  return F;
-}
+  // Push the current scope.
+  KSDbgInfo.LexicalBlocks.push_back(SP);
+
+  // Unset the location for the prologue emission (leading instructions with no
+  // location in a function are considered part of the prologue and the debugger
+  // will run past them when breaking on a function)
+  KSDbgInfo.emitLocation(nullptr);
 
-/// CreateArgumentAllocas - Create an alloca for each argument and register the
-/// argument in the symbol table so that references to it will succeed.
-void PrototypeAST::CreateArgumentAllocas(Function *F) {
-  Function::arg_iterator AI = F->arg_begin();
-  for (unsigned Idx = 0, e = Args.size(); Idx != e; ++Idx, ++AI) {
+  // Record the function arguments in the NamedValues map.
+  NamedValues.clear();
+  unsigned ArgIdx = 0;
+  for (auto &Arg : TheFunction->args()) {
     // Create an alloca for this variable.
-    AllocaInst *Alloca = CreateEntryBlockAlloca(F, Args[Idx]);
+    AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, Arg.getName());
 
     // Create a debug descriptor for the variable.
-    DIScope *Scope = KSDbgInfo.LexicalBlocks.back();
-    DIFile *Unit = DBuilder->createFile(KSDbgInfo.TheCU->getFilename(),
-                                        KSDbgInfo.TheCU->getDirectory());
     DILocalVariable *D = DBuilder->createParameterVariable(
-        Scope, Args[Idx], Idx + 1, Unit, Line, KSDbgInfo.getDoubleTy(), true);
+        SP, Arg.getName(), ++ArgIdx, Unit, LineNo, KSDbgInfo.getDoubleTy(),
+        true);
 
     DBuilder->insertDeclare(Alloca, D, DBuilder->createExpression(),
-                            DebugLoc::get(Line, 0, Scope),
+                            DebugLoc::get(LineNo, 0, SP),
                             Builder.GetInsertBlock());
 
     // Store the initial value into the alloca.
-    Builder.CreateStore(AI, Alloca);
+    Builder.CreateStore(&Arg, Alloca);
 
     // Add arguments to variable symbol table.
-    NamedValues[Args[Idx]] = Alloca;
+    NamedValues[Arg.getName()] = Alloca;
   }
-}
-
-Function *FunctionAST::Codegen() {
-  NamedValues.clear();
-
-  Function *TheFunction = Proto->Codegen();
-  if (!TheFunction)
-    return nullptr;
-
-  // Push the current scope.
-  KSDbgInfo.LexicalBlocks.push_back(KSDbgInfo.FnScopeMap[Proto.get()]);
-
-  // Unset the location for the prologue emission (leading instructions with no
-  // location in a function are considered part of the prologue and the debugger
-  // will run past them when breaking on a function)
-  KSDbgInfo.emitLocation(nullptr);
-
-  // If this is an operator, install it.
-  if (Proto->isBinaryOp())
-    BinopPrecedence[Proto->getOperatorName()] = Proto->getBinaryPrecedence();
-
-  // Create a new basic block to start insertion into.
-  BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
-  Builder.SetInsertPoint(BB);
-
-  // Add all arguments to the symbol table and create their allocas.
-  Proto->CreateArgumentAllocas(TheFunction);
 
   KSDbgInfo.emitLocation(Body.get());
 
-  if (Value *RetVal = Body->Codegen()) {
+  if (Value *RetVal = Body->codegen()) {
     // Finish off the function.
     Builder.CreateRet(RetVal);
 
@@ -1317,16 +1296,13 @@ Function *FunctionAST::Codegen() {
     // Validate the generated code, checking for consistency.
     verifyFunction(*TheFunction);
 
-    // Optimize the function.
-    TheFPM->run(*TheFunction);
-
     return TheFunction;
   }
 
   // Error reading body, remove function.
   TheFunction->eraseFromParent();
 
-  if (Proto->isBinaryOp())
+  if (P.isBinaryOp())
     BinopPrecedence.erase(Proto->getOperatorName());
 
   // Pop off the lexical block for the function since we added it
@@ -1340,13 +1316,16 @@ Function *FunctionAST::Codegen() {
 // Top-Level parsing and JIT Driver
 //===----------------------------------------------------------------------===//
 
-static ExecutionEngine *TheExecutionEngine;
+static void InitializeModule() {
+  // Open a new module.
+  TheModule = llvm::make_unique<Module>("my cool jit", getGlobalContext());
+  TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
+}
 
 static void HandleDefinition() {
   if (auto FnAST = ParseDefinition()) {
-    if (!FnAST->Codegen()) {
+    if (!FnAST->codegen())
       fprintf(stderr, "Error reading function definition:");
-    }
   } else {
     // Skip token for error recovery.
     getNextToken();
@@ -1355,9 +1334,10 @@ static void HandleDefinition() {
 
 static void HandleExtern() {
   if (auto ProtoAST = ParseExtern()) {
-    if (!ProtoAST->Codegen()) {
+    if (!ProtoAST->codegen())
       fprintf(stderr, "Error reading extern");
-    }
+    else
+      FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
   } else {
     // Skip token for error recovery.
     getNextToken();
@@ -1367,7 +1347,7 @@ static void HandleExtern() {
 static void HandleTopLevelExpression() {
   // Evaluate a top-level expression into an anonymous function.
   if (auto FnAST = ParseTopLevelExpr()) {
-    if (!FnAST->Codegen()) {
+    if (!FnAST->codegen()) {
       fprintf(stderr, "Error generating code for top level expr");
     }
   } else {
@@ -1422,7 +1402,6 @@ int main() {
   InitializeNativeTarget();
   InitializeNativeTargetAsmPrinter();
   InitializeNativeTargetAsmParser();
-  LLVMContext &Context = getGlobalContext();
 
   // Install standard binary operators.
   // 1 is lowest precedence.
@@ -1435,9 +1414,9 @@ int main() {
   // Prime the first token.
   getNextToken();
 
-  // Make the module, which holds all the code.
-  std::unique_ptr<Module> Owner = make_unique<Module>("my cool jit", Context);
-  TheModule = Owner.get();
+  TheJIT = llvm::make_unique<KaleidoscopeJIT>();
+
+  InitializeModule();
 
   // Add the current debug info version into the module.
   TheModule->addModuleFlag(Module::Warning, "Debug Info Version",
@@ -1448,7 +1427,7 @@ int main() {
     TheModule->addModuleFlag(llvm::Module::Warning, "Dwarf Version", 2);
 
   // Construct the DIBuilder, we do this here because we need the module.
-  DBuilder = new DIBuilder(*TheModule);
+  DBuilder = llvm::make_unique<DIBuilder>(*TheModule);
 
   // Create the compile unit for the module.
   // Currently down as "fib.ks" as a filename since we're redirecting stdin
@@ -1456,47 +1435,9 @@ int main() {
   KSDbgInfo.TheCU = DBuilder->createCompileUnit(
       dwarf::DW_LANG_C, "fib.ks", ".", "Kaleidoscope Compiler", 0, "", 0);
 
-  // Create the JIT.  This takes ownership of the module.
-  std::string ErrStr;
-  TheExecutionEngine =
-      EngineBuilder(std::move(Owner))
-          .setErrorStr(&ErrStr)
-          .setMCJITMemoryManager(llvm::make_unique<SectionMemoryManager>())
-          .create();
-  if (!TheExecutionEngine) {
-    fprintf(stderr, "Could not create ExecutionEngine: %s\n", ErrStr.c_str());
-    exit(1);
-  }
-
-  legacy::FunctionPassManager OurFPM(TheModule);
-
-  // Set up the optimizer pipeline.  Start with registering info about how the
-  // target lays out data structures.
-  TheModule->setDataLayout(TheExecutionEngine->getDataLayout());
-#if 0
-  // Provide basic AliasAnalysis support for GVN.
-  OurFPM.add(createBasicAliasAnalysisPass());
-  // Promote allocas to registers.
-  OurFPM.add(createPromoteMemoryToRegisterPass());
-  // Do simple "peephole" optimizations and bit-twiddling optzns.
-  OurFPM.add(createInstructionCombiningPass());
-  // Reassociate expressions.
-  OurFPM.add(createReassociatePass());
-  // Eliminate Common SubExpressions.
-  OurFPM.add(createGVNPass());
-  // Simplify the control flow graph (deleting unreachable blocks, etc).
-  OurFPM.add(createCFGSimplificationPass());
-#endif
-  OurFPM.doInitialization();
-
-  // Set the global so the code gen can use this.
-  TheFPM = &OurFPM;
-
   // Run the main "interpreter loop" now.
   MainLoop();
 
-  TheFPM = 0;
-
   // Finalize the debug info.
   DBuilder->finalize();
 
diff --git a/examples/Kaleidoscope/include/KaleidoscopeJIT.h b/examples/Kaleidoscope/include/KaleidoscopeJIT.h
new file mode 100644 (file)
index 0000000..0c825cc
--- /dev/null
@@ -0,0 +1,114 @@
+//===----- KaleidoscopeJIT.h - A simple JIT for Kaleidoscope ----*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// Contains a simple JIT definition for use in the kaleidoscope tutorials.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
+#define LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
+
+#include "llvm/ExecutionEngine/ExecutionEngine.h"
+#include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
+#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
+#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
+#include "llvm/ExecutionEngine/Orc/LambdaResolver.h"
+#include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h"
+#include "llvm/IR/Mangler.h"
+#include "llvm/Support/DynamicLibrary.h"
+
+namespace llvm {
+namespace orc {
+
+class KaleidoscopeJIT {
+public:
+  typedef ObjectLinkingLayer<> ObjLayerT;
+  typedef IRCompileLayer<ObjLayerT> CompileLayerT;
+  typedef CompileLayerT::ModuleSetHandleT ModuleHandleT;
+
+  KaleidoscopeJIT()
+      : TM(EngineBuilder().selectTarget()), DL(TM->createDataLayout()),
+        CompileLayer(ObjectLayer, SimpleCompiler(*TM)) {
+    llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
+  }
+
+  TargetMachine &getTargetMachine() { return *TM; }
+
+  ModuleHandleT addModule(std::unique_ptr<Module> M) {
+    // We need a memory manager to allocate memory and resolve symbols for this
+    // new module. Create one that resolves symbols by looking back into the
+    // JIT.
+    auto Resolver = createLambdaResolver(
+        [&](const std::string &Name) {
+          if (auto Sym = findMangledSymbol(Name))
+            return RuntimeDyld::SymbolInfo(Sym.getAddress(), Sym.getFlags());
+          return RuntimeDyld::SymbolInfo(nullptr);
+        },
+        [](const std::string &S) { return nullptr; });
+    auto H = CompileLayer.addModuleSet(singletonSet(std::move(M)),
+                                       make_unique<SectionMemoryManager>(),
+                                       std::move(Resolver));
+
+    ModuleHandles.push_back(H);
+    return H;
+  }
+
+  void removeModule(ModuleHandleT H) {
+    ModuleHandles.erase(
+        std::find(ModuleHandles.begin(), ModuleHandles.end(), H));
+    CompileLayer.removeModuleSet(H);
+  }
+
+  JITSymbol findSymbol(const std::string Name) {
+    return findMangledSymbol(mangle(Name));
+  }
+
+private:
+
+  std::string mangle(const std::string &Name) {
+    std::string MangledName;
+    {
+      raw_string_ostream MangledNameStream(MangledName);
+      Mangler::getNameWithPrefix(MangledNameStream, Name, DL);
+    }
+    return MangledName;
+  }
+
+  template <typename T> static std::vector<T> singletonSet(T t) {
+    std::vector<T> Vec;
+    Vec.push_back(std::move(t));
+    return Vec;
+  }
+
+  JITSymbol findMangledSymbol(const std::string &Name) {
+    // Search modules in reverse order: from last added to first added.
+    // This is the opposite of the usual search order for dlsym, but makes more
+    // sense in a REPL where we want to bind to the newest available definition.
+    for (auto H : make_range(ModuleHandles.rbegin(), ModuleHandles.rend()))
+      if (auto Sym = CompileLayer.findSymbolIn(H, Name, true))
+        return Sym;
+
+    // If we can't find the symbol in the JIT, try looking in the host process.
+    if (auto SymAddr = RTDyldMemoryManager::getSymbolAddressInProcess(Name))
+      return JITSymbol(SymAddr, JITSymbolFlags::Exported);
+
+    return nullptr;
+  }
+
+  std::unique_ptr<TargetMachine> TM;
+  const DataLayout DL;
+  ObjLayerT ObjectLayer;
+  CompileLayerT CompileLayer;
+  std::vector<ModuleHandleT> ModuleHandles;
+};
+
+} // End namespace orc.
+} // End namespace llvm
+
+#endif // LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H