1 //===- Nodes.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/AST/Nodes.h"
10 #include "mlir/Tools/PDLL/AST/Context.h"
11 #include "llvm/ADT/SmallPtrSet.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 
14 using namespace mlir;
15 using namespace mlir::pdll::ast;
16 
17 /// Copy a string reference into the context with a null terminator.
18 static StringRef copyStringWithNull(Context &ctx, StringRef str) {
19   if (str.empty())
20     return str;
21 
22   char *data = ctx.getAllocator().Allocate<char>(str.size() + 1);
23   std::copy(str.begin(), str.end(), data);
24   data[str.size()] = 0;
25   return StringRef(data, str.size());
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // Name
30 //===----------------------------------------------------------------------===//
31 
32 const Name &Name::create(Context &ctx, StringRef name, SMRange location) {
33   return *new (ctx.getAllocator().Allocate<Name>())
34       Name(copyStringWithNull(ctx, name), location);
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // Node
39 //===----------------------------------------------------------------------===//
40 
41 namespace {
42 class NodeVisitor {
43 public:
44   explicit NodeVisitor(function_ref<void(const Node *)> visitFn)
45       : visitFn(visitFn) {}
46 
47   void visit(const Node *node) {
48     if (!node || !alreadyVisited.insert(node).second)
49       return;
50 
51     visitFn(node);
52     TypeSwitch<const Node *>(node)
53         .Case<
54             // Statements.
55             const CompoundStmt, const EraseStmt, const LetStmt,
56             const ReplaceStmt, const ReturnStmt, const RewriteStmt,
57 
58             // Expressions.
59             const AttributeExpr, const CallExpr, const DeclRefExpr,
60             const MemberAccessExpr, const OperationExpr, const TupleExpr,
61             const TypeExpr,
62 
63             // Core Constraint Decls.
64             const AttrConstraintDecl, const OpConstraintDecl,
65             const TypeConstraintDecl, const TypeRangeConstraintDecl,
66             const ValueConstraintDecl, const ValueRangeConstraintDecl,
67 
68             // Decls.
69             const NamedAttributeDecl, const OpNameDecl, const PatternDecl,
70             const UserConstraintDecl, const UserRewriteDecl, const VariableDecl,
71 
72             const Module>(
73             [&](auto derivedNode) { this->visitImpl(derivedNode); })
74         .Default([](const Node *) { llvm_unreachable("unknown AST node"); });
75   }
76 
77 private:
78   void visitImpl(const CompoundStmt *stmt) {
79     for (const Node *child : stmt->getChildren())
80       visit(child);
81   }
82   void visitImpl(const EraseStmt *stmt) { visit(stmt->getRootOpExpr()); }
83   void visitImpl(const LetStmt *stmt) { visit(stmt->getVarDecl()); }
84   void visitImpl(const ReplaceStmt *stmt) {
85     visit(stmt->getRootOpExpr());
86     for (const Node *child : stmt->getReplExprs())
87       visit(child);
88   }
89   void visitImpl(const ReturnStmt *stmt) { visit(stmt->getResultExpr()); }
90   void visitImpl(const RewriteStmt *stmt) {
91     visit(stmt->getRootOpExpr());
92     visit(stmt->getRewriteBody());
93   }
94 
95   void visitImpl(const AttributeExpr *expr) {}
96   void visitImpl(const CallExpr *expr) {
97     visit(expr->getCallableExpr());
98     for (const Node *child : expr->getArguments())
99       visit(child);
100   }
101   void visitImpl(const DeclRefExpr *expr) { visit(expr->getDecl()); }
102   void visitImpl(const MemberAccessExpr *expr) { visit(expr->getParentExpr()); }
103   void visitImpl(const OperationExpr *expr) {
104     visit(expr->getNameDecl());
105     for (const Node *child : expr->getOperands())
106       visit(child);
107     for (const Node *child : expr->getResultTypes())
108       visit(child);
109     for (const Node *child : expr->getAttributes())
110       visit(child);
111   }
112   void visitImpl(const TupleExpr *expr) {
113     for (const Node *child : expr->getElements())
114       visit(child);
115   }
116   void visitImpl(const TypeExpr *expr) {}
117 
118   void visitImpl(const AttrConstraintDecl *decl) { visit(decl->getTypeExpr()); }
119   void visitImpl(const OpConstraintDecl *decl) { visit(decl->getNameDecl()); }
120   void visitImpl(const TypeConstraintDecl *decl) {}
121   void visitImpl(const TypeRangeConstraintDecl *decl) {}
122   void visitImpl(const ValueConstraintDecl *decl) {
123     visit(decl->getTypeExpr());
124   }
125   void visitImpl(const ValueRangeConstraintDecl *decl) {
126     visit(decl->getTypeExpr());
127   }
128 
129   void visitImpl(const NamedAttributeDecl *decl) { visit(decl->getValue()); }
130   void visitImpl(const OpNameDecl *decl) {}
131   void visitImpl(const PatternDecl *decl) { visit(decl->getBody()); }
132   void visitImpl(const UserConstraintDecl *decl) {
133     for (const Node *child : decl->getInputs())
134       visit(child);
135     for (const Node *child : decl->getResults())
136       visit(child);
137     visit(decl->getBody());
138   }
139   void visitImpl(const UserRewriteDecl *decl) {
140     for (const Node *child : decl->getInputs())
141       visit(child);
142     for (const Node *child : decl->getResults())
143       visit(child);
144     visit(decl->getBody());
145   }
146   void visitImpl(const VariableDecl *decl) {
147     visit(decl->getInitExpr());
148     for (const ConstraintRef &child : decl->getConstraints())
149       visit(child.constraint);
150   }
151 
152   void visitImpl(const Module *module) {
153     for (const Node *child : module->getChildren())
154       visit(child);
155   }
156 
157   function_ref<void(const Node *)> visitFn;
158   SmallPtrSet<const Node *, 16> alreadyVisited;
159 };
160 } // namespace
161 
162 void Node::walk(function_ref<void(const Node *)> walkFn) const {
163   return NodeVisitor(walkFn).visit(this);
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // DeclScope
168 //===----------------------------------------------------------------------===//
169 
170 void DeclScope::add(Decl *decl) {
171   const Name *name = decl->getName();
172   assert(name && "expected a named decl");
173   assert(!decls.count(name->getName()) && "decl with this name already exists");
174   decls.try_emplace(name->getName(), decl);
175 }
176 
177 Decl *DeclScope::lookup(StringRef name) {
178   if (Decl *decl = decls.lookup(name))
179     return decl;
180   return parent ? parent->lookup(name) : nullptr;
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // CompoundStmt
185 //===----------------------------------------------------------------------===//
186 
187 CompoundStmt *CompoundStmt::create(Context &ctx, SMRange loc,
188                                    ArrayRef<Stmt *> children) {
189   unsigned allocSize = CompoundStmt::totalSizeToAlloc<Stmt *>(children.size());
190   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CompoundStmt));
191 
192   CompoundStmt *stmt = new (rawData) CompoundStmt(loc, children.size());
193   std::uninitialized_copy(children.begin(), children.end(),
194                           stmt->getChildren().begin());
195   return stmt;
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // LetStmt
200 //===----------------------------------------------------------------------===//
201 
202 LetStmt *LetStmt::create(Context &ctx, SMRange loc,
203                          VariableDecl *varDecl) {
204   return new (ctx.getAllocator().Allocate<LetStmt>()) LetStmt(loc, varDecl);
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // OpRewriteStmt
209 //===----------------------------------------------------------------------===//
210 
211 //===----------------------------------------------------------------------===//
212 // EraseStmt
213 
214 EraseStmt *EraseStmt::create(Context &ctx, SMRange loc, Expr *rootOp) {
215   return new (ctx.getAllocator().Allocate<EraseStmt>()) EraseStmt(loc, rootOp);
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // ReplaceStmt
220 
221 ReplaceStmt *ReplaceStmt::create(Context &ctx, SMRange loc, Expr *rootOp,
222                                  ArrayRef<Expr *> replExprs) {
223   unsigned allocSize = ReplaceStmt::totalSizeToAlloc<Expr *>(replExprs.size());
224   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(ReplaceStmt));
225 
226   ReplaceStmt *stmt = new (rawData) ReplaceStmt(loc, rootOp, replExprs.size());
227   std::uninitialized_copy(replExprs.begin(), replExprs.end(),
228                           stmt->getReplExprs().begin());
229   return stmt;
230 }
231 
232 //===----------------------------------------------------------------------===//
233 // RewriteStmt
234 
235 RewriteStmt *RewriteStmt::create(Context &ctx, SMRange loc, Expr *rootOp,
236                                  CompoundStmt *rewriteBody) {
237   return new (ctx.getAllocator().Allocate<RewriteStmt>())
238       RewriteStmt(loc, rootOp, rewriteBody);
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // ReturnStmt
243 //===----------------------------------------------------------------------===//
244 
245 ReturnStmt *ReturnStmt::create(Context &ctx, SMRange loc, Expr *resultExpr) {
246   return new (ctx.getAllocator().Allocate<ReturnStmt>())
247       ReturnStmt(loc, resultExpr);
248 }
249 
250 //===----------------------------------------------------------------------===//
251 // AttributeExpr
252 //===----------------------------------------------------------------------===//
253 
254 AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc,
255                                      StringRef value) {
256   return new (ctx.getAllocator().Allocate<AttributeExpr>())
257       AttributeExpr(ctx, loc, copyStringWithNull(ctx, value));
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // CallExpr
262 //===----------------------------------------------------------------------===//
263 
264 CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable,
265                            ArrayRef<Expr *> arguments, Type resultType) {
266   unsigned allocSize = CallExpr::totalSizeToAlloc<Expr *>(arguments.size());
267   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr));
268 
269   CallExpr *expr =
270       new (rawData) CallExpr(loc, resultType, callable, arguments.size());
271   std::uninitialized_copy(arguments.begin(), arguments.end(),
272                           expr->getArguments().begin());
273   return expr;
274 }
275 
276 //===----------------------------------------------------------------------===//
277 // DeclRefExpr
278 //===----------------------------------------------------------------------===//
279 
280 DeclRefExpr *DeclRefExpr::create(Context &ctx, SMRange loc, Decl *decl,
281                                  Type type) {
282   return new (ctx.getAllocator().Allocate<DeclRefExpr>())
283       DeclRefExpr(loc, decl, type);
284 }
285 
286 //===----------------------------------------------------------------------===//
287 // MemberAccessExpr
288 //===----------------------------------------------------------------------===//
289 
290 MemberAccessExpr *MemberAccessExpr::create(Context &ctx, SMRange loc,
291                                            const Expr *parentExpr,
292                                            StringRef memberName, Type type) {
293   return new (ctx.getAllocator().Allocate<MemberAccessExpr>()) MemberAccessExpr(
294       loc, parentExpr, memberName.copy(ctx.getAllocator()), type);
295 }
296 
297 //===----------------------------------------------------------------------===//
298 // OperationExpr
299 //===----------------------------------------------------------------------===//
300 
301 OperationExpr *
302 OperationExpr::create(Context &ctx, SMRange loc, const ods::Operation *odsOp,
303                       const OpNameDecl *name, ArrayRef<Expr *> operands,
304                       ArrayRef<Expr *> resultTypes,
305                       ArrayRef<NamedAttributeDecl *> attributes) {
306   unsigned allocSize =
307       OperationExpr::totalSizeToAlloc<Expr *, NamedAttributeDecl *>(
308           operands.size() + resultTypes.size(), attributes.size());
309   void *rawData =
310       ctx.getAllocator().Allocate(allocSize, alignof(OperationExpr));
311 
312   Type resultType = OperationType::get(ctx, name->getName(), odsOp);
313   OperationExpr *opExpr = new (rawData)
314       OperationExpr(loc, resultType, name, operands.size(), resultTypes.size(),
315                     attributes.size(), name->getLoc());
316   std::uninitialized_copy(operands.begin(), operands.end(),
317                           opExpr->getOperands().begin());
318   std::uninitialized_copy(resultTypes.begin(), resultTypes.end(),
319                           opExpr->getResultTypes().begin());
320   std::uninitialized_copy(attributes.begin(), attributes.end(),
321                           opExpr->getAttributes().begin());
322   return opExpr;
323 }
324 
325 Optional<StringRef> OperationExpr::getName() const {
326   return getNameDecl()->getName();
327 }
328 
329 //===----------------------------------------------------------------------===//
330 // TupleExpr
331 //===----------------------------------------------------------------------===//
332 
333 TupleExpr *TupleExpr::create(Context &ctx, SMRange loc,
334                              ArrayRef<Expr *> elements,
335                              ArrayRef<StringRef> names) {
336   unsigned allocSize = TupleExpr::totalSizeToAlloc<Expr *>(elements.size());
337   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(TupleExpr));
338 
339   auto elementTypes = llvm::map_range(
340       elements, [](const Expr *expr) { return expr->getType(); });
341   TupleType type = TupleType::get(ctx, llvm::to_vector(elementTypes), names);
342 
343   TupleExpr *expr = new (rawData) TupleExpr(loc, type);
344   std::uninitialized_copy(elements.begin(), elements.end(),
345                           expr->getElements().begin());
346   return expr;
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // TypeExpr
351 //===----------------------------------------------------------------------===//
352 
353 TypeExpr *TypeExpr::create(Context &ctx, SMRange loc, StringRef value) {
354   return new (ctx.getAllocator().Allocate<TypeExpr>())
355       TypeExpr(ctx, loc, copyStringWithNull(ctx, value));
356 }
357 
358 //===----------------------------------------------------------------------===//
359 // Decl
360 //===----------------------------------------------------------------------===//
361 
362 void Decl::setDocComment(Context &ctx, StringRef comment) {
363   docComment = comment.copy(ctx.getAllocator());
364 }
365 
366 //===----------------------------------------------------------------------===//
367 // AttrConstraintDecl
368 //===----------------------------------------------------------------------===//
369 
370 AttrConstraintDecl *AttrConstraintDecl::create(Context &ctx, SMRange loc,
371                                                Expr *typeExpr) {
372   return new (ctx.getAllocator().Allocate<AttrConstraintDecl>())
373       AttrConstraintDecl(loc, typeExpr);
374 }
375 
376 //===----------------------------------------------------------------------===//
377 // OpConstraintDecl
378 //===----------------------------------------------------------------------===//
379 
380 OpConstraintDecl *OpConstraintDecl::create(Context &ctx, SMRange loc,
381                                            const OpNameDecl *nameDecl) {
382   if (!nameDecl)
383     nameDecl = OpNameDecl::create(ctx, SMRange());
384 
385   return new (ctx.getAllocator().Allocate<OpConstraintDecl>())
386       OpConstraintDecl(loc, nameDecl);
387 }
388 
389 Optional<StringRef> OpConstraintDecl::getName() const {
390   return getNameDecl()->getName();
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // TypeConstraintDecl
395 //===----------------------------------------------------------------------===//
396 
397 TypeConstraintDecl *TypeConstraintDecl::create(Context &ctx,
398                                                SMRange loc) {
399   return new (ctx.getAllocator().Allocate<TypeConstraintDecl>())
400       TypeConstraintDecl(loc);
401 }
402 
403 //===----------------------------------------------------------------------===//
404 // TypeRangeConstraintDecl
405 //===----------------------------------------------------------------------===//
406 
407 TypeRangeConstraintDecl *TypeRangeConstraintDecl::create(Context &ctx,
408                                                          SMRange loc) {
409   return new (ctx.getAllocator().Allocate<TypeRangeConstraintDecl>())
410       TypeRangeConstraintDecl(loc);
411 }
412 
413 //===----------------------------------------------------------------------===//
414 // ValueConstraintDecl
415 //===----------------------------------------------------------------------===//
416 
417 ValueConstraintDecl *
418 ValueConstraintDecl::create(Context &ctx, SMRange loc, Expr *typeExpr) {
419   return new (ctx.getAllocator().Allocate<ValueConstraintDecl>())
420       ValueConstraintDecl(loc, typeExpr);
421 }
422 
423 //===----------------------------------------------------------------------===//
424 // ValueRangeConstraintDecl
425 //===----------------------------------------------------------------------===//
426 
427 ValueRangeConstraintDecl *ValueRangeConstraintDecl::create(Context &ctx,
428                                                            SMRange loc,
429                                                            Expr *typeExpr) {
430   return new (ctx.getAllocator().Allocate<ValueRangeConstraintDecl>())
431       ValueRangeConstraintDecl(loc, typeExpr);
432 }
433 
434 //===----------------------------------------------------------------------===//
435 // UserConstraintDecl
436 //===----------------------------------------------------------------------===//
437 
438 Optional<StringRef>
439 UserConstraintDecl::getNativeInputType(unsigned index) const {
440   return hasNativeInputTypes ? getTrailingObjects<StringRef>()[index]
441                              : Optional<StringRef>();
442 }
443 
444 UserConstraintDecl *UserConstraintDecl::createImpl(
445     Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
446     ArrayRef<StringRef> nativeInputTypes, ArrayRef<VariableDecl *> results,
447     Optional<StringRef> codeBlock, const CompoundStmt *body, Type resultType) {
448   bool hasNativeInputTypes = !nativeInputTypes.empty();
449   assert(!hasNativeInputTypes || nativeInputTypes.size() == inputs.size());
450 
451   unsigned allocSize =
452       UserConstraintDecl::totalSizeToAlloc<VariableDecl *, StringRef>(
453           inputs.size() + results.size(),
454           hasNativeInputTypes ? inputs.size() : 0);
455   void *rawData =
456       ctx.getAllocator().Allocate(allocSize, alignof(UserConstraintDecl));
457   if (codeBlock)
458     codeBlock = codeBlock->copy(ctx.getAllocator());
459 
460   UserConstraintDecl *decl = new (rawData)
461       UserConstraintDecl(name, inputs.size(), hasNativeInputTypes,
462                          results.size(), codeBlock, body, resultType);
463   std::uninitialized_copy(inputs.begin(), inputs.end(),
464                           decl->getInputs().begin());
465   std::uninitialized_copy(results.begin(), results.end(),
466                           decl->getResults().begin());
467   if (hasNativeInputTypes) {
468     StringRef *nativeInputTypesPtr = decl->getTrailingObjects<StringRef>();
469     for (unsigned i = 0, e = inputs.size(); i < e; ++i)
470       nativeInputTypesPtr[i] = nativeInputTypes[i].copy(ctx.getAllocator());
471   }
472 
473   return decl;
474 }
475 
476 //===----------------------------------------------------------------------===//
477 // NamedAttributeDecl
478 //===----------------------------------------------------------------------===//
479 
480 NamedAttributeDecl *NamedAttributeDecl::create(Context &ctx, const Name &name,
481                                                Expr *value) {
482   return new (ctx.getAllocator().Allocate<NamedAttributeDecl>())
483       NamedAttributeDecl(name, value);
484 }
485 
486 //===----------------------------------------------------------------------===//
487 // OpNameDecl
488 //===----------------------------------------------------------------------===//
489 
490 OpNameDecl *OpNameDecl::create(Context &ctx, const Name &name) {
491   return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(name);
492 }
493 OpNameDecl *OpNameDecl::create(Context &ctx, SMRange loc) {
494   return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(loc);
495 }
496 
497 //===----------------------------------------------------------------------===//
498 // PatternDecl
499 //===----------------------------------------------------------------------===//
500 
501 PatternDecl *PatternDecl::create(Context &ctx, SMRange loc,
502                                  const Name *name, Optional<uint16_t> benefit,
503                                  bool hasBoundedRecursion,
504                                  const CompoundStmt *body) {
505   return new (ctx.getAllocator().Allocate<PatternDecl>())
506       PatternDecl(loc, name, benefit, hasBoundedRecursion, body);
507 }
508 
509 //===----------------------------------------------------------------------===//
510 // UserRewriteDecl
511 //===----------------------------------------------------------------------===//
512 
513 UserRewriteDecl *UserRewriteDecl::createImpl(Context &ctx, const Name &name,
514                                              ArrayRef<VariableDecl *> inputs,
515                                              ArrayRef<VariableDecl *> results,
516                                              Optional<StringRef> codeBlock,
517                                              const CompoundStmt *body,
518                                              Type resultType) {
519   unsigned allocSize = UserRewriteDecl::totalSizeToAlloc<VariableDecl *>(
520       inputs.size() + results.size());
521   void *rawData =
522       ctx.getAllocator().Allocate(allocSize, alignof(UserRewriteDecl));
523   if (codeBlock)
524     codeBlock = codeBlock->copy(ctx.getAllocator());
525 
526   UserRewriteDecl *decl = new (rawData) UserRewriteDecl(
527       name, inputs.size(), results.size(), codeBlock, body, resultType);
528   std::uninitialized_copy(inputs.begin(), inputs.end(),
529                           decl->getInputs().begin());
530   std::uninitialized_copy(results.begin(), results.end(),
531                           decl->getResults().begin());
532   return decl;
533 }
534 
535 //===----------------------------------------------------------------------===//
536 // VariableDecl
537 //===----------------------------------------------------------------------===//
538 
539 VariableDecl *VariableDecl::create(Context &ctx, const Name &name, Type type,
540                                    Expr *initExpr,
541                                    ArrayRef<ConstraintRef> constraints) {
542   unsigned allocSize =
543       VariableDecl::totalSizeToAlloc<ConstraintRef>(constraints.size());
544   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(VariableDecl));
545 
546   VariableDecl *varDecl =
547       new (rawData) VariableDecl(name, type, initExpr, constraints.size());
548   std::uninitialized_copy(constraints.begin(), constraints.end(),
549                           varDecl->getConstraints().begin());
550   return varDecl;
551 }
552 
553 //===----------------------------------------------------------------------===//
554 // Module
555 //===----------------------------------------------------------------------===//
556 
557 Module *Module::create(Context &ctx, SMLoc loc,
558                        ArrayRef<Decl *> children) {
559   unsigned allocSize = Module::totalSizeToAlloc<Decl *>(children.size());
560   void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(Module));
561 
562   Module *module = new (rawData) Module(loc, children.size());
563   std::uninitialized_copy(children.begin(), children.end(),
564                           module->getChildren().begin());
565   return module;
566 }
567