1 //===- AST.h - Node definition for the Toy AST ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AST for the Toy language. It is optimized for
10 // simplicity, not efficiency. The AST forms a tree structure where each node
11 // references its children using std::unique_ptr<>.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef TOY_AST_H
16 #define TOY_AST_H
17 
18 #include "toy/Lexer.h"
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/Support/Casting.h"
23 #include <utility>
24 #include <vector>
25 
26 namespace toy {
27 
28 /// A variable type with either name or shape information.
29 struct VarType {
30   std::string name;
31   std::vector<int64_t> shape;
32 };
33 
34 /// Base class for all expression nodes.
35 class ExprAST {
36 public:
37   enum ExprASTKind {
38     Expr_VarDecl,
39     Expr_Return,
40     Expr_Num,
41     Expr_Literal,
42     Expr_StructLiteral,
43     Expr_Var,
44     Expr_BinOp,
45     Expr_Call,
46     Expr_Print,
47   };
48 
ExprAST(ExprASTKind kind,Location location)49   ExprAST(ExprASTKind kind, Location location)
50       : kind(kind), location(std::move(location)) {}
51   virtual ~ExprAST() = default;
52 
getKind()53   ExprASTKind getKind() const { return kind; }
54 
loc()55   const Location &loc() { return location; }
56 
57 private:
58   const ExprASTKind kind;
59   Location location;
60 };
61 
62 /// A block-list of expressions.
63 using ExprASTList = std::vector<std::unique_ptr<ExprAST>>;
64 
65 /// Expression class for numeric literals like "1.0".
66 class NumberExprAST : public ExprAST {
67   double val;
68 
69 public:
NumberExprAST(Location loc,double val)70   NumberExprAST(Location loc, double val)
71       : ExprAST(Expr_Num, std::move(loc)), val(val) {}
72 
getValue()73   double getValue() { return val; }
74 
75   /// LLVM style RTTI
classof(const ExprAST * c)76   static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
77 };
78 
79 /// Expression class for a literal value.
80 class LiteralExprAST : public ExprAST {
81   std::vector<std::unique_ptr<ExprAST>> values;
82   std::vector<int64_t> dims;
83 
84 public:
LiteralExprAST(Location loc,std::vector<std::unique_ptr<ExprAST>> values,std::vector<int64_t> dims)85   LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values,
86                  std::vector<int64_t> dims)
87       : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)),
88         dims(std::move(dims)) {}
89 
getValues()90   llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
getDims()91   llvm::ArrayRef<int64_t> getDims() { return dims; }
92 
93   /// LLVM style RTTI
classof(const ExprAST * c)94   static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; }
95 };
96 
97 /// Expression class for a literal struct value.
98 class StructLiteralExprAST : public ExprAST {
99   std::vector<std::unique_ptr<ExprAST>> values;
100 
101 public:
StructLiteralExprAST(Location loc,std::vector<std::unique_ptr<ExprAST>> values)102   StructLiteralExprAST(Location loc,
103                        std::vector<std::unique_ptr<ExprAST>> values)
104       : ExprAST(Expr_StructLiteral, std::move(loc)), values(std::move(values)) {
105   }
106 
getValues()107   llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
108 
109   /// LLVM style RTTI
classof(const ExprAST * c)110   static bool classof(const ExprAST *c) {
111     return c->getKind() == Expr_StructLiteral;
112   }
113 };
114 
115 /// Expression class for referencing a variable, like "a".
116 class VariableExprAST : public ExprAST {
117   std::string name;
118 
119 public:
VariableExprAST(Location loc,llvm::StringRef name)120   VariableExprAST(Location loc, llvm::StringRef name)
121       : ExprAST(Expr_Var, std::move(loc)), name(name) {}
122 
getName()123   llvm::StringRef getName() { return name; }
124 
125   /// LLVM style RTTI
classof(const ExprAST * c)126   static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; }
127 };
128 
129 /// Expression class for defining a variable.
130 class VarDeclExprAST : public ExprAST {
131   std::string name;
132   VarType type;
133   std::unique_ptr<ExprAST> initVal;
134 
135 public:
136   VarDeclExprAST(Location loc, llvm::StringRef name, VarType type,
137                  std::unique_ptr<ExprAST> initVal = nullptr)
ExprAST(Expr_VarDecl,std::move (loc))138       : ExprAST(Expr_VarDecl, std::move(loc)), name(name),
139         type(std::move(type)), initVal(std::move(initVal)) {}
140 
getName()141   llvm::StringRef getName() { return name; }
getInitVal()142   ExprAST *getInitVal() { return initVal.get(); }
getType()143   const VarType &getType() { return type; }
144 
145   /// LLVM style RTTI
classof(const ExprAST * c)146   static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; }
147 };
148 
149 /// Expression class for a return operator.
150 class ReturnExprAST : public ExprAST {
151   llvm::Optional<std::unique_ptr<ExprAST>> expr;
152 
153 public:
ReturnExprAST(Location loc,llvm::Optional<std::unique_ptr<ExprAST>> expr)154   ReturnExprAST(Location loc, llvm::Optional<std::unique_ptr<ExprAST>> expr)
155       : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {}
156 
getExpr()157   llvm::Optional<ExprAST *> getExpr() {
158     if (expr.hasValue())
159       return expr->get();
160     return llvm::None;
161   }
162 
163   /// LLVM style RTTI
classof(const ExprAST * c)164   static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; }
165 };
166 
167 /// Expression class for a binary operator.
168 class BinaryExprAST : public ExprAST {
169   char op;
170   std::unique_ptr<ExprAST> lhs, rhs;
171 
172 public:
getOp()173   char getOp() { return op; }
getLHS()174   ExprAST *getLHS() { return lhs.get(); }
getRHS()175   ExprAST *getRHS() { return rhs.get(); }
176 
BinaryExprAST(Location loc,char op,std::unique_ptr<ExprAST> lhs,std::unique_ptr<ExprAST> rhs)177   BinaryExprAST(Location loc, char op, std::unique_ptr<ExprAST> lhs,
178                 std::unique_ptr<ExprAST> rhs)
179       : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)),
180         rhs(std::move(rhs)) {}
181 
182   /// LLVM style RTTI
classof(const ExprAST * c)183   static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; }
184 };
185 
186 /// Expression class for function calls.
187 class CallExprAST : public ExprAST {
188   std::string callee;
189   std::vector<std::unique_ptr<ExprAST>> args;
190 
191 public:
CallExprAST(Location loc,const std::string & callee,std::vector<std::unique_ptr<ExprAST>> args)192   CallExprAST(Location loc, const std::string &callee,
193               std::vector<std::unique_ptr<ExprAST>> args)
194       : ExprAST(Expr_Call, std::move(loc)), callee(callee),
195         args(std::move(args)) {}
196 
getCallee()197   llvm::StringRef getCallee() { return callee; }
getArgs()198   llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; }
199 
200   /// LLVM style RTTI
classof(const ExprAST * c)201   static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; }
202 };
203 
204 /// Expression class for builtin print calls.
205 class PrintExprAST : public ExprAST {
206   std::unique_ptr<ExprAST> arg;
207 
208 public:
PrintExprAST(Location loc,std::unique_ptr<ExprAST> arg)209   PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
210       : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {}
211 
getArg()212   ExprAST *getArg() { return arg.get(); }
213 
214   /// LLVM style RTTI
classof(const ExprAST * c)215   static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
216 };
217 
218 /// This class represents the "prototype" for a function, which captures its
219 /// name, and its argument names (thus implicitly the number of arguments the
220 /// function takes).
221 class PrototypeAST {
222   Location location;
223   std::string name;
224   std::vector<std::unique_ptr<VarDeclExprAST>> args;
225 
226 public:
PrototypeAST(Location location,const std::string & name,std::vector<std::unique_ptr<VarDeclExprAST>> args)227   PrototypeAST(Location location, const std::string &name,
228                std::vector<std::unique_ptr<VarDeclExprAST>> args)
229       : location(std::move(location)), name(name), args(std::move(args)) {}
230 
loc()231   const Location &loc() { return location; }
getName()232   llvm::StringRef getName() const { return name; }
getArgs()233   llvm::ArrayRef<std::unique_ptr<VarDeclExprAST>> getArgs() { return args; }
234 };
235 
236 /// This class represents a top level record in a module.
237 class RecordAST {
238 public:
239   enum RecordASTKind {
240     Record_Function,
241     Record_Struct,
242   };
243 
RecordAST(RecordASTKind kind)244   RecordAST(RecordASTKind kind) : kind(kind) {}
245   virtual ~RecordAST() = default;
246 
getKind()247   RecordASTKind getKind() const { return kind; }
248 
249 private:
250   const RecordASTKind kind;
251 };
252 
253 /// This class represents a function definition itself.
254 class FunctionAST : public RecordAST {
255   std::unique_ptr<PrototypeAST> proto;
256   std::unique_ptr<ExprASTList> body;
257 
258 public:
FunctionAST(std::unique_ptr<PrototypeAST> proto,std::unique_ptr<ExprASTList> body)259   FunctionAST(std::unique_ptr<PrototypeAST> proto,
260               std::unique_ptr<ExprASTList> body)
261       : RecordAST(Record_Function), proto(std::move(proto)),
262         body(std::move(body)) {}
getProto()263   PrototypeAST *getProto() { return proto.get(); }
getBody()264   ExprASTList *getBody() { return body.get(); }
265 
266   /// LLVM style RTTI
classof(const RecordAST * r)267   static bool classof(const RecordAST *r) {
268     return r->getKind() == Record_Function;
269   }
270 };
271 
272 /// This class represents a struct definition.
273 class StructAST : public RecordAST {
274   Location location;
275   std::string name;
276   std::vector<std::unique_ptr<VarDeclExprAST>> variables;
277 
278 public:
StructAST(Location location,const std::string & name,std::vector<std::unique_ptr<VarDeclExprAST>> variables)279   StructAST(Location location, const std::string &name,
280             std::vector<std::unique_ptr<VarDeclExprAST>> variables)
281       : RecordAST(Record_Struct), location(std::move(location)), name(name),
282         variables(std::move(variables)) {}
283 
loc()284   const Location &loc() { return location; }
getName()285   llvm::StringRef getName() const { return name; }
getVariables()286   llvm::ArrayRef<std::unique_ptr<VarDeclExprAST>> getVariables() {
287     return variables;
288   }
289 
290   /// LLVM style RTTI
classof(const RecordAST * r)291   static bool classof(const RecordAST *r) {
292     return r->getKind() == Record_Struct;
293   }
294 };
295 
296 /// This class represents a list of functions to be processed together
297 class ModuleAST {
298   std::vector<std::unique_ptr<RecordAST>> records;
299 
300 public:
ModuleAST(std::vector<std::unique_ptr<RecordAST>> records)301   ModuleAST(std::vector<std::unique_ptr<RecordAST>> records)
302       : records(std::move(records)) {}
303 
begin()304   auto begin() { return records.begin(); }
end()305   auto end() { return records.end(); }
306 };
307 
308 void dump(ModuleAST &);
309 
310 } // namespace toy
311 
312 #endif // TOY_AST_H
313