Implementing JIT (Just In Time) Compilation

Do not miss this exclusive book on Binary Tree Problems. Get it now for free.

JIT (Just in Time) compilation involves transforming bytecode into machine executable instructions bytecode. In this article, we will implement a JIT to our Kaleidoscope interpreter.

Table of contents.

  1. Introduction.
  2. The JIT compiler.
  3. Summary.

Prerequisites.

  1. LLVM Compiler optimizations

Introduction.

A JIT(Just-In-Time) compiler is a compiler that converts bytecode into instructions that can be executed by the target machine. JIT compilers are mainly used in cases where we want to improve or optimize the performance of the binary code during run time. For example the Java JIT compiler improves the performance of Java programs at run time.

A system with JIT continuously analyzes the code during execution and identifies blocks where speedup gained from compilation or recompilation outweigh the overhead of compiling the code.

The JIT compiler.

The LLVM IR we obtained from the prerequisite articles is common currency between many different compiler parts, for example we can convert it textual format, binary format, run optimizations, JIT compile it etc.

The basic idea for the JIT compiler is that, the user enters a function body and immediately evaluate the top-level expressions typed in. For example, if the user types in 1 + 2, the compiler evaluates and returns 3. If a function has been defined, it should be callable from the REPL.

For this we create an environment to create code for the current native target and initialize the JIT. We do this by calling InitializeNativeTarget functions and adding a global variable TheJIT then initializing it in main - the entry point of the program.

static std::unique_ptr<KaleidoscopeJIT> TheJIT;
...
int main() {
  InitializeNativeTarget();
  InitializeNativeTargetAsmPrinter();
  InitializeNativeTargetAsmParser();

  // Install standard binary operators.
  // 1 is lowest precedence.
  BinopPrecedence['<'] = 10;
  BinopPrecedence['+'] = 20;
  BinopPrecedence['-'] = 20;
  BinopPrecedence['*'] = 40; // highest.

  // Prime the first token.
  fprintf(stderr, "ready> ");
  getNextToken();

  TheJIT = std::make_unique<KaleidoscopeJIT>();

  // Run the main "interpreter loop" now.
  MainLoop();

  return 0;
}

KaleidoscopeJIT is a JIT built for this tutorial, it is in the LLVM source code here. Its API is simple enough, addModule adds an LLVM IR module to the JIT making its functions available for execution, removeModule removes a module, this also frees any memory associated with the code in the module and findSymbol enables us to loop up pointers to the compiled code.

We then set up the data layout for the JIT, for this we add the following line to our InitializeModuleAndPassManager function;

TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout()); // JIT

We use the API and change the code that parses top-level expressions. Our handleTopLevelExpression now looks like the following;

static void handleTopLevelExpression()
{
    if (auto FnAST = ParseTopLevelExpr()) // evaluate top-level expression into anonymous function
    {
        if (FnAST->codegen())
        {
            auto RT = TheJIT->getMainJITDylib().createResourceTracker(); // create resource tracker, tracks JIT allocated memory

            auto TSM = ThreadSafeModule(move(TheModule), move(TheContext));
            ExitOnErr(TheJIT->addModule(move(TSM), RT));
            InitializeModuleAndPassManager();

            auto ExprSymbol = ExitOnErr(TheJIT->lookup("__anon_expr")); // search JIT for __anon_ expression

            double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress(); // get symbol address
            fprintf(stderr, "Evaluated to %f\n", FP());

            ExitOnErr(RT->remove()); // remove anonymous expression module from JIT
        }
    }
    else
    {
        getNextToken(); // skip token, error recovery
    }
}

If parsing and code generation is successful, we add the module with the top-level expressions to the JIT. For this, we call addModule which triggers code generation for all functions in the module and returns a handle used to remove the module from the JIT later.

When the module is added to the JIT, it can't be modified therefore we open a new module which holds subsequent code. For this we call InitializeModuleAndPassManager().

When the module is added to JIT, we get a pointer to the final generated code by calling the JIT findSymbol method and passing the name of the top-level expression. Since we just added this function, findSymol is returned as the result.

We then get the in-memory address of the name of the top-level expression - __anon_expr by calling getAddress on the symbol.
We compile top-level expressions into a self-contained LLVM function that takes no arguments and returns a double.

Since we don't support re-evaluation of top-level expressions, we remove the module so as to free memory.

We can also allow functions to live in their own module, for this we need to re-generate previous function declarations into each new module we open;

static std::unique_ptr<KaleidoscopeJIT> TheJIT;

...

Function *getFunction(std::string Name) {
  // First, see if the function has already been added to the current module.
  if (auto *F = TheModule->getFunction(Name))
    return F;

  // If not, check whether we can codegen the declaration from some existing
  // prototype.
  auto FI = FunctionProtos.find(Name);
  if (FI != FunctionProtos.end())
    return FI->second->codegen();

  // If no existing prototype exists, return null.
  return nullptr;
}

...

Value *CallExprAST::codegen() {
  // Look up the name in the global module table.
  Function *CalleeF = getFunction(Callee);

...

Function *FunctionAST::codegen() {
  // Transfer ownership of the prototype to the FunctionProtos map, but keep a
  // reference to it for use below.
  auto &P = *Proto;
  FunctionProtos[Proto->getName()] = std::move(Proto);
  Function *TheFunction = getFunction(P.getName());
  if (!TheFunction)
    return nullptr;

First we add a new global - FunctionProtos. It holds the most recent function prototype. We also add getFunction() method which replaces calls to TheModule->getFunction(). It searches for an existing function declaration and falls back to generating a new declaration from FunctionProtos if it does not find one.

In CallExprAST::codegen() we replace the call to TheModule->getFunction() and in FunctionAST::codegen() we update the FunctionProtos map then call getFunction() and we are done.

Now to update the helper functions handleDefinition and handleExtern.

static void HandleDefinition() {
  if (auto FnAST = ParseDefinition()) {
    if (auto *FnIR = FnAST->codegen()) {
      fprintf(stderr, "Read function definition:");
      FnIR->print(errs());
      fprintf(stderr, "\n");
      TheJIT->addModule(std::move(TheModule));
      InitializeModuleAndPassManager();
    }
  } else {
    // Skip token for error recovery.
     getNextToken();
  }
}

static void HandleExtern() {
  if (auto ProtoAST = ParseExtern()) {
    if (auto *FnIR = ProtoAST->codegen()) {
      fprintf(stderr, "Read extern: ");
      FnIR->print(errs());
      fprintf(stderr, "\n");
      FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
    }
  } else {
    // Skip token for error recovery.
    getNextToken();
  }
}

In the first helper function handleDefinition, we add two lines which transfer the newly defined function to the JIT and open a new module.
In the handleExtern, we add a single line that adds the prototype to FunctionProtos.

Summary.

JIT compilation involves executing code whereby we compile code during execution - at run time.

Sign up for FREE 3 months of Amazon Music. YOU MUST NOT MISS.