1 //===- Nodes.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/AST/Nodes.h"
10 #include "mlir/Tools/PDLL/AST/Context.h"
11 #include "llvm/ADT/SmallPtrSet.h"
12 
13 using namespace mlir;
14 using namespace mlir::pdll::ast;
15 
16 /// Copy a string reference into the context with a null terminator.
17 static StringRef copyStringWithNull(Context &ctx, StringRef str) {
18   if (str.empty())
19     return str;
20 
21   char *data = ctx.getAllocator().Allocate<char>(str.size() + 1);
22   std::copy(str.begin(), str.end(), data);
23   data[str.size()] = 0;
24   return StringRef(data, str.size());
25 }
26 
27 //===----------------------------------------------------------------------===//
28 // Name
29 //===----------------------------------------------------------------------===//
30 
31 const Name &Name::create(Context &ctx, StringRef name, SMRange location) {
32   return *new (ctx.getAllocator().Allocate<Name>())
33       Name(copyStringWithNull(ctx, name), location);
34 }
35 
36 //===----------------------------------------------------------------------===//
37 // DeclScope
38 //===----------------------------------------------------------------------===//
39 
40 void DeclScope::add(Decl *decl) {
41   const Name *name = decl->getName();
42   assert(name && "expected a named decl");
43   assert(!decls.count(name->getName()) && "decl with this name already exists");
44   decls.try_emplace(name->getName(), decl);
45 }
46 
47 Decl *DeclScope::lookup(StringRef name) {
48   if (Decl *decl = decls.lookup(name))
49     return decl;
50   return parent ? parent->lookup(name) : nullptr;
51 }
52 
53 //===----------------------------------------------------------------------===//
54 // CompoundStmt
55 //===----------------------------------------------------------------------===//
56 
57 CompoundStmt *CompoundStmt::create(Context &ctx, SMRange loc,
58                                    ArrayRef<Stmt *> children) {
59   unsigned allocSize = CompoundStmt::totalSizeToAlloc<Stmt *>(children.size());
60   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CompoundStmt));
61 
62   CompoundStmt *stmt = new (rawData) CompoundStmt(loc, children.size());
63   std::uninitialized_copy(children.begin(), children.end(),
64                           stmt->getChildren().begin());
65   return stmt;
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // LetStmt
70 //===----------------------------------------------------------------------===//
71 
72 LetStmt *LetStmt::create(Context &ctx, SMRange loc,
73                          VariableDecl *varDecl) {
74   return new (ctx.getAllocator().Allocate<LetStmt>()) LetStmt(loc, varDecl);
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // OpRewriteStmt
79 //===----------------------------------------------------------------------===//
80 
81 //===----------------------------------------------------------------------===//
82 // EraseStmt
83 
84 EraseStmt *EraseStmt::create(Context &ctx, SMRange loc, Expr *rootOp) {
85   return new (ctx.getAllocator().Allocate<EraseStmt>()) EraseStmt(loc, rootOp);
86 }
87 
88 //===----------------------------------------------------------------------===//
89 // ReplaceStmt
90 
91 ReplaceStmt *ReplaceStmt::create(Context &ctx, SMRange loc, Expr *rootOp,
92                                  ArrayRef<Expr *> replExprs) {
93   unsigned allocSize = ReplaceStmt::totalSizeToAlloc<Expr *>(replExprs.size());
94   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(ReplaceStmt));
95 
96   ReplaceStmt *stmt = new (rawData) ReplaceStmt(loc, rootOp, replExprs.size());
97   std::uninitialized_copy(replExprs.begin(), replExprs.end(),
98                           stmt->getReplExprs().begin());
99   return stmt;
100 }
101 
102 //===----------------------------------------------------------------------===//
103 // RewriteStmt
104 
105 RewriteStmt *RewriteStmt::create(Context &ctx, SMRange loc, Expr *rootOp,
106                                  CompoundStmt *rewriteBody) {
107   return new (ctx.getAllocator().Allocate<RewriteStmt>())
108       RewriteStmt(loc, rootOp, rewriteBody);
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // ReturnStmt
113 //===----------------------------------------------------------------------===//
114 
115 ReturnStmt *ReturnStmt::create(Context &ctx, SMRange loc, Expr *resultExpr) {
116   return new (ctx.getAllocator().Allocate<ReturnStmt>())
117       ReturnStmt(loc, resultExpr);
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // AttributeExpr
122 //===----------------------------------------------------------------------===//
123 
124 AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc,
125                                      StringRef value) {
126   return new (ctx.getAllocator().Allocate<AttributeExpr>())
127       AttributeExpr(ctx, loc, copyStringWithNull(ctx, value));
128 }
129 
130 //===----------------------------------------------------------------------===//
131 // CallExpr
132 //===----------------------------------------------------------------------===//
133 
134 CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable,
135                            ArrayRef<Expr *> arguments, Type resultType) {
136   unsigned allocSize = CallExpr::totalSizeToAlloc<Expr *>(arguments.size());
137   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr));
138 
139   CallExpr *expr =
140       new (rawData) CallExpr(loc, resultType, callable, arguments.size());
141   std::uninitialized_copy(arguments.begin(), arguments.end(),
142                           expr->getArguments().begin());
143   return expr;
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // DeclRefExpr
148 //===----------------------------------------------------------------------===//
149 
150 DeclRefExpr *DeclRefExpr::create(Context &ctx, SMRange loc, Decl *decl,
151                                  Type type) {
152   return new (ctx.getAllocator().Allocate<DeclRefExpr>())
153       DeclRefExpr(loc, decl, type);
154 }
155 
156 //===----------------------------------------------------------------------===//
157 // MemberAccessExpr
158 //===----------------------------------------------------------------------===//
159 
160 MemberAccessExpr *MemberAccessExpr::create(Context &ctx, SMRange loc,
161                                            const Expr *parentExpr,
162                                            StringRef memberName, Type type) {
163   return new (ctx.getAllocator().Allocate<MemberAccessExpr>()) MemberAccessExpr(
164       loc, parentExpr, memberName.copy(ctx.getAllocator()), type);
165 }
166 
167 //===----------------------------------------------------------------------===//
168 // OperationExpr
169 //===----------------------------------------------------------------------===//
170 
171 OperationExpr *OperationExpr::create(
172     Context &ctx, SMRange loc, const OpNameDecl *name,
173     ArrayRef<Expr *> operands, ArrayRef<Expr *> resultTypes,
174     ArrayRef<NamedAttributeDecl *> attributes) {
175   unsigned allocSize =
176       OperationExpr::totalSizeToAlloc<Expr *, NamedAttributeDecl *>(
177           operands.size() + resultTypes.size(), attributes.size());
178   void *rawData =
179       ctx.getAllocator().Allocate(allocSize, alignof(OperationExpr));
180 
181   Type resultType = OperationType::get(ctx, name->getName());
182   OperationExpr *opExpr = new (rawData)
183       OperationExpr(loc, resultType, name, operands.size(), resultTypes.size(),
184                     attributes.size(), name->getLoc());
185   std::uninitialized_copy(operands.begin(), operands.end(),
186                           opExpr->getOperands().begin());
187   std::uninitialized_copy(resultTypes.begin(), resultTypes.end(),
188                           opExpr->getResultTypes().begin());
189   std::uninitialized_copy(attributes.begin(), attributes.end(),
190                           opExpr->getAttributes().begin());
191   return opExpr;
192 }
193 
194 Optional<StringRef> OperationExpr::getName() const {
195   return getNameDecl()->getName();
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // TupleExpr
200 //===----------------------------------------------------------------------===//
201 
202 TupleExpr *TupleExpr::create(Context &ctx, SMRange loc,
203                              ArrayRef<Expr *> elements,
204                              ArrayRef<StringRef> names) {
205   unsigned allocSize = TupleExpr::totalSizeToAlloc<Expr *>(elements.size());
206   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(TupleExpr));
207 
208   auto elementTypes = llvm::map_range(
209       elements, [](const Expr *expr) { return expr->getType(); });
210   TupleType type = TupleType::get(ctx, llvm::to_vector(elementTypes), names);
211 
212   TupleExpr *expr = new (rawData) TupleExpr(loc, type);
213   std::uninitialized_copy(elements.begin(), elements.end(),
214                           expr->getElements().begin());
215   return expr;
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // TypeExpr
220 //===----------------------------------------------------------------------===//
221 
222 TypeExpr *TypeExpr::create(Context &ctx, SMRange loc, StringRef value) {
223   return new (ctx.getAllocator().Allocate<TypeExpr>())
224       TypeExpr(ctx, loc, copyStringWithNull(ctx, value));
225 }
226 
227 //===----------------------------------------------------------------------===//
228 // AttrConstraintDecl
229 //===----------------------------------------------------------------------===//
230 
231 AttrConstraintDecl *AttrConstraintDecl::create(Context &ctx, SMRange loc,
232                                                Expr *typeExpr) {
233   return new (ctx.getAllocator().Allocate<AttrConstraintDecl>())
234       AttrConstraintDecl(loc, typeExpr);
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // OpConstraintDecl
239 //===----------------------------------------------------------------------===//
240 
241 OpConstraintDecl *OpConstraintDecl::create(Context &ctx, SMRange loc,
242                                            const OpNameDecl *nameDecl) {
243   if (!nameDecl)
244     nameDecl = OpNameDecl::create(ctx, SMRange());
245 
246   return new (ctx.getAllocator().Allocate<OpConstraintDecl>())
247       OpConstraintDecl(loc, nameDecl);
248 }
249 
250 Optional<StringRef> OpConstraintDecl::getName() const {
251   return getNameDecl()->getName();
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // TypeConstraintDecl
256 //===----------------------------------------------------------------------===//
257 
258 TypeConstraintDecl *TypeConstraintDecl::create(Context &ctx,
259                                                SMRange loc) {
260   return new (ctx.getAllocator().Allocate<TypeConstraintDecl>())
261       TypeConstraintDecl(loc);
262 }
263 
264 //===----------------------------------------------------------------------===//
265 // TypeRangeConstraintDecl
266 //===----------------------------------------------------------------------===//
267 
268 TypeRangeConstraintDecl *TypeRangeConstraintDecl::create(Context &ctx,
269                                                          SMRange loc) {
270   return new (ctx.getAllocator().Allocate<TypeRangeConstraintDecl>())
271       TypeRangeConstraintDecl(loc);
272 }
273 
274 //===----------------------------------------------------------------------===//
275 // ValueConstraintDecl
276 //===----------------------------------------------------------------------===//
277 
278 ValueConstraintDecl *
279 ValueConstraintDecl::create(Context &ctx, SMRange loc, Expr *typeExpr) {
280   return new (ctx.getAllocator().Allocate<ValueConstraintDecl>())
281       ValueConstraintDecl(loc, typeExpr);
282 }
283 
284 //===----------------------------------------------------------------------===//
285 // ValueRangeConstraintDecl
286 //===----------------------------------------------------------------------===//
287 
288 ValueRangeConstraintDecl *ValueRangeConstraintDecl::create(Context &ctx,
289                                                            SMRange loc,
290                                                            Expr *typeExpr) {
291   return new (ctx.getAllocator().Allocate<ValueRangeConstraintDecl>())
292       ValueRangeConstraintDecl(loc, typeExpr);
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // UserConstraintDecl
297 //===----------------------------------------------------------------------===//
298 
299 UserConstraintDecl *UserConstraintDecl::createImpl(
300     Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
301     ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock,
302     const CompoundStmt *body, Type resultType) {
303   unsigned allocSize = UserConstraintDecl::totalSizeToAlloc<VariableDecl *>(
304       inputs.size() + results.size());
305   void *rawData =
306       ctx.getAllocator().Allocate(allocSize, alignof(UserConstraintDecl));
307   if (codeBlock)
308     codeBlock = codeBlock->copy(ctx.getAllocator());
309 
310   UserConstraintDecl *decl = new (rawData) UserConstraintDecl(
311       name, inputs.size(), results.size(), codeBlock, body, resultType);
312   std::uninitialized_copy(inputs.begin(), inputs.end(),
313                           decl->getInputs().begin());
314   std::uninitialized_copy(results.begin(), results.end(),
315                           decl->getResults().begin());
316   return decl;
317 }
318 
319 //===----------------------------------------------------------------------===//
320 // NamedAttributeDecl
321 //===----------------------------------------------------------------------===//
322 
323 NamedAttributeDecl *NamedAttributeDecl::create(Context &ctx, const Name &name,
324                                                Expr *value) {
325   return new (ctx.getAllocator().Allocate<NamedAttributeDecl>())
326       NamedAttributeDecl(name, value);
327 }
328 
329 //===----------------------------------------------------------------------===//
330 // OpNameDecl
331 //===----------------------------------------------------------------------===//
332 
333 OpNameDecl *OpNameDecl::create(Context &ctx, const Name &name) {
334   return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(name);
335 }
336 OpNameDecl *OpNameDecl::create(Context &ctx, SMRange loc) {
337   return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(loc);
338 }
339 
340 //===----------------------------------------------------------------------===//
341 // PatternDecl
342 //===----------------------------------------------------------------------===//
343 
344 PatternDecl *PatternDecl::create(Context &ctx, SMRange loc,
345                                  const Name *name, Optional<uint16_t> benefit,
346                                  bool hasBoundedRecursion,
347                                  const CompoundStmt *body) {
348   return new (ctx.getAllocator().Allocate<PatternDecl>())
349       PatternDecl(loc, name, benefit, hasBoundedRecursion, body);
350 }
351 
352 //===----------------------------------------------------------------------===//
353 // UserRewriteDecl
354 //===----------------------------------------------------------------------===//
355 
356 UserRewriteDecl *UserRewriteDecl::createImpl(Context &ctx, const Name &name,
357                                              ArrayRef<VariableDecl *> inputs,
358                                              ArrayRef<VariableDecl *> results,
359                                              Optional<StringRef> codeBlock,
360                                              const CompoundStmt *body,
361                                              Type resultType) {
362   unsigned allocSize = UserRewriteDecl::totalSizeToAlloc<VariableDecl *>(
363       inputs.size() + results.size());
364   void *rawData =
365       ctx.getAllocator().Allocate(allocSize, alignof(UserRewriteDecl));
366   if (codeBlock)
367     codeBlock = codeBlock->copy(ctx.getAllocator());
368 
369   UserRewriteDecl *decl = new (rawData) UserRewriteDecl(
370       name, inputs.size(), results.size(), codeBlock, body, resultType);
371   std::uninitialized_copy(inputs.begin(), inputs.end(),
372                           decl->getInputs().begin());
373   std::uninitialized_copy(results.begin(), results.end(),
374                           decl->getResults().begin());
375   return decl;
376 }
377 
378 //===----------------------------------------------------------------------===//
379 // VariableDecl
380 //===----------------------------------------------------------------------===//
381 
382 VariableDecl *VariableDecl::create(Context &ctx, const Name &name, Type type,
383                                    Expr *initExpr,
384                                    ArrayRef<ConstraintRef> constraints) {
385   unsigned allocSize =
386       VariableDecl::totalSizeToAlloc<ConstraintRef>(constraints.size());
387   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(VariableDecl));
388 
389   VariableDecl *varDecl =
390       new (rawData) VariableDecl(name, type, initExpr, constraints.size());
391   std::uninitialized_copy(constraints.begin(), constraints.end(),
392                           varDecl->getConstraints().begin());
393   return varDecl;
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // Module
398 //===----------------------------------------------------------------------===//
399 
400 Module *Module::create(Context &ctx, SMLoc loc,
401                        ArrayRef<Decl *> children) {
402   unsigned allocSize = Module::totalSizeToAlloc<Decl *>(children.size());
403   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(Module));
404 
405   Module *module = new (rawData) Module(loc, children.size());
406   std::uninitialized_copy(children.begin(), children.end(),
407                           module->getChildren().begin());
408   return module;
409 }
410