1 //===- Nodes.cpp ----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Tools/PDLL/AST/Nodes.h"
10 #include "mlir/Tools/PDLL/AST/Context.h"
11 #include "llvm/ADT/SmallPtrSet.h"
12 #include "llvm/ADT/TypeSwitch.h"
13
14 using namespace mlir;
15 using namespace mlir::pdll::ast;
16
17 /// Copy a string reference into the context with a null terminator.
copyStringWithNull(Context & ctx,StringRef str)18 static StringRef copyStringWithNull(Context &ctx, StringRef str) {
19 if (str.empty())
20 return str;
21
22 char *data = ctx.getAllocator().Allocate<char>(str.size() + 1);
23 std::copy(str.begin(), str.end(), data);
24 data[str.size()] = 0;
25 return StringRef(data, str.size());
26 }
27
28 //===----------------------------------------------------------------------===//
29 // Name
30 //===----------------------------------------------------------------------===//
31
create(Context & ctx,StringRef name,SMRange location)32 const Name &Name::create(Context &ctx, StringRef name, SMRange location) {
33 return *new (ctx.getAllocator().Allocate<Name>())
34 Name(copyStringWithNull(ctx, name), location);
35 }
36
37 //===----------------------------------------------------------------------===//
38 // Node
39 //===----------------------------------------------------------------------===//
40
41 namespace {
42 class NodeVisitor {
43 public:
NodeVisitor(function_ref<void (const Node *)> visitFn)44 explicit NodeVisitor(function_ref<void(const Node *)> visitFn)
45 : visitFn(visitFn) {}
46
visit(const Node * node)47 void visit(const Node *node) {
48 if (!node || !alreadyVisited.insert(node).second)
49 return;
50
51 visitFn(node);
52 TypeSwitch<const Node *>(node)
53 .Case<
54 // Statements.
55 const CompoundStmt, const EraseStmt, const LetStmt,
56 const ReplaceStmt, const ReturnStmt, const RewriteStmt,
57
58 // Expressions.
59 const AttributeExpr, const CallExpr, const DeclRefExpr,
60 const MemberAccessExpr, const OperationExpr, const TupleExpr,
61 const TypeExpr,
62
63 // Core Constraint Decls.
64 const AttrConstraintDecl, const OpConstraintDecl,
65 const TypeConstraintDecl, const TypeRangeConstraintDecl,
66 const ValueConstraintDecl, const ValueRangeConstraintDecl,
67
68 // Decls.
69 const NamedAttributeDecl, const OpNameDecl, const PatternDecl,
70 const UserConstraintDecl, const UserRewriteDecl, const VariableDecl,
71
72 const Module>(
73 [&](auto derivedNode) { this->visitImpl(derivedNode); })
74 .Default([](const Node *) { llvm_unreachable("unknown AST node"); });
75 }
76
77 private:
visitImpl(const CompoundStmt * stmt)78 void visitImpl(const CompoundStmt *stmt) {
79 for (const Node *child : stmt->getChildren())
80 visit(child);
81 }
visitImpl(const EraseStmt * stmt)82 void visitImpl(const EraseStmt *stmt) { visit(stmt->getRootOpExpr()); }
visitImpl(const LetStmt * stmt)83 void visitImpl(const LetStmt *stmt) { visit(stmt->getVarDecl()); }
visitImpl(const ReplaceStmt * stmt)84 void visitImpl(const ReplaceStmt *stmt) {
85 visit(stmt->getRootOpExpr());
86 for (const Node *child : stmt->getReplExprs())
87 visit(child);
88 }
visitImpl(const ReturnStmt * stmt)89 void visitImpl(const ReturnStmt *stmt) { visit(stmt->getResultExpr()); }
visitImpl(const RewriteStmt * stmt)90 void visitImpl(const RewriteStmt *stmt) {
91 visit(stmt->getRootOpExpr());
92 visit(stmt->getRewriteBody());
93 }
94
visitImpl(const AttributeExpr * expr)95 void visitImpl(const AttributeExpr *expr) {}
visitImpl(const CallExpr * expr)96 void visitImpl(const CallExpr *expr) {
97 visit(expr->getCallableExpr());
98 for (const Node *child : expr->getArguments())
99 visit(child);
100 }
visitImpl(const DeclRefExpr * expr)101 void visitImpl(const DeclRefExpr *expr) { visit(expr->getDecl()); }
visitImpl(const MemberAccessExpr * expr)102 void visitImpl(const MemberAccessExpr *expr) { visit(expr->getParentExpr()); }
visitImpl(const OperationExpr * expr)103 void visitImpl(const OperationExpr *expr) {
104 visit(expr->getNameDecl());
105 for (const Node *child : expr->getOperands())
106 visit(child);
107 for (const Node *child : expr->getResultTypes())
108 visit(child);
109 for (const Node *child : expr->getAttributes())
110 visit(child);
111 }
visitImpl(const TupleExpr * expr)112 void visitImpl(const TupleExpr *expr) {
113 for (const Node *child : expr->getElements())
114 visit(child);
115 }
visitImpl(const TypeExpr * expr)116 void visitImpl(const TypeExpr *expr) {}
117
visitImpl(const AttrConstraintDecl * decl)118 void visitImpl(const AttrConstraintDecl *decl) { visit(decl->getTypeExpr()); }
visitImpl(const OpConstraintDecl * decl)119 void visitImpl(const OpConstraintDecl *decl) { visit(decl->getNameDecl()); }
visitImpl(const TypeConstraintDecl * decl)120 void visitImpl(const TypeConstraintDecl *decl) {}
visitImpl(const TypeRangeConstraintDecl * decl)121 void visitImpl(const TypeRangeConstraintDecl *decl) {}
visitImpl(const ValueConstraintDecl * decl)122 void visitImpl(const ValueConstraintDecl *decl) {
123 visit(decl->getTypeExpr());
124 }
visitImpl(const ValueRangeConstraintDecl * decl)125 void visitImpl(const ValueRangeConstraintDecl *decl) {
126 visit(decl->getTypeExpr());
127 }
128
visitImpl(const NamedAttributeDecl * decl)129 void visitImpl(const NamedAttributeDecl *decl) { visit(decl->getValue()); }
visitImpl(const OpNameDecl * decl)130 void visitImpl(const OpNameDecl *decl) {}
visitImpl(const PatternDecl * decl)131 void visitImpl(const PatternDecl *decl) { visit(decl->getBody()); }
visitImpl(const UserConstraintDecl * decl)132 void visitImpl(const UserConstraintDecl *decl) {
133 for (const Node *child : decl->getInputs())
134 visit(child);
135 for (const Node *child : decl->getResults())
136 visit(child);
137 visit(decl->getBody());
138 }
visitImpl(const UserRewriteDecl * decl)139 void visitImpl(const UserRewriteDecl *decl) {
140 for (const Node *child : decl->getInputs())
141 visit(child);
142 for (const Node *child : decl->getResults())
143 visit(child);
144 visit(decl->getBody());
145 }
visitImpl(const VariableDecl * decl)146 void visitImpl(const VariableDecl *decl) {
147 visit(decl->getInitExpr());
148 for (const ConstraintRef &child : decl->getConstraints())
149 visit(child.constraint);
150 }
151
visitImpl(const Module * module)152 void visitImpl(const Module *module) {
153 for (const Node *child : module->getChildren())
154 visit(child);
155 }
156
157 function_ref<void(const Node *)> visitFn;
158 SmallPtrSet<const Node *, 16> alreadyVisited;
159 };
160 } // namespace
161
walk(function_ref<void (const Node *)> walkFn) const162 void Node::walk(function_ref<void(const Node *)> walkFn) const {
163 return NodeVisitor(walkFn).visit(this);
164 }
165
166 //===----------------------------------------------------------------------===//
167 // DeclScope
168 //===----------------------------------------------------------------------===//
169
add(Decl * decl)170 void DeclScope::add(Decl *decl) {
171 const Name *name = decl->getName();
172 assert(name && "expected a named decl");
173 assert(!decls.count(name->getName()) && "decl with this name already exists");
174 decls.try_emplace(name->getName(), decl);
175 }
176
lookup(StringRef name)177 Decl *DeclScope::lookup(StringRef name) {
178 if (Decl *decl = decls.lookup(name))
179 return decl;
180 return parent ? parent->lookup(name) : nullptr;
181 }
182
183 //===----------------------------------------------------------------------===//
184 // CompoundStmt
185 //===----------------------------------------------------------------------===//
186
create(Context & ctx,SMRange loc,ArrayRef<Stmt * > children)187 CompoundStmt *CompoundStmt::create(Context &ctx, SMRange loc,
188 ArrayRef<Stmt *> children) {
189 unsigned allocSize = CompoundStmt::totalSizeToAlloc<Stmt *>(children.size());
190 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CompoundStmt));
191
192 CompoundStmt *stmt = new (rawData) CompoundStmt(loc, children.size());
193 std::uninitialized_copy(children.begin(), children.end(),
194 stmt->getChildren().begin());
195 return stmt;
196 }
197
198 //===----------------------------------------------------------------------===//
199 // LetStmt
200 //===----------------------------------------------------------------------===//
201
create(Context & ctx,SMRange loc,VariableDecl * varDecl)202 LetStmt *LetStmt::create(Context &ctx, SMRange loc, VariableDecl *varDecl) {
203 return new (ctx.getAllocator().Allocate<LetStmt>()) LetStmt(loc, varDecl);
204 }
205
206 //===----------------------------------------------------------------------===//
207 // OpRewriteStmt
208 //===----------------------------------------------------------------------===//
209
210 //===----------------------------------------------------------------------===//
211 // EraseStmt
212
create(Context & ctx,SMRange loc,Expr * rootOp)213 EraseStmt *EraseStmt::create(Context &ctx, SMRange loc, Expr *rootOp) {
214 return new (ctx.getAllocator().Allocate<EraseStmt>()) EraseStmt(loc, rootOp);
215 }
216
217 //===----------------------------------------------------------------------===//
218 // ReplaceStmt
219
create(Context & ctx,SMRange loc,Expr * rootOp,ArrayRef<Expr * > replExprs)220 ReplaceStmt *ReplaceStmt::create(Context &ctx, SMRange loc, Expr *rootOp,
221 ArrayRef<Expr *> replExprs) {
222 unsigned allocSize = ReplaceStmt::totalSizeToAlloc<Expr *>(replExprs.size());
223 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(ReplaceStmt));
224
225 ReplaceStmt *stmt = new (rawData) ReplaceStmt(loc, rootOp, replExprs.size());
226 std::uninitialized_copy(replExprs.begin(), replExprs.end(),
227 stmt->getReplExprs().begin());
228 return stmt;
229 }
230
231 //===----------------------------------------------------------------------===//
232 // RewriteStmt
233
create(Context & ctx,SMRange loc,Expr * rootOp,CompoundStmt * rewriteBody)234 RewriteStmt *RewriteStmt::create(Context &ctx, SMRange loc, Expr *rootOp,
235 CompoundStmt *rewriteBody) {
236 return new (ctx.getAllocator().Allocate<RewriteStmt>())
237 RewriteStmt(loc, rootOp, rewriteBody);
238 }
239
240 //===----------------------------------------------------------------------===//
241 // ReturnStmt
242 //===----------------------------------------------------------------------===//
243
create(Context & ctx,SMRange loc,Expr * resultExpr)244 ReturnStmt *ReturnStmt::create(Context &ctx, SMRange loc, Expr *resultExpr) {
245 return new (ctx.getAllocator().Allocate<ReturnStmt>())
246 ReturnStmt(loc, resultExpr);
247 }
248
249 //===----------------------------------------------------------------------===//
250 // AttributeExpr
251 //===----------------------------------------------------------------------===//
252
create(Context & ctx,SMRange loc,StringRef value)253 AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc,
254 StringRef value) {
255 return new (ctx.getAllocator().Allocate<AttributeExpr>())
256 AttributeExpr(ctx, loc, copyStringWithNull(ctx, value));
257 }
258
259 //===----------------------------------------------------------------------===//
260 // CallExpr
261 //===----------------------------------------------------------------------===//
262
create(Context & ctx,SMRange loc,Expr * callable,ArrayRef<Expr * > arguments,Type resultType)263 CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable,
264 ArrayRef<Expr *> arguments, Type resultType) {
265 unsigned allocSize = CallExpr::totalSizeToAlloc<Expr *>(arguments.size());
266 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr));
267
268 CallExpr *expr =
269 new (rawData) CallExpr(loc, resultType, callable, arguments.size());
270 std::uninitialized_copy(arguments.begin(), arguments.end(),
271 expr->getArguments().begin());
272 return expr;
273 }
274
275 //===----------------------------------------------------------------------===//
276 // DeclRefExpr
277 //===----------------------------------------------------------------------===//
278
create(Context & ctx,SMRange loc,Decl * decl,Type type)279 DeclRefExpr *DeclRefExpr::create(Context &ctx, SMRange loc, Decl *decl,
280 Type type) {
281 return new (ctx.getAllocator().Allocate<DeclRefExpr>())
282 DeclRefExpr(loc, decl, type);
283 }
284
285 //===----------------------------------------------------------------------===//
286 // MemberAccessExpr
287 //===----------------------------------------------------------------------===//
288
create(Context & ctx,SMRange loc,const Expr * parentExpr,StringRef memberName,Type type)289 MemberAccessExpr *MemberAccessExpr::create(Context &ctx, SMRange loc,
290 const Expr *parentExpr,
291 StringRef memberName, Type type) {
292 return new (ctx.getAllocator().Allocate<MemberAccessExpr>()) MemberAccessExpr(
293 loc, parentExpr, memberName.copy(ctx.getAllocator()), type);
294 }
295
296 //===----------------------------------------------------------------------===//
297 // OperationExpr
298 //===----------------------------------------------------------------------===//
299
300 OperationExpr *
create(Context & ctx,SMRange loc,const ods::Operation * odsOp,const OpNameDecl * name,ArrayRef<Expr * > operands,ArrayRef<Expr * > resultTypes,ArrayRef<NamedAttributeDecl * > attributes)301 OperationExpr::create(Context &ctx, SMRange loc, const ods::Operation *odsOp,
302 const OpNameDecl *name, ArrayRef<Expr *> operands,
303 ArrayRef<Expr *> resultTypes,
304 ArrayRef<NamedAttributeDecl *> attributes) {
305 unsigned allocSize =
306 OperationExpr::totalSizeToAlloc<Expr *, NamedAttributeDecl *>(
307 operands.size() + resultTypes.size(), attributes.size());
308 void *rawData =
309 ctx.getAllocator().Allocate(allocSize, alignof(OperationExpr));
310
311 Type resultType = OperationType::get(ctx, name->getName(), odsOp);
312 OperationExpr *opExpr = new (rawData)
313 OperationExpr(loc, resultType, name, operands.size(), resultTypes.size(),
314 attributes.size(), name->getLoc());
315 std::uninitialized_copy(operands.begin(), operands.end(),
316 opExpr->getOperands().begin());
317 std::uninitialized_copy(resultTypes.begin(), resultTypes.end(),
318 opExpr->getResultTypes().begin());
319 std::uninitialized_copy(attributes.begin(), attributes.end(),
320 opExpr->getAttributes().begin());
321 return opExpr;
322 }
323
getName() const324 Optional<StringRef> OperationExpr::getName() const {
325 return getNameDecl()->getName();
326 }
327
328 //===----------------------------------------------------------------------===//
329 // TupleExpr
330 //===----------------------------------------------------------------------===//
331
create(Context & ctx,SMRange loc,ArrayRef<Expr * > elements,ArrayRef<StringRef> names)332 TupleExpr *TupleExpr::create(Context &ctx, SMRange loc,
333 ArrayRef<Expr *> elements,
334 ArrayRef<StringRef> names) {
335 unsigned allocSize = TupleExpr::totalSizeToAlloc<Expr *>(elements.size());
336 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(TupleExpr));
337
338 auto elementTypes = llvm::map_range(
339 elements, [](const Expr *expr) { return expr->getType(); });
340 TupleType type = TupleType::get(ctx, llvm::to_vector(elementTypes), names);
341
342 TupleExpr *expr = new (rawData) TupleExpr(loc, type);
343 std::uninitialized_copy(elements.begin(), elements.end(),
344 expr->getElements().begin());
345 return expr;
346 }
347
348 //===----------------------------------------------------------------------===//
349 // TypeExpr
350 //===----------------------------------------------------------------------===//
351
create(Context & ctx,SMRange loc,StringRef value)352 TypeExpr *TypeExpr::create(Context &ctx, SMRange loc, StringRef value) {
353 return new (ctx.getAllocator().Allocate<TypeExpr>())
354 TypeExpr(ctx, loc, copyStringWithNull(ctx, value));
355 }
356
357 //===----------------------------------------------------------------------===//
358 // Decl
359 //===----------------------------------------------------------------------===//
360
setDocComment(Context & ctx,StringRef comment)361 void Decl::setDocComment(Context &ctx, StringRef comment) {
362 docComment = comment.copy(ctx.getAllocator());
363 }
364
365 //===----------------------------------------------------------------------===//
366 // AttrConstraintDecl
367 //===----------------------------------------------------------------------===//
368
create(Context & ctx,SMRange loc,Expr * typeExpr)369 AttrConstraintDecl *AttrConstraintDecl::create(Context &ctx, SMRange loc,
370 Expr *typeExpr) {
371 return new (ctx.getAllocator().Allocate<AttrConstraintDecl>())
372 AttrConstraintDecl(loc, typeExpr);
373 }
374
375 //===----------------------------------------------------------------------===//
376 // OpConstraintDecl
377 //===----------------------------------------------------------------------===//
378
create(Context & ctx,SMRange loc,const OpNameDecl * nameDecl)379 OpConstraintDecl *OpConstraintDecl::create(Context &ctx, SMRange loc,
380 const OpNameDecl *nameDecl) {
381 if (!nameDecl)
382 nameDecl = OpNameDecl::create(ctx, SMRange());
383
384 return new (ctx.getAllocator().Allocate<OpConstraintDecl>())
385 OpConstraintDecl(loc, nameDecl);
386 }
387
getName() const388 Optional<StringRef> OpConstraintDecl::getName() const {
389 return getNameDecl()->getName();
390 }
391
392 //===----------------------------------------------------------------------===//
393 // TypeConstraintDecl
394 //===----------------------------------------------------------------------===//
395
create(Context & ctx,SMRange loc)396 TypeConstraintDecl *TypeConstraintDecl::create(Context &ctx, SMRange loc) {
397 return new (ctx.getAllocator().Allocate<TypeConstraintDecl>())
398 TypeConstraintDecl(loc);
399 }
400
401 //===----------------------------------------------------------------------===//
402 // TypeRangeConstraintDecl
403 //===----------------------------------------------------------------------===//
404
create(Context & ctx,SMRange loc)405 TypeRangeConstraintDecl *TypeRangeConstraintDecl::create(Context &ctx,
406 SMRange loc) {
407 return new (ctx.getAllocator().Allocate<TypeRangeConstraintDecl>())
408 TypeRangeConstraintDecl(loc);
409 }
410
411 //===----------------------------------------------------------------------===//
412 // ValueConstraintDecl
413 //===----------------------------------------------------------------------===//
414
create(Context & ctx,SMRange loc,Expr * typeExpr)415 ValueConstraintDecl *ValueConstraintDecl::create(Context &ctx, SMRange loc,
416 Expr *typeExpr) {
417 return new (ctx.getAllocator().Allocate<ValueConstraintDecl>())
418 ValueConstraintDecl(loc, typeExpr);
419 }
420
421 //===----------------------------------------------------------------------===//
422 // ValueRangeConstraintDecl
423 //===----------------------------------------------------------------------===//
424
425 ValueRangeConstraintDecl *
create(Context & ctx,SMRange loc,Expr * typeExpr)426 ValueRangeConstraintDecl::create(Context &ctx, SMRange loc, Expr *typeExpr) {
427 return new (ctx.getAllocator().Allocate<ValueRangeConstraintDecl>())
428 ValueRangeConstraintDecl(loc, typeExpr);
429 }
430
431 //===----------------------------------------------------------------------===//
432 // UserConstraintDecl
433 //===----------------------------------------------------------------------===//
434
435 Optional<StringRef>
getNativeInputType(unsigned index) const436 UserConstraintDecl::getNativeInputType(unsigned index) const {
437 return hasNativeInputTypes ? getTrailingObjects<StringRef>()[index]
438 : Optional<StringRef>();
439 }
440
createImpl(Context & ctx,const Name & name,ArrayRef<VariableDecl * > inputs,ArrayRef<StringRef> nativeInputTypes,ArrayRef<VariableDecl * > results,Optional<StringRef> codeBlock,const CompoundStmt * body,Type resultType)441 UserConstraintDecl *UserConstraintDecl::createImpl(
442 Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
443 ArrayRef<StringRef> nativeInputTypes, ArrayRef<VariableDecl *> results,
444 Optional<StringRef> codeBlock, const CompoundStmt *body, Type resultType) {
445 bool hasNativeInputTypes = !nativeInputTypes.empty();
446 assert(!hasNativeInputTypes || nativeInputTypes.size() == inputs.size());
447
448 unsigned allocSize =
449 UserConstraintDecl::totalSizeToAlloc<VariableDecl *, StringRef>(
450 inputs.size() + results.size(),
451 hasNativeInputTypes ? inputs.size() : 0);
452 void *rawData =
453 ctx.getAllocator().Allocate(allocSize, alignof(UserConstraintDecl));
454 if (codeBlock)
455 codeBlock = codeBlock->copy(ctx.getAllocator());
456
457 UserConstraintDecl *decl = new (rawData)
458 UserConstraintDecl(name, inputs.size(), hasNativeInputTypes,
459 results.size(), codeBlock, body, resultType);
460 std::uninitialized_copy(inputs.begin(), inputs.end(),
461 decl->getInputs().begin());
462 std::uninitialized_copy(results.begin(), results.end(),
463 decl->getResults().begin());
464 if (hasNativeInputTypes) {
465 StringRef *nativeInputTypesPtr = decl->getTrailingObjects<StringRef>();
466 for (unsigned i = 0, e = inputs.size(); i < e; ++i)
467 nativeInputTypesPtr[i] = nativeInputTypes[i].copy(ctx.getAllocator());
468 }
469
470 return decl;
471 }
472
473 //===----------------------------------------------------------------------===//
474 // NamedAttributeDecl
475 //===----------------------------------------------------------------------===//
476
create(Context & ctx,const Name & name,Expr * value)477 NamedAttributeDecl *NamedAttributeDecl::create(Context &ctx, const Name &name,
478 Expr *value) {
479 return new (ctx.getAllocator().Allocate<NamedAttributeDecl>())
480 NamedAttributeDecl(name, value);
481 }
482
483 //===----------------------------------------------------------------------===//
484 // OpNameDecl
485 //===----------------------------------------------------------------------===//
486
create(Context & ctx,const Name & name)487 OpNameDecl *OpNameDecl::create(Context &ctx, const Name &name) {
488 return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(name);
489 }
create(Context & ctx,SMRange loc)490 OpNameDecl *OpNameDecl::create(Context &ctx, SMRange loc) {
491 return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(loc);
492 }
493
494 //===----------------------------------------------------------------------===//
495 // PatternDecl
496 //===----------------------------------------------------------------------===//
497
create(Context & ctx,SMRange loc,const Name * name,Optional<uint16_t> benefit,bool hasBoundedRecursion,const CompoundStmt * body)498 PatternDecl *PatternDecl::create(Context &ctx, SMRange loc, const Name *name,
499 Optional<uint16_t> benefit,
500 bool hasBoundedRecursion,
501 const CompoundStmt *body) {
502 return new (ctx.getAllocator().Allocate<PatternDecl>())
503 PatternDecl(loc, name, benefit, hasBoundedRecursion, body);
504 }
505
506 //===----------------------------------------------------------------------===//
507 // UserRewriteDecl
508 //===----------------------------------------------------------------------===//
509
createImpl(Context & ctx,const Name & name,ArrayRef<VariableDecl * > inputs,ArrayRef<VariableDecl * > results,Optional<StringRef> codeBlock,const CompoundStmt * body,Type resultType)510 UserRewriteDecl *UserRewriteDecl::createImpl(Context &ctx, const Name &name,
511 ArrayRef<VariableDecl *> inputs,
512 ArrayRef<VariableDecl *> results,
513 Optional<StringRef> codeBlock,
514 const CompoundStmt *body,
515 Type resultType) {
516 unsigned allocSize = UserRewriteDecl::totalSizeToAlloc<VariableDecl *>(
517 inputs.size() + results.size());
518 void *rawData =
519 ctx.getAllocator().Allocate(allocSize, alignof(UserRewriteDecl));
520 if (codeBlock)
521 codeBlock = codeBlock->copy(ctx.getAllocator());
522
523 UserRewriteDecl *decl = new (rawData) UserRewriteDecl(
524 name, inputs.size(), results.size(), codeBlock, body, resultType);
525 std::uninitialized_copy(inputs.begin(), inputs.end(),
526 decl->getInputs().begin());
527 std::uninitialized_copy(results.begin(), results.end(),
528 decl->getResults().begin());
529 return decl;
530 }
531
532 //===----------------------------------------------------------------------===//
533 // VariableDecl
534 //===----------------------------------------------------------------------===//
535
create(Context & ctx,const Name & name,Type type,Expr * initExpr,ArrayRef<ConstraintRef> constraints)536 VariableDecl *VariableDecl::create(Context &ctx, const Name &name, Type type,
537 Expr *initExpr,
538 ArrayRef<ConstraintRef> constraints) {
539 unsigned allocSize =
540 VariableDecl::totalSizeToAlloc<ConstraintRef>(constraints.size());
541 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(VariableDecl));
542
543 VariableDecl *varDecl =
544 new (rawData) VariableDecl(name, type, initExpr, constraints.size());
545 std::uninitialized_copy(constraints.begin(), constraints.end(),
546 varDecl->getConstraints().begin());
547 return varDecl;
548 }
549
550 //===----------------------------------------------------------------------===//
551 // Module
552 //===----------------------------------------------------------------------===//
553
create(Context & ctx,SMLoc loc,ArrayRef<Decl * > children)554 Module *Module::create(Context &ctx, SMLoc loc, ArrayRef<Decl *> children) {
555 unsigned allocSize = Module::totalSizeToAlloc<Decl *>(children.size());
556 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(Module));
557
558 Module *module = new (rawData) Module(loc, children.size());
559 std::uninitialized_copy(children.begin(), children.end(),
560 module->getChildren().begin());
561 return module;
562 }
563