1 #include "llvm/ADT/APFloat.h" 2 #include "llvm/ADT/STLExtras.h" 3 #include "llvm/IR/BasicBlock.h" 4 #include "llvm/IR/Constants.h" 5 #include "llvm/IR/DerivedTypes.h" 6 #include "llvm/IR/Function.h" 7 #include "llvm/IR/IRBuilder.h" 8 #include "llvm/IR/LLVMContext.h" 9 #include "llvm/IR/Module.h" 10 #include "llvm/IR/Type.h" 11 #include "llvm/IR/Verifier.h" 12 #include <algorithm> 13 #include <cctype> 14 #include <cstdio> 15 #include <cstdlib> 16 #include <map> 17 #include <memory> 18 #include <string> 19 #include <vector> 20 21 using namespace llvm; 22 23 //===----------------------------------------------------------------------===// 24 // Lexer 25 //===----------------------------------------------------------------------===// 26 27 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one 28 // of these for known things. 29 enum Token { 30 tok_eof = -1, 31 32 // commands 33 tok_def = -2, 34 tok_extern = -3, 35 36 // primary 37 tok_identifier = -4, 38 tok_number = -5 39 }; 40 41 static std::string IdentifierStr; // Filled in if tok_identifier 42 static double NumVal; // Filled in if tok_number 43 44 /// gettok - Return the next token from standard input. 45 static int gettok() { 46 static int LastChar = ' '; 47 48 // Skip any whitespace. 49 while (isspace(LastChar)) 50 LastChar = getchar(); 51 52 if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]* 53 IdentifierStr = LastChar; 54 while (isalnum((LastChar = getchar()))) 55 IdentifierStr += LastChar; 56 57 if (IdentifierStr == "def") 58 return tok_def; 59 if (IdentifierStr == "extern") 60 return tok_extern; 61 return tok_identifier; 62 } 63 64 if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+ 65 std::string NumStr; 66 do { 67 NumStr += LastChar; 68 LastChar = getchar(); 69 } while (isdigit(LastChar) || LastChar == '.'); 70 71 NumVal = strtod(NumStr.c_str(), nullptr); 72 return tok_number; 73 } 74 75 if (LastChar == '#') { 76 // Comment until end of line. 77 do 78 LastChar = getchar(); 79 while (LastChar != EOF && LastChar != '\n' && LastChar != '\r'); 80 81 if (LastChar != EOF) 82 return gettok(); 83 } 84 85 // Check for end of file. Don't eat the EOF. 86 if (LastChar == EOF) 87 return tok_eof; 88 89 // Otherwise, just return the character as its ascii value. 90 int ThisChar = LastChar; 91 LastChar = getchar(); 92 return ThisChar; 93 } 94 95 //===----------------------------------------------------------------------===// 96 // Abstract Syntax Tree (aka Parse Tree) 97 //===----------------------------------------------------------------------===// 98 99 namespace { 100 101 /// ExprAST - Base class for all expression nodes. 102 class ExprAST { 103 public: 104 virtual ~ExprAST() = default; 105 106 virtual Value *codegen() = 0; 107 }; 108 109 /// NumberExprAST - Expression class for numeric literals like "1.0". 110 class NumberExprAST : public ExprAST { 111 double Val; 112 113 public: 114 NumberExprAST(double Val) : Val(Val) {} 115 116 Value *codegen() override; 117 }; 118 119 /// VariableExprAST - Expression class for referencing a variable, like "a". 120 class VariableExprAST : public ExprAST { 121 std::string Name; 122 123 public: 124 VariableExprAST(const std::string &Name) : Name(Name) {} 125 126 Value *codegen() override; 127 }; 128 129 /// BinaryExprAST - Expression class for a binary operator. 130 class BinaryExprAST : public ExprAST { 131 char Op; 132 std::unique_ptr<ExprAST> LHS, RHS; 133 134 public: 135 BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS, 136 std::unique_ptr<ExprAST> RHS) 137 : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {} 138 139 Value *codegen() override; 140 }; 141 142 /// CallExprAST - Expression class for function calls. 143 class CallExprAST : public ExprAST { 144 std::string Callee; 145 std::vector<std::unique_ptr<ExprAST>> Args; 146 147 public: 148 CallExprAST(const std::string &Callee, 149 std::vector<std::unique_ptr<ExprAST>> Args) 150 : Callee(Callee), Args(std::move(Args)) {} 151 152 Value *codegen() override; 153 }; 154 155 /// PrototypeAST - This class represents the "prototype" for a function, 156 /// which captures its name, and its argument names (thus implicitly the number 157 /// of arguments the function takes). 158 class PrototypeAST { 159 std::string Name; 160 std::vector<std::string> Args; 161 162 public: 163 PrototypeAST(const std::string &Name, std::vector<std::string> Args) 164 : Name(Name), Args(std::move(Args)) {} 165 166 Function *codegen(); 167 const std::string &getName() const { return Name; } 168 }; 169 170 /// FunctionAST - This class represents a function definition itself. 171 class FunctionAST { 172 std::unique_ptr<PrototypeAST> Proto; 173 std::unique_ptr<ExprAST> Body; 174 175 public: 176 FunctionAST(std::unique_ptr<PrototypeAST> Proto, 177 std::unique_ptr<ExprAST> Body) 178 : Proto(std::move(Proto)), Body(std::move(Body)) {} 179 180 Function *codegen(); 181 }; 182 183 } // end anonymous namespace 184 185 //===----------------------------------------------------------------------===// 186 // Parser 187 //===----------------------------------------------------------------------===// 188 189 /// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current 190 /// token the parser is looking at. getNextToken reads another token from the 191 /// lexer and updates CurTok with its results. 192 static int CurTok; 193 static int getNextToken() { return CurTok = gettok(); } 194 195 /// BinopPrecedence - This holds the precedence for each binary operator that is 196 /// defined. 197 static std::map<char, int> BinopPrecedence; 198 199 /// GetTokPrecedence - Get the precedence of the pending binary operator token. 200 static int GetTokPrecedence() { 201 if (!isascii(CurTok)) 202 return -1; 203 204 // Make sure it's a declared binop. 205 int TokPrec = BinopPrecedence[CurTok]; 206 if (TokPrec <= 0) 207 return -1; 208 return TokPrec; 209 } 210 211 /// LogError* - These are little helper functions for error handling. 212 std::unique_ptr<ExprAST> LogError(const char *Str) { 213 fprintf(stderr, "Error: %s\n", Str); 214 return nullptr; 215 } 216 217 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) { 218 LogError(Str); 219 return nullptr; 220 } 221 222 static std::unique_ptr<ExprAST> ParseExpression(); 223 224 /// numberexpr ::= number 225 static std::unique_ptr<ExprAST> ParseNumberExpr() { 226 auto Result = llvm::make_unique<NumberExprAST>(NumVal); 227 getNextToken(); // consume the number 228 return std::move(Result); 229 } 230 231 /// parenexpr ::= '(' expression ')' 232 static std::unique_ptr<ExprAST> ParseParenExpr() { 233 getNextToken(); // eat (. 234 auto V = ParseExpression(); 235 if (!V) 236 return nullptr; 237 238 if (CurTok != ')') 239 return LogError("expected ')'"); 240 getNextToken(); // eat ). 241 return V; 242 } 243 244 /// identifierexpr 245 /// ::= identifier 246 /// ::= identifier '(' expression* ')' 247 static std::unique_ptr<ExprAST> ParseIdentifierExpr() { 248 std::string IdName = IdentifierStr; 249 250 getNextToken(); // eat identifier. 251 252 if (CurTok != '(') // Simple variable ref. 253 return llvm::make_unique<VariableExprAST>(IdName); 254 255 // Call. 256 getNextToken(); // eat ( 257 std::vector<std::unique_ptr<ExprAST>> Args; 258 if (CurTok != ')') { 259 while (true) { 260 if (auto Arg = ParseExpression()) 261 Args.push_back(std::move(Arg)); 262 else 263 return nullptr; 264 265 if (CurTok == ')') 266 break; 267 268 if (CurTok != ',') 269 return LogError("Expected ')' or ',' in argument list"); 270 getNextToken(); 271 } 272 } 273 274 // Eat the ')'. 275 getNextToken(); 276 277 return llvm::make_unique<CallExprAST>(IdName, std::move(Args)); 278 } 279 280 /// primary 281 /// ::= identifierexpr 282 /// ::= numberexpr 283 /// ::= parenexpr 284 static std::unique_ptr<ExprAST> ParsePrimary() { 285 switch (CurTok) { 286 default: 287 return LogError("unknown token when expecting an expression"); 288 case tok_identifier: 289 return ParseIdentifierExpr(); 290 case tok_number: 291 return ParseNumberExpr(); 292 case '(': 293 return ParseParenExpr(); 294 } 295 } 296 297 /// binoprhs 298 /// ::= ('+' primary)* 299 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec, 300 std::unique_ptr<ExprAST> LHS) { 301 // If this is a binop, find its precedence. 302 while (true) { 303 int TokPrec = GetTokPrecedence(); 304 305 // If this is a binop that binds at least as tightly as the current binop, 306 // consume it, otherwise we are done. 307 if (TokPrec < ExprPrec) 308 return LHS; 309 310 // Okay, we know this is a binop. 311 int BinOp = CurTok; 312 getNextToken(); // eat binop 313 314 // Parse the primary expression after the binary operator. 315 auto RHS = ParsePrimary(); 316 if (!RHS) 317 return nullptr; 318 319 // If BinOp binds less tightly with RHS than the operator after RHS, let 320 // the pending operator take RHS as its LHS. 321 int NextPrec = GetTokPrecedence(); 322 if (TokPrec < NextPrec) { 323 RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS)); 324 if (!RHS) 325 return nullptr; 326 } 327 328 // Merge LHS/RHS. 329 LHS = 330 llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS)); 331 } 332 } 333 334 /// expression 335 /// ::= primary binoprhs 336 /// 337 static std::unique_ptr<ExprAST> ParseExpression() { 338 auto LHS = ParsePrimary(); 339 if (!LHS) 340 return nullptr; 341 342 return ParseBinOpRHS(0, std::move(LHS)); 343 } 344 345 /// prototype 346 /// ::= id '(' id* ')' 347 static std::unique_ptr<PrototypeAST> ParsePrototype() { 348 if (CurTok != tok_identifier) 349 return LogErrorP("Expected function name in prototype"); 350 351 std::string FnName = IdentifierStr; 352 getNextToken(); 353 354 if (CurTok != '(') 355 return LogErrorP("Expected '(' in prototype"); 356 357 std::vector<std::string> ArgNames; 358 while (getNextToken() == tok_identifier) 359 ArgNames.push_back(IdentifierStr); 360 if (CurTok != ')') 361 return LogErrorP("Expected ')' in prototype"); 362 363 // success. 364 getNextToken(); // eat ')'. 365 366 return llvm::make_unique<PrototypeAST>(FnName, std::move(ArgNames)); 367 } 368 369 /// definition ::= 'def' prototype expression 370 static std::unique_ptr<FunctionAST> ParseDefinition() { 371 getNextToken(); // eat def. 372 auto Proto = ParsePrototype(); 373 if (!Proto) 374 return nullptr; 375 376 if (auto E = ParseExpression()) 377 return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E)); 378 return nullptr; 379 } 380 381 /// toplevelexpr ::= expression 382 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() { 383 if (auto E = ParseExpression()) { 384 // Make an anonymous proto. 385 auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr", 386 std::vector<std::string>()); 387 return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E)); 388 } 389 return nullptr; 390 } 391 392 /// external ::= 'extern' prototype 393 static std::unique_ptr<PrototypeAST> ParseExtern() { 394 getNextToken(); // eat extern. 395 return ParsePrototype(); 396 } 397 398 //===----------------------------------------------------------------------===// 399 // Code Generation 400 //===----------------------------------------------------------------------===// 401 402 static LLVMContext TheContext; 403 static IRBuilder<> Builder(TheContext); 404 static std::unique_ptr<Module> TheModule; 405 static std::map<std::string, Value *> NamedValues; 406 407 Value *LogErrorV(const char *Str) { 408 LogError(Str); 409 return nullptr; 410 } 411 412 Value *NumberExprAST::codegen() { 413 return ConstantFP::get(TheContext, APFloat(Val)); 414 } 415 416 Value *VariableExprAST::codegen() { 417 // Look this variable up in the function. 418 Value *V = NamedValues[Name]; 419 if (!V) 420 return LogErrorV("Unknown variable name"); 421 return V; 422 } 423 424 Value *BinaryExprAST::codegen() { 425 Value *L = LHS->codegen(); 426 Value *R = RHS->codegen(); 427 if (!L || !R) 428 return nullptr; 429 430 switch (Op) { 431 case '+': 432 return Builder.CreateFAdd(L, R, "addtmp"); 433 case '-': 434 return Builder.CreateFSub(L, R, "subtmp"); 435 case '*': 436 return Builder.CreateFMul(L, R, "multmp"); 437 case '<': 438 L = Builder.CreateFCmpULT(L, R, "cmptmp"); 439 // Convert bool 0/1 to double 0.0 or 1.0 440 return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp"); 441 default: 442 return LogErrorV("invalid binary operator"); 443 } 444 } 445 446 Value *CallExprAST::codegen() { 447 // Look up the name in the global module table. 448 Function *CalleeF = TheModule->getFunction(Callee); 449 if (!CalleeF) 450 return LogErrorV("Unknown function referenced"); 451 452 // If argument mismatch error. 453 if (CalleeF->arg_size() != Args.size()) 454 return LogErrorV("Incorrect # arguments passed"); 455 456 std::vector<Value *> ArgsV; 457 for (unsigned i = 0, e = Args.size(); i != e; ++i) { 458 ArgsV.push_back(Args[i]->codegen()); 459 if (!ArgsV.back()) 460 return nullptr; 461 } 462 463 return Builder.CreateCall(CalleeF, ArgsV, "calltmp"); 464 } 465 466 Function *PrototypeAST::codegen() { 467 // Make the function type: double(double,double) etc. 468 std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(TheContext)); 469 FunctionType *FT = 470 FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false); 471 472 Function *F = 473 Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get()); 474 475 // Set names for all arguments. 476 unsigned Idx = 0; 477 for (auto &Arg : F->args()) 478 Arg.setName(Args[Idx++]); 479 480 return F; 481 } 482 483 Function *FunctionAST::codegen() { 484 // First, check for an existing function from a previous 'extern' declaration. 485 Function *TheFunction = TheModule->getFunction(Proto->getName()); 486 487 if (!TheFunction) 488 TheFunction = Proto->codegen(); 489 490 if (!TheFunction) 491 return nullptr; 492 493 // Create a new basic block to start insertion into. 494 BasicBlock *BB = BasicBlock::Create(TheContext, "entry", TheFunction); 495 Builder.SetInsertPoint(BB); 496 497 // Record the function arguments in the NamedValues map. 498 NamedValues.clear(); 499 for (auto &Arg : TheFunction->args()) 500 NamedValues[Arg.getName()] = &Arg; 501 502 if (Value *RetVal = Body->codegen()) { 503 // Finish off the function. 504 Builder.CreateRet(RetVal); 505 506 // Validate the generated code, checking for consistency. 507 verifyFunction(*TheFunction); 508 509 return TheFunction; 510 } 511 512 // Error reading body, remove function. 513 TheFunction->eraseFromParent(); 514 return nullptr; 515 } 516 517 //===----------------------------------------------------------------------===// 518 // Top-Level parsing and JIT Driver 519 //===----------------------------------------------------------------------===// 520 521 static void HandleDefinition() { 522 if (auto FnAST = ParseDefinition()) { 523 if (auto *FnIR = FnAST->codegen()) { 524 fprintf(stderr, "Read function definition:"); 525 FnIR->dump(); 526 } 527 } else { 528 // Skip token for error recovery. 529 getNextToken(); 530 } 531 } 532 533 static void HandleExtern() { 534 if (auto ProtoAST = ParseExtern()) { 535 if (auto *FnIR = ProtoAST->codegen()) { 536 fprintf(stderr, "Read extern: "); 537 FnIR->dump(); 538 } 539 } else { 540 // Skip token for error recovery. 541 getNextToken(); 542 } 543 } 544 545 static void HandleTopLevelExpression() { 546 // Evaluate a top-level expression into an anonymous function. 547 if (auto FnAST = ParseTopLevelExpr()) { 548 if (auto *FnIR = FnAST->codegen()) { 549 fprintf(stderr, "Read top-level expression:"); 550 FnIR->dump(); 551 } 552 } else { 553 // Skip token for error recovery. 554 getNextToken(); 555 } 556 } 557 558 /// top ::= definition | external | expression | ';' 559 static void MainLoop() { 560 while (true) { 561 fprintf(stderr, "ready> "); 562 switch (CurTok) { 563 case tok_eof: 564 return; 565 case ';': // ignore top-level semicolons. 566 getNextToken(); 567 break; 568 case tok_def: 569 HandleDefinition(); 570 break; 571 case tok_extern: 572 HandleExtern(); 573 break; 574 default: 575 HandleTopLevelExpression(); 576 break; 577 } 578 } 579 } 580 581 //===----------------------------------------------------------------------===// 582 // Main driver code. 583 //===----------------------------------------------------------------------===// 584 585 int main() { 586 // Install standard binary operators. 587 // 1 is lowest precedence. 588 BinopPrecedence['<'] = 10; 589 BinopPrecedence['+'] = 20; 590 BinopPrecedence['-'] = 20; 591 BinopPrecedence['*'] = 40; // highest. 592 593 // Prime the first token. 594 fprintf(stderr, "ready> "); 595 getNextToken(); 596 597 // Make the module, which holds all the code. 598 TheModule = llvm::make_unique<Module>("my cool jit", TheContext); 599 600 // Run the main "interpreter loop" now. 601 MainLoop(); 602 603 // Print out all of the generated code. 604 TheModule->dump(); 605 606 return 0; 607 } 608