1 #include "llvm/ADT/APFloat.h"
2 #include "llvm/ADT/STLExtras.h"
3 #include "llvm/IR/BasicBlock.h"
4 #include "llvm/IR/Constants.h"
5 #include "llvm/IR/DerivedTypes.h"
6 #include "llvm/IR/Function.h"
7 #include "llvm/IR/IRBuilder.h"
8 #include "llvm/IR/LLVMContext.h"
9 #include "llvm/IR/LegacyPassManager.h"
10 #include "llvm/IR/Module.h"
11 #include "llvm/IR/Type.h"
12 #include "llvm/IR/Verifier.h"
13 #include "llvm/Support/TargetSelect.h"
14 #include "llvm/Target/TargetMachine.h"
15 #include "llvm/Transforms/Scalar.h"
16 #include "llvm/Transforms/Scalar/GVN.h"
17 #include "../include/KaleidoscopeJIT.h"
18 #include <algorithm>
19 #include <cassert>
20 #include <cctype>
21 #include <cstdint>
22 #include <cstdio>
23 #include <cstdlib>
24 #include <map>
25 #include <memory>
26 #include <string>
27 #include <vector>
28 
29 using namespace llvm;
30 using namespace llvm::orc;
31 
32 //===----------------------------------------------------------------------===//
33 // Lexer
34 //===----------------------------------------------------------------------===//
35 
36 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
37 // of these for known things.
38 enum Token {
39   tok_eof = -1,
40 
41   // commands
42   tok_def = -2,
43   tok_extern = -3,
44 
45   // primary
46   tok_identifier = -4,
47   tok_number = -5
48 };
49 
50 static std::string IdentifierStr; // Filled in if tok_identifier
51 static double NumVal;             // Filled in if tok_number
52 
53 /// gettok - Return the next token from standard input.
54 static int gettok() {
55   static int LastChar = ' ';
56 
57   // Skip any whitespace.
58   while (isspace(LastChar))
59     LastChar = getchar();
60 
61   if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
62     IdentifierStr = LastChar;
63     while (isalnum((LastChar = getchar())))
64       IdentifierStr += LastChar;
65 
66     if (IdentifierStr == "def")
67       return tok_def;
68     if (IdentifierStr == "extern")
69       return tok_extern;
70     return tok_identifier;
71   }
72 
73   if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
74     std::string NumStr;
75     do {
76       NumStr += LastChar;
77       LastChar = getchar();
78     } while (isdigit(LastChar) || LastChar == '.');
79 
80     NumVal = strtod(NumStr.c_str(), nullptr);
81     return tok_number;
82   }
83 
84   if (LastChar == '#') {
85     // Comment until end of line.
86     do
87       LastChar = getchar();
88     while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
89 
90     if (LastChar != EOF)
91       return gettok();
92   }
93 
94   // Check for end of file.  Don't eat the EOF.
95   if (LastChar == EOF)
96     return tok_eof;
97 
98   // Otherwise, just return the character as its ascii value.
99   int ThisChar = LastChar;
100   LastChar = getchar();
101   return ThisChar;
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // Abstract Syntax Tree (aka Parse Tree)
106 //===----------------------------------------------------------------------===//
107 
108 namespace {
109 
110 /// ExprAST - Base class for all expression nodes.
111 class ExprAST {
112 public:
113   virtual ~ExprAST() = default;
114 
115   virtual Value *codegen() = 0;
116 };
117 
118 /// NumberExprAST - Expression class for numeric literals like "1.0".
119 class NumberExprAST : public ExprAST {
120   double Val;
121 
122 public:
123   NumberExprAST(double Val) : Val(Val) {}
124 
125   Value *codegen() override;
126 };
127 
128 /// VariableExprAST - Expression class for referencing a variable, like "a".
129 class VariableExprAST : public ExprAST {
130   std::string Name;
131 
132 public:
133   VariableExprAST(const std::string &Name) : Name(Name) {}
134 
135   Value *codegen() override;
136 };
137 
138 /// BinaryExprAST - Expression class for a binary operator.
139 class BinaryExprAST : public ExprAST {
140   char Op;
141   std::unique_ptr<ExprAST> LHS, RHS;
142 
143 public:
144   BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
145                 std::unique_ptr<ExprAST> RHS)
146       : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
147 
148   Value *codegen() override;
149 };
150 
151 /// CallExprAST - Expression class for function calls.
152 class CallExprAST : public ExprAST {
153   std::string Callee;
154   std::vector<std::unique_ptr<ExprAST>> Args;
155 
156 public:
157   CallExprAST(const std::string &Callee,
158               std::vector<std::unique_ptr<ExprAST>> Args)
159       : Callee(Callee), Args(std::move(Args)) {}
160 
161   Value *codegen() override;
162 };
163 
164 /// PrototypeAST - This class represents the "prototype" for a function,
165 /// which captures its name, and its argument names (thus implicitly the number
166 /// of arguments the function takes).
167 class PrototypeAST {
168   std::string Name;
169   std::vector<std::string> Args;
170 
171 public:
172   PrototypeAST(const std::string &Name, std::vector<std::string> Args)
173       : Name(Name), Args(std::move(Args)) {}
174 
175   Function *codegen();
176   const std::string &getName() const { return Name; }
177 };
178 
179 /// FunctionAST - This class represents a function definition itself.
180 class FunctionAST {
181   std::unique_ptr<PrototypeAST> Proto;
182   std::unique_ptr<ExprAST> Body;
183 
184 public:
185   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
186               std::unique_ptr<ExprAST> Body)
187       : Proto(std::move(Proto)), Body(std::move(Body)) {}
188 
189   Function *codegen();
190 };
191 
192 } // end anonymous namespace
193 
194 //===----------------------------------------------------------------------===//
195 // Parser
196 //===----------------------------------------------------------------------===//
197 
198 /// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
199 /// token the parser is looking at.  getNextToken reads another token from the
200 /// lexer and updates CurTok with its results.
201 static int CurTok;
202 static int getNextToken() { return CurTok = gettok(); }
203 
204 /// BinopPrecedence - This holds the precedence for each binary operator that is
205 /// defined.
206 static std::map<char, int> BinopPrecedence;
207 
208 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
209 static int GetTokPrecedence() {
210   if (!isascii(CurTok))
211     return -1;
212 
213   // Make sure it's a declared binop.
214   int TokPrec = BinopPrecedence[CurTok];
215   if (TokPrec <= 0)
216     return -1;
217   return TokPrec;
218 }
219 
220 /// LogError* - These are little helper functions for error handling.
221 std::unique_ptr<ExprAST> LogError(const char *Str) {
222   fprintf(stderr, "Error: %s\n", Str);
223   return nullptr;
224 }
225 
226 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
227   LogError(Str);
228   return nullptr;
229 }
230 
231 static std::unique_ptr<ExprAST> ParseExpression();
232 
233 /// numberexpr ::= number
234 static std::unique_ptr<ExprAST> ParseNumberExpr() {
235   auto Result = llvm::make_unique<NumberExprAST>(NumVal);
236   getNextToken(); // consume the number
237   return std::move(Result);
238 }
239 
240 /// parenexpr ::= '(' expression ')'
241 static std::unique_ptr<ExprAST> ParseParenExpr() {
242   getNextToken(); // eat (.
243   auto V = ParseExpression();
244   if (!V)
245     return nullptr;
246 
247   if (CurTok != ')')
248     return LogError("expected ')'");
249   getNextToken(); // eat ).
250   return V;
251 }
252 
253 /// identifierexpr
254 ///   ::= identifier
255 ///   ::= identifier '(' expression* ')'
256 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
257   std::string IdName = IdentifierStr;
258 
259   getNextToken(); // eat identifier.
260 
261   if (CurTok != '(') // Simple variable ref.
262     return llvm::make_unique<VariableExprAST>(IdName);
263 
264   // Call.
265   getNextToken(); // eat (
266   std::vector<std::unique_ptr<ExprAST>> Args;
267   if (CurTok != ')') {
268     while (true) {
269       if (auto Arg = ParseExpression())
270         Args.push_back(std::move(Arg));
271       else
272         return nullptr;
273 
274       if (CurTok == ')')
275         break;
276 
277       if (CurTok != ',')
278         return LogError("Expected ')' or ',' in argument list");
279       getNextToken();
280     }
281   }
282 
283   // Eat the ')'.
284   getNextToken();
285 
286   return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
287 }
288 
289 /// primary
290 ///   ::= identifierexpr
291 ///   ::= numberexpr
292 ///   ::= parenexpr
293 static std::unique_ptr<ExprAST> ParsePrimary() {
294   switch (CurTok) {
295   default:
296     return LogError("unknown token when expecting an expression");
297   case tok_identifier:
298     return ParseIdentifierExpr();
299   case tok_number:
300     return ParseNumberExpr();
301   case '(':
302     return ParseParenExpr();
303   }
304 }
305 
306 /// binoprhs
307 ///   ::= ('+' primary)*
308 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
309                                               std::unique_ptr<ExprAST> LHS) {
310   // If this is a binop, find its precedence.
311   while (true) {
312     int TokPrec = GetTokPrecedence();
313 
314     // If this is a binop that binds at least as tightly as the current binop,
315     // consume it, otherwise we are done.
316     if (TokPrec < ExprPrec)
317       return LHS;
318 
319     // Okay, we know this is a binop.
320     int BinOp = CurTok;
321     getNextToken(); // eat binop
322 
323     // Parse the primary expression after the binary operator.
324     auto RHS = ParsePrimary();
325     if (!RHS)
326       return nullptr;
327 
328     // If BinOp binds less tightly with RHS than the operator after RHS, let
329     // the pending operator take RHS as its LHS.
330     int NextPrec = GetTokPrecedence();
331     if (TokPrec < NextPrec) {
332       RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
333       if (!RHS)
334         return nullptr;
335     }
336 
337     // Merge LHS/RHS.
338     LHS =
339         llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
340   }
341 }
342 
343 /// expression
344 ///   ::= primary binoprhs
345 ///
346 static std::unique_ptr<ExprAST> ParseExpression() {
347   auto LHS = ParsePrimary();
348   if (!LHS)
349     return nullptr;
350 
351   return ParseBinOpRHS(0, std::move(LHS));
352 }
353 
354 /// prototype
355 ///   ::= id '(' id* ')'
356 static std::unique_ptr<PrototypeAST> ParsePrototype() {
357   if (CurTok != tok_identifier)
358     return LogErrorP("Expected function name in prototype");
359 
360   std::string FnName = IdentifierStr;
361   getNextToken();
362 
363   if (CurTok != '(')
364     return LogErrorP("Expected '(' in prototype");
365 
366   std::vector<std::string> ArgNames;
367   while (getNextToken() == tok_identifier)
368     ArgNames.push_back(IdentifierStr);
369   if (CurTok != ')')
370     return LogErrorP("Expected ')' in prototype");
371 
372   // success.
373   getNextToken(); // eat ')'.
374 
375   return llvm::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
376 }
377 
378 /// definition ::= 'def' prototype expression
379 static std::unique_ptr<FunctionAST> ParseDefinition() {
380   getNextToken(); // eat def.
381   auto Proto = ParsePrototype();
382   if (!Proto)
383     return nullptr;
384 
385   if (auto E = ParseExpression())
386     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
387   return nullptr;
388 }
389 
390 /// toplevelexpr ::= expression
391 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
392   if (auto E = ParseExpression()) {
393     // Make an anonymous proto.
394     auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
395                                                  std::vector<std::string>());
396     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
397   }
398   return nullptr;
399 }
400 
401 /// external ::= 'extern' prototype
402 static std::unique_ptr<PrototypeAST> ParseExtern() {
403   getNextToken(); // eat extern.
404   return ParsePrototype();
405 }
406 
407 //===----------------------------------------------------------------------===//
408 // Code Generation
409 //===----------------------------------------------------------------------===//
410 
411 static LLVMContext TheContext;
412 static IRBuilder<> Builder(TheContext);
413 static std::unique_ptr<Module> TheModule;
414 static std::map<std::string, Value *> NamedValues;
415 static std::unique_ptr<legacy::FunctionPassManager> TheFPM;
416 static std::unique_ptr<KaleidoscopeJIT> TheJIT;
417 static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
418 
419 Value *LogErrorV(const char *Str) {
420   LogError(Str);
421   return nullptr;
422 }
423 
424 Function *getFunction(std::string Name) {
425   // First, see if the function has already been added to the current module.
426   if (auto *F = TheModule->getFunction(Name))
427     return F;
428 
429   // If not, check whether we can codegen the declaration from some existing
430   // prototype.
431   auto FI = FunctionProtos.find(Name);
432   if (FI != FunctionProtos.end())
433     return FI->second->codegen();
434 
435   // If no existing prototype exists, return null.
436   return nullptr;
437 }
438 
439 Value *NumberExprAST::codegen() {
440   return ConstantFP::get(TheContext, APFloat(Val));
441 }
442 
443 Value *VariableExprAST::codegen() {
444   // Look this variable up in the function.
445   Value *V = NamedValues[Name];
446   if (!V)
447     return LogErrorV("Unknown variable name");
448   return V;
449 }
450 
451 Value *BinaryExprAST::codegen() {
452   Value *L = LHS->codegen();
453   Value *R = RHS->codegen();
454   if (!L || !R)
455     return nullptr;
456 
457   switch (Op) {
458   case '+':
459     return Builder.CreateFAdd(L, R, "addtmp");
460   case '-':
461     return Builder.CreateFSub(L, R, "subtmp");
462   case '*':
463     return Builder.CreateFMul(L, R, "multmp");
464   case '<':
465     L = Builder.CreateFCmpULT(L, R, "cmptmp");
466     // Convert bool 0/1 to double 0.0 or 1.0
467     return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp");
468   default:
469     return LogErrorV("invalid binary operator");
470   }
471 }
472 
473 Value *CallExprAST::codegen() {
474   // Look up the name in the global module table.
475   Function *CalleeF = getFunction(Callee);
476   if (!CalleeF)
477     return LogErrorV("Unknown function referenced");
478 
479   // If argument mismatch error.
480   if (CalleeF->arg_size() != Args.size())
481     return LogErrorV("Incorrect # arguments passed");
482 
483   std::vector<Value *> ArgsV;
484   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
485     ArgsV.push_back(Args[i]->codegen());
486     if (!ArgsV.back())
487       return nullptr;
488   }
489 
490   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
491 }
492 
493 Function *PrototypeAST::codegen() {
494   // Make the function type:  double(double,double) etc.
495   std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(TheContext));
496   FunctionType *FT =
497       FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false);
498 
499   Function *F =
500       Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
501 
502   // Set names for all arguments.
503   unsigned Idx = 0;
504   for (auto &Arg : F->args())
505     Arg.setName(Args[Idx++]);
506 
507   return F;
508 }
509 
510 Function *FunctionAST::codegen() {
511   // Transfer ownership of the prototype to the FunctionProtos map, but keep a
512   // reference to it for use below.
513   auto &P = *Proto;
514   FunctionProtos[Proto->getName()] = std::move(Proto);
515   Function *TheFunction = getFunction(P.getName());
516   if (!TheFunction)
517     return nullptr;
518 
519   // Create a new basic block to start insertion into.
520   BasicBlock *BB = BasicBlock::Create(TheContext, "entry", TheFunction);
521   Builder.SetInsertPoint(BB);
522 
523   // Record the function arguments in the NamedValues map.
524   NamedValues.clear();
525   for (auto &Arg : TheFunction->args())
526     NamedValues[Arg.getName()] = &Arg;
527 
528   if (Value *RetVal = Body->codegen()) {
529     // Finish off the function.
530     Builder.CreateRet(RetVal);
531 
532     // Validate the generated code, checking for consistency.
533     verifyFunction(*TheFunction);
534 
535     // Run the optimizer on the function.
536     TheFPM->run(*TheFunction);
537 
538     return TheFunction;
539   }
540 
541   // Error reading body, remove function.
542   TheFunction->eraseFromParent();
543   return nullptr;
544 }
545 
546 //===----------------------------------------------------------------------===//
547 // Top-Level parsing and JIT Driver
548 //===----------------------------------------------------------------------===//
549 
550 static void InitializeModuleAndPassManager() {
551   // Open a new module.
552   TheModule = llvm::make_unique<Module>("my cool jit", TheContext);
553   TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
554 
555   // Create a new pass manager attached to it.
556   TheFPM = llvm::make_unique<legacy::FunctionPassManager>(TheModule.get());
557 
558   // Do simple "peephole" optimizations and bit-twiddling optzns.
559   TheFPM->add(createInstructionCombiningPass());
560   // Reassociate expressions.
561   TheFPM->add(createReassociatePass());
562   // Eliminate Common SubExpressions.
563   TheFPM->add(createGVNPass());
564   // Simplify the control flow graph (deleting unreachable blocks, etc).
565   TheFPM->add(createCFGSimplificationPass());
566 
567   TheFPM->doInitialization();
568 }
569 
570 static void HandleDefinition() {
571   if (auto FnAST = ParseDefinition()) {
572     if (auto *FnIR = FnAST->codegen()) {
573       fprintf(stderr, "Read function definition:");
574       FnIR->dump();
575       TheJIT->addModule(std::move(TheModule));
576       InitializeModuleAndPassManager();
577     }
578   } else {
579     // Skip token for error recovery.
580     getNextToken();
581   }
582 }
583 
584 static void HandleExtern() {
585   if (auto ProtoAST = ParseExtern()) {
586     if (auto *FnIR = ProtoAST->codegen()) {
587       fprintf(stderr, "Read extern: ");
588       FnIR->dump();
589       FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
590     }
591   } else {
592     // Skip token for error recovery.
593     getNextToken();
594   }
595 }
596 
597 static void HandleTopLevelExpression() {
598   // Evaluate a top-level expression into an anonymous function.
599   if (auto FnAST = ParseTopLevelExpr()) {
600     if (FnAST->codegen()) {
601       // JIT the module containing the anonymous expression, keeping a handle so
602       // we can free it later.
603       auto H = TheJIT->addModule(std::move(TheModule));
604       InitializeModuleAndPassManager();
605 
606       // Search the JIT for the __anon_expr symbol.
607       auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
608       assert(ExprSymbol && "Function not found");
609 
610       // Get the symbol's address and cast it to the right type (takes no
611       // arguments, returns a double) so we can call it as a native function.
612       double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
613       fprintf(stderr, "Evaluated to %f\n", FP());
614 
615       // Delete the anonymous expression module from the JIT.
616       TheJIT->removeModule(H);
617     }
618   } else {
619     // Skip token for error recovery.
620     getNextToken();
621   }
622 }
623 
624 /// top ::= definition | external | expression | ';'
625 static void MainLoop() {
626   while (true) {
627     fprintf(stderr, "ready> ");
628     switch (CurTok) {
629     case tok_eof:
630       return;
631     case ';': // ignore top-level semicolons.
632       getNextToken();
633       break;
634     case tok_def:
635       HandleDefinition();
636       break;
637     case tok_extern:
638       HandleExtern();
639       break;
640     default:
641       HandleTopLevelExpression();
642       break;
643     }
644   }
645 }
646 
647 //===----------------------------------------------------------------------===//
648 // "Library" functions that can be "extern'd" from user code.
649 //===----------------------------------------------------------------------===//
650 
651 /// putchard - putchar that takes a double and returns 0.
652 extern "C" double putchard(double X) {
653   fputc((char)X, stderr);
654   return 0;
655 }
656 
657 /// printd - printf that takes a double prints it as "%f\n", returning 0.
658 extern "C" double printd(double X) {
659   fprintf(stderr, "%f\n", X);
660   return 0;
661 }
662 
663 //===----------------------------------------------------------------------===//
664 // Main driver code.
665 //===----------------------------------------------------------------------===//
666 
667 int main() {
668   InitializeNativeTarget();
669   InitializeNativeTargetAsmPrinter();
670   InitializeNativeTargetAsmParser();
671 
672   // Install standard binary operators.
673   // 1 is lowest precedence.
674   BinopPrecedence['<'] = 10;
675   BinopPrecedence['+'] = 20;
676   BinopPrecedence['-'] = 20;
677   BinopPrecedence['*'] = 40; // highest.
678 
679   // Prime the first token.
680   fprintf(stderr, "ready> ");
681   getNextToken();
682 
683   TheJIT = llvm::make_unique<KaleidoscopeJIT>();
684 
685   InitializeModuleAndPassManager();
686 
687   // Run the main "interpreter loop" now.
688   MainLoop();
689 
690   return 0;
691 }
692