1 //===- AST.h - Node definition for the Toy AST ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the AST for the Toy language. It is optimized for 10 // simplicity, not efficiency. The AST forms a tree structure where each node 11 // references its children using std::unique_ptr<>. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #ifndef TOY_AST_H 16 #define TOY_AST_H 17 18 #include "toy/Lexer.h" 19 20 #include "llvm/ADT/ArrayRef.h" 21 #include "llvm/ADT/StringRef.h" 22 #include "llvm/Support/Casting.h" 23 #include <utility> 24 #include <vector> 25 26 namespace toy { 27 28 /// A variable type with either name or shape information. 29 struct VarType { 30 std::string name; 31 std::vector<int64_t> shape; 32 }; 33 34 /// Base class for all expression nodes. 35 class ExprAST { 36 public: 37 enum ExprASTKind { 38 Expr_VarDecl, 39 Expr_Return, 40 Expr_Num, 41 Expr_Literal, 42 Expr_StructLiteral, 43 Expr_Var, 44 Expr_BinOp, 45 Expr_Call, 46 Expr_Print, 47 }; 48 ExprAST(ExprASTKind kind,Location location)49 ExprAST(ExprASTKind kind, Location location) 50 : kind(kind), location(std::move(location)) {} 51 virtual ~ExprAST() = default; 52 getKind()53 ExprASTKind getKind() const { return kind; } 54 loc()55 const Location &loc() { return location; } 56 57 private: 58 const ExprASTKind kind; 59 Location location; 60 }; 61 62 /// A block-list of expressions. 63 using ExprASTList = std::vector<std::unique_ptr<ExprAST>>; 64 65 /// Expression class for numeric literals like "1.0". 66 class NumberExprAST : public ExprAST { 67 double val; 68 69 public: NumberExprAST(Location loc,double val)70 NumberExprAST(Location loc, double val) 71 : ExprAST(Expr_Num, std::move(loc)), val(val) {} 72 getValue()73 double getValue() { return val; } 74 75 /// LLVM style RTTI classof(const ExprAST * c)76 static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } 77 }; 78 79 /// Expression class for a literal value. 80 class LiteralExprAST : public ExprAST { 81 std::vector<std::unique_ptr<ExprAST>> values; 82 std::vector<int64_t> dims; 83 84 public: LiteralExprAST(Location loc,std::vector<std::unique_ptr<ExprAST>> values,std::vector<int64_t> dims)85 LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values, 86 std::vector<int64_t> dims) 87 : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)), 88 dims(std::move(dims)) {} 89 getValues()90 llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; } getDims()91 llvm::ArrayRef<int64_t> getDims() { return dims; } 92 93 /// LLVM style RTTI classof(const ExprAST * c)94 static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } 95 }; 96 97 /// Expression class for a literal struct value. 98 class StructLiteralExprAST : public ExprAST { 99 std::vector<std::unique_ptr<ExprAST>> values; 100 101 public: StructLiteralExprAST(Location loc,std::vector<std::unique_ptr<ExprAST>> values)102 StructLiteralExprAST(Location loc, 103 std::vector<std::unique_ptr<ExprAST>> values) 104 : ExprAST(Expr_StructLiteral, std::move(loc)), values(std::move(values)) { 105 } 106 getValues()107 llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; } 108 109 /// LLVM style RTTI classof(const ExprAST * c)110 static bool classof(const ExprAST *c) { 111 return c->getKind() == Expr_StructLiteral; 112 } 113 }; 114 115 /// Expression class for referencing a variable, like "a". 116 class VariableExprAST : public ExprAST { 117 std::string name; 118 119 public: VariableExprAST(Location loc,llvm::StringRef name)120 VariableExprAST(Location loc, llvm::StringRef name) 121 : ExprAST(Expr_Var, std::move(loc)), name(name) {} 122 getName()123 llvm::StringRef getName() { return name; } 124 125 /// LLVM style RTTI classof(const ExprAST * c)126 static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } 127 }; 128 129 /// Expression class for defining a variable. 130 class VarDeclExprAST : public ExprAST { 131 std::string name; 132 VarType type; 133 std::unique_ptr<ExprAST> initVal; 134 135 public: 136 VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, 137 std::unique_ptr<ExprAST> initVal = nullptr) ExprAST(Expr_VarDecl,std::move (loc))138 : ExprAST(Expr_VarDecl, std::move(loc)), name(name), 139 type(std::move(type)), initVal(std::move(initVal)) {} 140 getName()141 llvm::StringRef getName() { return name; } getInitVal()142 ExprAST *getInitVal() { return initVal.get(); } getType()143 const VarType &getType() { return type; } 144 145 /// LLVM style RTTI classof(const ExprAST * c)146 static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } 147 }; 148 149 /// Expression class for a return operator. 150 class ReturnExprAST : public ExprAST { 151 llvm::Optional<std::unique_ptr<ExprAST>> expr; 152 153 public: ReturnExprAST(Location loc,llvm::Optional<std::unique_ptr<ExprAST>> expr)154 ReturnExprAST(Location loc, llvm::Optional<std::unique_ptr<ExprAST>> expr) 155 : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {} 156 getExpr()157 llvm::Optional<ExprAST *> getExpr() { 158 if (expr.hasValue()) 159 return expr->get(); 160 return llvm::None; 161 } 162 163 /// LLVM style RTTI classof(const ExprAST * c)164 static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } 165 }; 166 167 /// Expression class for a binary operator. 168 class BinaryExprAST : public ExprAST { 169 char op; 170 std::unique_ptr<ExprAST> lhs, rhs; 171 172 public: getOp()173 char getOp() { return op; } getLHS()174 ExprAST *getLHS() { return lhs.get(); } getRHS()175 ExprAST *getRHS() { return rhs.get(); } 176 BinaryExprAST(Location loc,char op,std::unique_ptr<ExprAST> lhs,std::unique_ptr<ExprAST> rhs)177 BinaryExprAST(Location loc, char op, std::unique_ptr<ExprAST> lhs, 178 std::unique_ptr<ExprAST> rhs) 179 : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)), 180 rhs(std::move(rhs)) {} 181 182 /// LLVM style RTTI classof(const ExprAST * c)183 static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } 184 }; 185 186 /// Expression class for function calls. 187 class CallExprAST : public ExprAST { 188 std::string callee; 189 std::vector<std::unique_ptr<ExprAST>> args; 190 191 public: CallExprAST(Location loc,const std::string & callee,std::vector<std::unique_ptr<ExprAST>> args)192 CallExprAST(Location loc, const std::string &callee, 193 std::vector<std::unique_ptr<ExprAST>> args) 194 : ExprAST(Expr_Call, std::move(loc)), callee(callee), 195 args(std::move(args)) {} 196 getCallee()197 llvm::StringRef getCallee() { return callee; } getArgs()198 llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; } 199 200 /// LLVM style RTTI classof(const ExprAST * c)201 static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } 202 }; 203 204 /// Expression class for builtin print calls. 205 class PrintExprAST : public ExprAST { 206 std::unique_ptr<ExprAST> arg; 207 208 public: PrintExprAST(Location loc,std::unique_ptr<ExprAST> arg)209 PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg) 210 : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {} 211 getArg()212 ExprAST *getArg() { return arg.get(); } 213 214 /// LLVM style RTTI classof(const ExprAST * c)215 static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } 216 }; 217 218 /// This class represents the "prototype" for a function, which captures its 219 /// name, and its argument names (thus implicitly the number of arguments the 220 /// function takes). 221 class PrototypeAST { 222 Location location; 223 std::string name; 224 std::vector<std::unique_ptr<VarDeclExprAST>> args; 225 226 public: PrototypeAST(Location location,const std::string & name,std::vector<std::unique_ptr<VarDeclExprAST>> args)227 PrototypeAST(Location location, const std::string &name, 228 std::vector<std::unique_ptr<VarDeclExprAST>> args) 229 : location(std::move(location)), name(name), args(std::move(args)) {} 230 loc()231 const Location &loc() { return location; } getName()232 llvm::StringRef getName() const { return name; } getArgs()233 llvm::ArrayRef<std::unique_ptr<VarDeclExprAST>> getArgs() { return args; } 234 }; 235 236 /// This class represents a top level record in a module. 237 class RecordAST { 238 public: 239 enum RecordASTKind { 240 Record_Function, 241 Record_Struct, 242 }; 243 RecordAST(RecordASTKind kind)244 RecordAST(RecordASTKind kind) : kind(kind) {} 245 virtual ~RecordAST() = default; 246 getKind()247 RecordASTKind getKind() const { return kind; } 248 249 private: 250 const RecordASTKind kind; 251 }; 252 253 /// This class represents a function definition itself. 254 class FunctionAST : public RecordAST { 255 std::unique_ptr<PrototypeAST> proto; 256 std::unique_ptr<ExprASTList> body; 257 258 public: FunctionAST(std::unique_ptr<PrototypeAST> proto,std::unique_ptr<ExprASTList> body)259 FunctionAST(std::unique_ptr<PrototypeAST> proto, 260 std::unique_ptr<ExprASTList> body) 261 : RecordAST(Record_Function), proto(std::move(proto)), 262 body(std::move(body)) {} getProto()263 PrototypeAST *getProto() { return proto.get(); } getBody()264 ExprASTList *getBody() { return body.get(); } 265 266 /// LLVM style RTTI classof(const RecordAST * r)267 static bool classof(const RecordAST *r) { 268 return r->getKind() == Record_Function; 269 } 270 }; 271 272 /// This class represents a struct definition. 273 class StructAST : public RecordAST { 274 Location location; 275 std::string name; 276 std::vector<std::unique_ptr<VarDeclExprAST>> variables; 277 278 public: StructAST(Location location,const std::string & name,std::vector<std::unique_ptr<VarDeclExprAST>> variables)279 StructAST(Location location, const std::string &name, 280 std::vector<std::unique_ptr<VarDeclExprAST>> variables) 281 : RecordAST(Record_Struct), location(std::move(location)), name(name), 282 variables(std::move(variables)) {} 283 loc()284 const Location &loc() { return location; } getName()285 llvm::StringRef getName() const { return name; } getVariables()286 llvm::ArrayRef<std::unique_ptr<VarDeclExprAST>> getVariables() { 287 return variables; 288 } 289 290 /// LLVM style RTTI classof(const RecordAST * r)291 static bool classof(const RecordAST *r) { 292 return r->getKind() == Record_Struct; 293 } 294 }; 295 296 /// This class represents a list of functions to be processed together 297 class ModuleAST { 298 std::vector<std::unique_ptr<RecordAST>> records; 299 300 public: ModuleAST(std::vector<std::unique_ptr<RecordAST>> records)301 ModuleAST(std::vector<std::unique_ptr<RecordAST>> records) 302 : records(std::move(records)) {} 303 begin()304 auto begin() { return records.begin(); } end()305 auto end() { return records.end(); } 306 }; 307 308 void dump(ModuleAST &); 309 310 } // namespace toy 311 312 #endif // TOY_AST_H 313