1 #include "llvm/ADT/APFloat.h" 2 #include "llvm/ADT/STLExtras.h" 3 #include "llvm/IR/BasicBlock.h" 4 #include "llvm/IR/Constants.h" 5 #include "llvm/IR/DerivedTypes.h" 6 #include "llvm/IR/Function.h" 7 #include "llvm/IR/IRBuilder.h" 8 #include "llvm/IR/LLVMContext.h" 9 #include "llvm/IR/LegacyPassManager.h" 10 #include "llvm/IR/Module.h" 11 #include "llvm/IR/Type.h" 12 #include "llvm/IR/Verifier.h" 13 #include "llvm/Support/TargetSelect.h" 14 #include "llvm/Target/TargetMachine.h" 15 #include "llvm/Transforms/Scalar.h" 16 #include "llvm/Transforms/Scalar/GVN.h" 17 #include "../include/KaleidoscopeJIT.h" 18 #include <algorithm> 19 #include <cassert> 20 #include <cctype> 21 #include <cstdint> 22 #include <cstdio> 23 #include <cstdlib> 24 #include <map> 25 #include <memory> 26 #include <string> 27 #include <vector> 28 29 using namespace llvm; 30 using namespace llvm::orc; 31 32 //===----------------------------------------------------------------------===// 33 // Lexer 34 //===----------------------------------------------------------------------===// 35 36 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one 37 // of these for known things. 38 enum Token { 39 tok_eof = -1, 40 41 // commands 42 tok_def = -2, 43 tok_extern = -3, 44 45 // primary 46 tok_identifier = -4, 47 tok_number = -5 48 }; 49 50 static std::string IdentifierStr; // Filled in if tok_identifier 51 static double NumVal; // Filled in if tok_number 52 53 /// gettok - Return the next token from standard input. 54 static int gettok() { 55 static int LastChar = ' '; 56 57 // Skip any whitespace. 58 while (isspace(LastChar)) 59 LastChar = getchar(); 60 61 if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]* 62 IdentifierStr = LastChar; 63 while (isalnum((LastChar = getchar()))) 64 IdentifierStr += LastChar; 65 66 if (IdentifierStr == "def") 67 return tok_def; 68 if (IdentifierStr == "extern") 69 return tok_extern; 70 return tok_identifier; 71 } 72 73 if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+ 74 std::string NumStr; 75 do { 76 NumStr += LastChar; 77 LastChar = getchar(); 78 } while (isdigit(LastChar) || LastChar == '.'); 79 80 NumVal = strtod(NumStr.c_str(), nullptr); 81 return tok_number; 82 } 83 84 if (LastChar == '#') { 85 // Comment until end of line. 86 do 87 LastChar = getchar(); 88 while (LastChar != EOF && LastChar != '\n' && LastChar != '\r'); 89 90 if (LastChar != EOF) 91 return gettok(); 92 } 93 94 // Check for end of file. Don't eat the EOF. 95 if (LastChar == EOF) 96 return tok_eof; 97 98 // Otherwise, just return the character as its ascii value. 99 int ThisChar = LastChar; 100 LastChar = getchar(); 101 return ThisChar; 102 } 103 104 //===----------------------------------------------------------------------===// 105 // Abstract Syntax Tree (aka Parse Tree) 106 //===----------------------------------------------------------------------===// 107 108 namespace { 109 110 /// ExprAST - Base class for all expression nodes. 111 class ExprAST { 112 public: 113 virtual ~ExprAST() = default; 114 115 virtual Value *codegen() = 0; 116 }; 117 118 /// NumberExprAST - Expression class for numeric literals like "1.0". 119 class NumberExprAST : public ExprAST { 120 double Val; 121 122 public: 123 NumberExprAST(double Val) : Val(Val) {} 124 125 Value *codegen() override; 126 }; 127 128 /// VariableExprAST - Expression class for referencing a variable, like "a". 129 class VariableExprAST : public ExprAST { 130 std::string Name; 131 132 public: 133 VariableExprAST(const std::string &Name) : Name(Name) {} 134 135 Value *codegen() override; 136 }; 137 138 /// BinaryExprAST - Expression class for a binary operator. 139 class BinaryExprAST : public ExprAST { 140 char Op; 141 std::unique_ptr<ExprAST> LHS, RHS; 142 143 public: 144 BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS, 145 std::unique_ptr<ExprAST> RHS) 146 : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {} 147 148 Value *codegen() override; 149 }; 150 151 /// CallExprAST - Expression class for function calls. 152 class CallExprAST : public ExprAST { 153 std::string Callee; 154 std::vector<std::unique_ptr<ExprAST>> Args; 155 156 public: 157 CallExprAST(const std::string &Callee, 158 std::vector<std::unique_ptr<ExprAST>> Args) 159 : Callee(Callee), Args(std::move(Args)) {} 160 161 Value *codegen() override; 162 }; 163 164 /// PrototypeAST - This class represents the "prototype" for a function, 165 /// which captures its name, and its argument names (thus implicitly the number 166 /// of arguments the function takes). 167 class PrototypeAST { 168 std::string Name; 169 std::vector<std::string> Args; 170 171 public: 172 PrototypeAST(const std::string &Name, std::vector<std::string> Args) 173 : Name(Name), Args(std::move(Args)) {} 174 175 Function *codegen(); 176 const std::string &getName() const { return Name; } 177 }; 178 179 /// FunctionAST - This class represents a function definition itself. 180 class FunctionAST { 181 std::unique_ptr<PrototypeAST> Proto; 182 std::unique_ptr<ExprAST> Body; 183 184 public: 185 FunctionAST(std::unique_ptr<PrototypeAST> Proto, 186 std::unique_ptr<ExprAST> Body) 187 : Proto(std::move(Proto)), Body(std::move(Body)) {} 188 189 Function *codegen(); 190 }; 191 192 } // end anonymous namespace 193 194 //===----------------------------------------------------------------------===// 195 // Parser 196 //===----------------------------------------------------------------------===// 197 198 /// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current 199 /// token the parser is looking at. getNextToken reads another token from the 200 /// lexer and updates CurTok with its results. 201 static int CurTok; 202 static int getNextToken() { return CurTok = gettok(); } 203 204 /// BinopPrecedence - This holds the precedence for each binary operator that is 205 /// defined. 206 static std::map<char, int> BinopPrecedence; 207 208 /// GetTokPrecedence - Get the precedence of the pending binary operator token. 209 static int GetTokPrecedence() { 210 if (!isascii(CurTok)) 211 return -1; 212 213 // Make sure it's a declared binop. 214 int TokPrec = BinopPrecedence[CurTok]; 215 if (TokPrec <= 0) 216 return -1; 217 return TokPrec; 218 } 219 220 /// LogError* - These are little helper functions for error handling. 221 std::unique_ptr<ExprAST> LogError(const char *Str) { 222 fprintf(stderr, "Error: %s\n", Str); 223 return nullptr; 224 } 225 226 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) { 227 LogError(Str); 228 return nullptr; 229 } 230 231 static std::unique_ptr<ExprAST> ParseExpression(); 232 233 /// numberexpr ::= number 234 static std::unique_ptr<ExprAST> ParseNumberExpr() { 235 auto Result = llvm::make_unique<NumberExprAST>(NumVal); 236 getNextToken(); // consume the number 237 return std::move(Result); 238 } 239 240 /// parenexpr ::= '(' expression ')' 241 static std::unique_ptr<ExprAST> ParseParenExpr() { 242 getNextToken(); // eat (. 243 auto V = ParseExpression(); 244 if (!V) 245 return nullptr; 246 247 if (CurTok != ')') 248 return LogError("expected ')'"); 249 getNextToken(); // eat ). 250 return V; 251 } 252 253 /// identifierexpr 254 /// ::= identifier 255 /// ::= identifier '(' expression* ')' 256 static std::unique_ptr<ExprAST> ParseIdentifierExpr() { 257 std::string IdName = IdentifierStr; 258 259 getNextToken(); // eat identifier. 260 261 if (CurTok != '(') // Simple variable ref. 262 return llvm::make_unique<VariableExprAST>(IdName); 263 264 // Call. 265 getNextToken(); // eat ( 266 std::vector<std::unique_ptr<ExprAST>> Args; 267 if (CurTok != ')') { 268 while (true) { 269 if (auto Arg = ParseExpression()) 270 Args.push_back(std::move(Arg)); 271 else 272 return nullptr; 273 274 if (CurTok == ')') 275 break; 276 277 if (CurTok != ',') 278 return LogError("Expected ')' or ',' in argument list"); 279 getNextToken(); 280 } 281 } 282 283 // Eat the ')'. 284 getNextToken(); 285 286 return llvm::make_unique<CallExprAST>(IdName, std::move(Args)); 287 } 288 289 /// primary 290 /// ::= identifierexpr 291 /// ::= numberexpr 292 /// ::= parenexpr 293 static std::unique_ptr<ExprAST> ParsePrimary() { 294 switch (CurTok) { 295 default: 296 return LogError("unknown token when expecting an expression"); 297 case tok_identifier: 298 return ParseIdentifierExpr(); 299 case tok_number: 300 return ParseNumberExpr(); 301 case '(': 302 return ParseParenExpr(); 303 } 304 } 305 306 /// binoprhs 307 /// ::= ('+' primary)* 308 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec, 309 std::unique_ptr<ExprAST> LHS) { 310 // If this is a binop, find its precedence. 311 while (true) { 312 int TokPrec = GetTokPrecedence(); 313 314 // If this is a binop that binds at least as tightly as the current binop, 315 // consume it, otherwise we are done. 316 if (TokPrec < ExprPrec) 317 return LHS; 318 319 // Okay, we know this is a binop. 320 int BinOp = CurTok; 321 getNextToken(); // eat binop 322 323 // Parse the primary expression after the binary operator. 324 auto RHS = ParsePrimary(); 325 if (!RHS) 326 return nullptr; 327 328 // If BinOp binds less tightly with RHS than the operator after RHS, let 329 // the pending operator take RHS as its LHS. 330 int NextPrec = GetTokPrecedence(); 331 if (TokPrec < NextPrec) { 332 RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS)); 333 if (!RHS) 334 return nullptr; 335 } 336 337 // Merge LHS/RHS. 338 LHS = 339 llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS)); 340 } 341 } 342 343 /// expression 344 /// ::= primary binoprhs 345 /// 346 static std::unique_ptr<ExprAST> ParseExpression() { 347 auto LHS = ParsePrimary(); 348 if (!LHS) 349 return nullptr; 350 351 return ParseBinOpRHS(0, std::move(LHS)); 352 } 353 354 /// prototype 355 /// ::= id '(' id* ')' 356 static std::unique_ptr<PrototypeAST> ParsePrototype() { 357 if (CurTok != tok_identifier) 358 return LogErrorP("Expected function name in prototype"); 359 360 std::string FnName = IdentifierStr; 361 getNextToken(); 362 363 if (CurTok != '(') 364 return LogErrorP("Expected '(' in prototype"); 365 366 std::vector<std::string> ArgNames; 367 while (getNextToken() == tok_identifier) 368 ArgNames.push_back(IdentifierStr); 369 if (CurTok != ')') 370 return LogErrorP("Expected ')' in prototype"); 371 372 // success. 373 getNextToken(); // eat ')'. 374 375 return llvm::make_unique<PrototypeAST>(FnName, std::move(ArgNames)); 376 } 377 378 /// definition ::= 'def' prototype expression 379 static std::unique_ptr<FunctionAST> ParseDefinition() { 380 getNextToken(); // eat def. 381 auto Proto = ParsePrototype(); 382 if (!Proto) 383 return nullptr; 384 385 if (auto E = ParseExpression()) 386 return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E)); 387 return nullptr; 388 } 389 390 /// toplevelexpr ::= expression 391 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() { 392 if (auto E = ParseExpression()) { 393 // Make an anonymous proto. 394 auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr", 395 std::vector<std::string>()); 396 return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E)); 397 } 398 return nullptr; 399 } 400 401 /// external ::= 'extern' prototype 402 static std::unique_ptr<PrototypeAST> ParseExtern() { 403 getNextToken(); // eat extern. 404 return ParsePrototype(); 405 } 406 407 //===----------------------------------------------------------------------===// 408 // Code Generation 409 //===----------------------------------------------------------------------===// 410 411 static LLVMContext TheContext; 412 static IRBuilder<> Builder(TheContext); 413 static std::unique_ptr<Module> TheModule; 414 static std::map<std::string, Value *> NamedValues; 415 static std::unique_ptr<legacy::FunctionPassManager> TheFPM; 416 static std::unique_ptr<KaleidoscopeJIT> TheJIT; 417 static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos; 418 419 Value *LogErrorV(const char *Str) { 420 LogError(Str); 421 return nullptr; 422 } 423 424 Function *getFunction(std::string Name) { 425 // First, see if the function has already been added to the current module. 426 if (auto *F = TheModule->getFunction(Name)) 427 return F; 428 429 // If not, check whether we can codegen the declaration from some existing 430 // prototype. 431 auto FI = FunctionProtos.find(Name); 432 if (FI != FunctionProtos.end()) 433 return FI->second->codegen(); 434 435 // If no existing prototype exists, return null. 436 return nullptr; 437 } 438 439 Value *NumberExprAST::codegen() { 440 return ConstantFP::get(TheContext, APFloat(Val)); 441 } 442 443 Value *VariableExprAST::codegen() { 444 // Look this variable up in the function. 445 Value *V = NamedValues[Name]; 446 if (!V) 447 return LogErrorV("Unknown variable name"); 448 return V; 449 } 450 451 Value *BinaryExprAST::codegen() { 452 Value *L = LHS->codegen(); 453 Value *R = RHS->codegen(); 454 if (!L || !R) 455 return nullptr; 456 457 switch (Op) { 458 case '+': 459 return Builder.CreateFAdd(L, R, "addtmp"); 460 case '-': 461 return Builder.CreateFSub(L, R, "subtmp"); 462 case '*': 463 return Builder.CreateFMul(L, R, "multmp"); 464 case '<': 465 L = Builder.CreateFCmpULT(L, R, "cmptmp"); 466 // Convert bool 0/1 to double 0.0 or 1.0 467 return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp"); 468 default: 469 return LogErrorV("invalid binary operator"); 470 } 471 } 472 473 Value *CallExprAST::codegen() { 474 // Look up the name in the global module table. 475 Function *CalleeF = getFunction(Callee); 476 if (!CalleeF) 477 return LogErrorV("Unknown function referenced"); 478 479 // If argument mismatch error. 480 if (CalleeF->arg_size() != Args.size()) 481 return LogErrorV("Incorrect # arguments passed"); 482 483 std::vector<Value *> ArgsV; 484 for (unsigned i = 0, e = Args.size(); i != e; ++i) { 485 ArgsV.push_back(Args[i]->codegen()); 486 if (!ArgsV.back()) 487 return nullptr; 488 } 489 490 return Builder.CreateCall(CalleeF, ArgsV, "calltmp"); 491 } 492 493 Function *PrototypeAST::codegen() { 494 // Make the function type: double(double,double) etc. 495 std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(TheContext)); 496 FunctionType *FT = 497 FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false); 498 499 Function *F = 500 Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get()); 501 502 // Set names for all arguments. 503 unsigned Idx = 0; 504 for (auto &Arg : F->args()) 505 Arg.setName(Args[Idx++]); 506 507 return F; 508 } 509 510 Function *FunctionAST::codegen() { 511 // Transfer ownership of the prototype to the FunctionProtos map, but keep a 512 // reference to it for use below. 513 auto &P = *Proto; 514 FunctionProtos[Proto->getName()] = std::move(Proto); 515 Function *TheFunction = getFunction(P.getName()); 516 if (!TheFunction) 517 return nullptr; 518 519 // Create a new basic block to start insertion into. 520 BasicBlock *BB = BasicBlock::Create(TheContext, "entry", TheFunction); 521 Builder.SetInsertPoint(BB); 522 523 // Record the function arguments in the NamedValues map. 524 NamedValues.clear(); 525 for (auto &Arg : TheFunction->args()) 526 NamedValues[Arg.getName()] = &Arg; 527 528 if (Value *RetVal = Body->codegen()) { 529 // Finish off the function. 530 Builder.CreateRet(RetVal); 531 532 // Validate the generated code, checking for consistency. 533 verifyFunction(*TheFunction); 534 535 // Run the optimizer on the function. 536 TheFPM->run(*TheFunction); 537 538 return TheFunction; 539 } 540 541 // Error reading body, remove function. 542 TheFunction->eraseFromParent(); 543 return nullptr; 544 } 545 546 //===----------------------------------------------------------------------===// 547 // Top-Level parsing and JIT Driver 548 //===----------------------------------------------------------------------===// 549 550 static void InitializeModuleAndPassManager() { 551 // Open a new module. 552 TheModule = llvm::make_unique<Module>("my cool jit", TheContext); 553 TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout()); 554 555 // Create a new pass manager attached to it. 556 TheFPM = llvm::make_unique<legacy::FunctionPassManager>(TheModule.get()); 557 558 // Do simple "peephole" optimizations and bit-twiddling optzns. 559 TheFPM->add(createInstructionCombiningPass()); 560 // Reassociate expressions. 561 TheFPM->add(createReassociatePass()); 562 // Eliminate Common SubExpressions. 563 TheFPM->add(createGVNPass()); 564 // Simplify the control flow graph (deleting unreachable blocks, etc). 565 TheFPM->add(createCFGSimplificationPass()); 566 567 TheFPM->doInitialization(); 568 } 569 570 static void HandleDefinition() { 571 if (auto FnAST = ParseDefinition()) { 572 if (auto *FnIR = FnAST->codegen()) { 573 fprintf(stderr, "Read function definition:"); 574 FnIR->dump(); 575 TheJIT->addModule(std::move(TheModule)); 576 InitializeModuleAndPassManager(); 577 } 578 } else { 579 // Skip token for error recovery. 580 getNextToken(); 581 } 582 } 583 584 static void HandleExtern() { 585 if (auto ProtoAST = ParseExtern()) { 586 if (auto *FnIR = ProtoAST->codegen()) { 587 fprintf(stderr, "Read extern: "); 588 FnIR->dump(); 589 FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST); 590 } 591 } else { 592 // Skip token for error recovery. 593 getNextToken(); 594 } 595 } 596 597 static void HandleTopLevelExpression() { 598 // Evaluate a top-level expression into an anonymous function. 599 if (auto FnAST = ParseTopLevelExpr()) { 600 if (FnAST->codegen()) { 601 // JIT the module containing the anonymous expression, keeping a handle so 602 // we can free it later. 603 auto H = TheJIT->addModule(std::move(TheModule)); 604 InitializeModuleAndPassManager(); 605 606 // Search the JIT for the __anon_expr symbol. 607 auto ExprSymbol = TheJIT->findSymbol("__anon_expr"); 608 assert(ExprSymbol && "Function not found"); 609 610 // Get the symbol's address and cast it to the right type (takes no 611 // arguments, returns a double) so we can call it as a native function. 612 double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress(); 613 fprintf(stderr, "Evaluated to %f\n", FP()); 614 615 // Delete the anonymous expression module from the JIT. 616 TheJIT->removeModule(H); 617 } 618 } else { 619 // Skip token for error recovery. 620 getNextToken(); 621 } 622 } 623 624 /// top ::= definition | external | expression | ';' 625 static void MainLoop() { 626 while (true) { 627 fprintf(stderr, "ready> "); 628 switch (CurTok) { 629 case tok_eof: 630 return; 631 case ';': // ignore top-level semicolons. 632 getNextToken(); 633 break; 634 case tok_def: 635 HandleDefinition(); 636 break; 637 case tok_extern: 638 HandleExtern(); 639 break; 640 default: 641 HandleTopLevelExpression(); 642 break; 643 } 644 } 645 } 646 647 //===----------------------------------------------------------------------===// 648 // "Library" functions that can be "extern'd" from user code. 649 //===----------------------------------------------------------------------===// 650 651 /// putchard - putchar that takes a double and returns 0. 652 extern "C" double putchard(double X) { 653 fputc((char)X, stderr); 654 return 0; 655 } 656 657 /// printd - printf that takes a double prints it as "%f\n", returning 0. 658 extern "C" double printd(double X) { 659 fprintf(stderr, "%f\n", X); 660 return 0; 661 } 662 663 //===----------------------------------------------------------------------===// 664 // Main driver code. 665 //===----------------------------------------------------------------------===// 666 667 int main() { 668 InitializeNativeTarget(); 669 InitializeNativeTargetAsmPrinter(); 670 InitializeNativeTargetAsmParser(); 671 672 // Install standard binary operators. 673 // 1 is lowest precedence. 674 BinopPrecedence['<'] = 10; 675 BinopPrecedence['+'] = 20; 676 BinopPrecedence['-'] = 20; 677 BinopPrecedence['*'] = 40; // highest. 678 679 // Prime the first token. 680 fprintf(stderr, "ready> "); 681 getNextToken(); 682 683 TheJIT = llvm::make_unique<KaleidoscopeJIT>(); 684 685 InitializeModuleAndPassManager(); 686 687 // Run the main "interpreter loop" now. 688 MainLoop(); 689 690 return 0; 691 } 692