1 //===- Nodes.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/AST/Nodes.h"
10 #include "mlir/Tools/PDLL/AST/Context.h"
11 #include "llvm/ADT/SmallPtrSet.h"
12 
13 using namespace mlir;
14 using namespace mlir::pdll::ast;
15 
16 /// Copy a string reference into the context with a null terminator.
17 static StringRef copyStringWithNull(Context &ctx, StringRef str) {
18   if (str.empty())
19     return str;
20 
21   char *data = ctx.getAllocator().Allocate<char>(str.size() + 1);
22   std::copy(str.begin(), str.end(), data);
23   data[str.size()] = 0;
24   return StringRef(data, str.size());
25 }
26 
27 //===----------------------------------------------------------------------===//
28 // Name
29 //===----------------------------------------------------------------------===//
30 
31 const Name &Name::create(Context &ctx, StringRef name, SMRange location) {
32   return *new (ctx.getAllocator().Allocate<Name>())
33       Name(copyStringWithNull(ctx, name), location);
34 }
35 
36 //===----------------------------------------------------------------------===//
37 // DeclScope
38 //===----------------------------------------------------------------------===//
39 
40 void DeclScope::add(Decl *decl) {
41   const Name *name = decl->getName();
42   assert(name && "expected a named decl");
43   assert(!decls.count(name->getName()) && "decl with this name already exists");
44   decls.try_emplace(name->getName(), decl);
45 }
46 
47 Decl *DeclScope::lookup(StringRef name) {
48   if (Decl *decl = decls.lookup(name))
49     return decl;
50   return parent ? parent->lookup(name) : nullptr;
51 }
52 
53 //===----------------------------------------------------------------------===//
54 // CompoundStmt
55 //===----------------------------------------------------------------------===//
56 
57 CompoundStmt *CompoundStmt::create(Context &ctx, SMRange loc,
58                                    ArrayRef<Stmt *> children) {
59   unsigned allocSize = CompoundStmt::totalSizeToAlloc<Stmt *>(children.size());
60   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CompoundStmt));
61 
62   CompoundStmt *stmt = new (rawData) CompoundStmt(loc, children.size());
63   std::uninitialized_copy(children.begin(), children.end(),
64                           stmt->getChildren().begin());
65   return stmt;
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // LetStmt
70 //===----------------------------------------------------------------------===//
71 
72 LetStmt *LetStmt::create(Context &ctx, SMRange loc,
73                          VariableDecl *varDecl) {
74   return new (ctx.getAllocator().Allocate<LetStmt>()) LetStmt(loc, varDecl);
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // OpRewriteStmt
79 //===----------------------------------------------------------------------===//
80 
81 //===----------------------------------------------------------------------===//
82 // EraseStmt
83 
84 EraseStmt *EraseStmt::create(Context &ctx, SMRange loc, Expr *rootOp) {
85   return new (ctx.getAllocator().Allocate<EraseStmt>()) EraseStmt(loc, rootOp);
86 }
87 
88 //===----------------------------------------------------------------------===//
89 // ReplaceStmt
90 
91 ReplaceStmt *ReplaceStmt::create(Context &ctx, SMRange loc, Expr *rootOp,
92                                  ArrayRef<Expr *> replExprs) {
93   unsigned allocSize = ReplaceStmt::totalSizeToAlloc<Expr *>(replExprs.size());
94   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(ReplaceStmt));
95 
96   ReplaceStmt *stmt = new (rawData) ReplaceStmt(loc, rootOp, replExprs.size());
97   std::uninitialized_copy(replExprs.begin(), replExprs.end(),
98                           stmt->getReplExprs().begin());
99   return stmt;
100 }
101 
102 //===----------------------------------------------------------------------===//
103 // RewriteStmt
104 
105 RewriteStmt *RewriteStmt::create(Context &ctx, SMRange loc, Expr *rootOp,
106                                  CompoundStmt *rewriteBody) {
107   return new (ctx.getAllocator().Allocate<RewriteStmt>())
108       RewriteStmt(loc, rootOp, rewriteBody);
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // AttributeExpr
113 //===----------------------------------------------------------------------===//
114 
115 AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc,
116                                      StringRef value) {
117   return new (ctx.getAllocator().Allocate<AttributeExpr>())
118       AttributeExpr(ctx, loc, copyStringWithNull(ctx, value));
119 }
120 
121 //===----------------------------------------------------------------------===//
122 // DeclRefExpr
123 //===----------------------------------------------------------------------===//
124 
125 DeclRefExpr *DeclRefExpr::create(Context &ctx, SMRange loc, Decl *decl,
126                                  Type type) {
127   return new (ctx.getAllocator().Allocate<DeclRefExpr>())
128       DeclRefExpr(loc, decl, type);
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // MemberAccessExpr
133 //===----------------------------------------------------------------------===//
134 
135 MemberAccessExpr *MemberAccessExpr::create(Context &ctx, SMRange loc,
136                                            const Expr *parentExpr,
137                                            StringRef memberName, Type type) {
138   return new (ctx.getAllocator().Allocate<MemberAccessExpr>()) MemberAccessExpr(
139       loc, parentExpr, memberName.copy(ctx.getAllocator()), type);
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // OperationExpr
144 //===----------------------------------------------------------------------===//
145 
146 OperationExpr *OperationExpr::create(
147     Context &ctx, SMRange loc, const OpNameDecl *name,
148     ArrayRef<Expr *> operands, ArrayRef<Expr *> resultTypes,
149     ArrayRef<NamedAttributeDecl *> attributes) {
150   unsigned allocSize =
151       OperationExpr::totalSizeToAlloc<Expr *, NamedAttributeDecl *>(
152           operands.size() + resultTypes.size(), attributes.size());
153   void *rawData =
154       ctx.getAllocator().Allocate(allocSize, alignof(OperationExpr));
155 
156   Type resultType = OperationType::get(ctx, name->getName());
157   OperationExpr *opExpr = new (rawData)
158       OperationExpr(loc, resultType, name, operands.size(), resultTypes.size(),
159                     attributes.size(), name->getLoc());
160   std::uninitialized_copy(operands.begin(), operands.end(),
161                           opExpr->getOperands().begin());
162   std::uninitialized_copy(resultTypes.begin(), resultTypes.end(),
163                           opExpr->getResultTypes().begin());
164   std::uninitialized_copy(attributes.begin(), attributes.end(),
165                           opExpr->getAttributes().begin());
166   return opExpr;
167 }
168 
169 Optional<StringRef> OperationExpr::getName() const {
170   return getNameDecl()->getName();
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // TupleExpr
175 //===----------------------------------------------------------------------===//
176 
177 TupleExpr *TupleExpr::create(Context &ctx, SMRange loc,
178                              ArrayRef<Expr *> elements,
179                              ArrayRef<StringRef> names) {
180   unsigned allocSize = TupleExpr::totalSizeToAlloc<Expr *>(elements.size());
181   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(TupleExpr));
182 
183   auto elementTypes = llvm::map_range(
184       elements, [](const Expr *expr) { return expr->getType(); });
185   TupleType type = TupleType::get(ctx, llvm::to_vector(elementTypes), names);
186 
187   TupleExpr *expr = new (rawData) TupleExpr(loc, type);
188   std::uninitialized_copy(elements.begin(), elements.end(),
189                           expr->getElements().begin());
190   return expr;
191 }
192 
193 //===----------------------------------------------------------------------===//
194 // TypeExpr
195 //===----------------------------------------------------------------------===//
196 
197 TypeExpr *TypeExpr::create(Context &ctx, SMRange loc, StringRef value) {
198   return new (ctx.getAllocator().Allocate<TypeExpr>())
199       TypeExpr(ctx, loc, copyStringWithNull(ctx, value));
200 }
201 
202 //===----------------------------------------------------------------------===//
203 // AttrConstraintDecl
204 //===----------------------------------------------------------------------===//
205 
206 AttrConstraintDecl *AttrConstraintDecl::create(Context &ctx, SMRange loc,
207                                                Expr *typeExpr) {
208   return new (ctx.getAllocator().Allocate<AttrConstraintDecl>())
209       AttrConstraintDecl(loc, typeExpr);
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // OpConstraintDecl
214 //===----------------------------------------------------------------------===//
215 
216 OpConstraintDecl *OpConstraintDecl::create(Context &ctx, SMRange loc,
217                                            const OpNameDecl *nameDecl) {
218   if (!nameDecl)
219     nameDecl = OpNameDecl::create(ctx, SMRange());
220 
221   return new (ctx.getAllocator().Allocate<OpConstraintDecl>())
222       OpConstraintDecl(loc, nameDecl);
223 }
224 
225 Optional<StringRef> OpConstraintDecl::getName() const {
226   return getNameDecl()->getName();
227 }
228 
229 //===----------------------------------------------------------------------===//
230 // TypeConstraintDecl
231 //===----------------------------------------------------------------------===//
232 
233 TypeConstraintDecl *TypeConstraintDecl::create(Context &ctx,
234                                                SMRange loc) {
235   return new (ctx.getAllocator().Allocate<TypeConstraintDecl>())
236       TypeConstraintDecl(loc);
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // TypeRangeConstraintDecl
241 //===----------------------------------------------------------------------===//
242 
243 TypeRangeConstraintDecl *TypeRangeConstraintDecl::create(Context &ctx,
244                                                          SMRange loc) {
245   return new (ctx.getAllocator().Allocate<TypeRangeConstraintDecl>())
246       TypeRangeConstraintDecl(loc);
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // ValueConstraintDecl
251 //===----------------------------------------------------------------------===//
252 
253 ValueConstraintDecl *
254 ValueConstraintDecl::create(Context &ctx, SMRange loc, Expr *typeExpr) {
255   return new (ctx.getAllocator().Allocate<ValueConstraintDecl>())
256       ValueConstraintDecl(loc, typeExpr);
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // ValueRangeConstraintDecl
261 //===----------------------------------------------------------------------===//
262 
263 ValueRangeConstraintDecl *ValueRangeConstraintDecl::create(Context &ctx,
264                                                            SMRange loc,
265                                                            Expr *typeExpr) {
266   return new (ctx.getAllocator().Allocate<ValueRangeConstraintDecl>())
267       ValueRangeConstraintDecl(loc, typeExpr);
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // NamedAttributeDecl
272 //===----------------------------------------------------------------------===//
273 
274 NamedAttributeDecl *NamedAttributeDecl::create(Context &ctx, const Name &name,
275                                                Expr *value) {
276   return new (ctx.getAllocator().Allocate<NamedAttributeDecl>())
277       NamedAttributeDecl(name, value);
278 }
279 
280 //===----------------------------------------------------------------------===//
281 // OpNameDecl
282 //===----------------------------------------------------------------------===//
283 
284 OpNameDecl *OpNameDecl::create(Context &ctx, const Name &name) {
285   return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(name);
286 }
287 OpNameDecl *OpNameDecl::create(Context &ctx, SMRange loc) {
288   return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(loc);
289 }
290 
291 //===----------------------------------------------------------------------===//
292 // PatternDecl
293 //===----------------------------------------------------------------------===//
294 
295 PatternDecl *PatternDecl::create(Context &ctx, SMRange loc,
296                                  const Name *name, Optional<uint16_t> benefit,
297                                  bool hasBoundedRecursion,
298                                  const CompoundStmt *body) {
299   return new (ctx.getAllocator().Allocate<PatternDecl>())
300       PatternDecl(loc, name, benefit, hasBoundedRecursion, body);
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // VariableDecl
305 //===----------------------------------------------------------------------===//
306 
307 VariableDecl *VariableDecl::create(Context &ctx, const Name &name, Type type,
308                                    Expr *initExpr,
309                                    ArrayRef<ConstraintRef> constraints) {
310   unsigned allocSize =
311       VariableDecl::totalSizeToAlloc<ConstraintRef>(constraints.size());
312   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(VariableDecl));
313 
314   VariableDecl *varDecl =
315       new (rawData) VariableDecl(name, type, initExpr, constraints.size());
316   std::uninitialized_copy(constraints.begin(), constraints.end(),
317                           varDecl->getConstraints().begin());
318   return varDecl;
319 }
320 
321 //===----------------------------------------------------------------------===//
322 // Module
323 //===----------------------------------------------------------------------===//
324 
325 Module *Module::create(Context &ctx, SMLoc loc,
326                        ArrayRef<Decl *> children) {
327   unsigned allocSize = Module::totalSizeToAlloc<Decl *>(children.size());
328   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(Module));
329 
330   Module *module = new (rawData) Module(loc, children.size());
331   std::uninitialized_copy(children.begin(), children.end(),
332                           module->getChildren().begin());
333   return module;
334 }
335