1 //===- Nodes.cpp ----------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/AST/Nodes.h" 10 #include "mlir/Tools/PDLL/AST/Context.h" 11 #include "llvm/ADT/SmallPtrSet.h" 12 #include "llvm/ADT/TypeSwitch.h" 13 14 using namespace mlir; 15 using namespace mlir::pdll::ast; 16 17 /// Copy a string reference into the context with a null terminator. 18 static StringRef copyStringWithNull(Context &ctx, StringRef str) { 19 if (str.empty()) 20 return str; 21 22 char *data = ctx.getAllocator().Allocate<char>(str.size() + 1); 23 std::copy(str.begin(), str.end(), data); 24 data[str.size()] = 0; 25 return StringRef(data, str.size()); 26 } 27 28 //===----------------------------------------------------------------------===// 29 // Name 30 //===----------------------------------------------------------------------===// 31 32 const Name &Name::create(Context &ctx, StringRef name, SMRange location) { 33 return *new (ctx.getAllocator().Allocate<Name>()) 34 Name(copyStringWithNull(ctx, name), location); 35 } 36 37 //===----------------------------------------------------------------------===// 38 // Node 39 //===----------------------------------------------------------------------===// 40 41 namespace { 42 class NodeVisitor { 43 public: 44 explicit NodeVisitor(function_ref<void(const Node *)> visitFn) 45 : visitFn(visitFn) {} 46 47 void visit(const Node *node) { 48 if (!node || !alreadyVisited.insert(node).second) 49 return; 50 51 visitFn(node); 52 TypeSwitch<const Node *>(node) 53 .Case< 54 // Statements. 55 const CompoundStmt, const EraseStmt, const LetStmt, 56 const ReplaceStmt, const ReturnStmt, const RewriteStmt, 57 58 // Expressions. 59 const AttributeExpr, const CallExpr, const DeclRefExpr, 60 const MemberAccessExpr, const OperationExpr, const TupleExpr, 61 const TypeExpr, 62 63 // Core Constraint Decls. 64 const AttrConstraintDecl, const OpConstraintDecl, 65 const TypeConstraintDecl, const TypeRangeConstraintDecl, 66 const ValueConstraintDecl, const ValueRangeConstraintDecl, 67 68 // Decls. 69 const NamedAttributeDecl, const OpNameDecl, const PatternDecl, 70 const UserConstraintDecl, const UserRewriteDecl, const VariableDecl, 71 72 const Module>( 73 [&](auto derivedNode) { this->visitImpl(derivedNode); }) 74 .Default([](const Node *) { llvm_unreachable("unknown AST node"); }); 75 } 76 77 private: 78 void visitImpl(const CompoundStmt *stmt) { 79 for (const Node *child : stmt->getChildren()) 80 visit(child); 81 } 82 void visitImpl(const EraseStmt *stmt) { visit(stmt->getRootOpExpr()); } 83 void visitImpl(const LetStmt *stmt) { visit(stmt->getVarDecl()); } 84 void visitImpl(const ReplaceStmt *stmt) { 85 visit(stmt->getRootOpExpr()); 86 for (const Node *child : stmt->getReplExprs()) 87 visit(child); 88 } 89 void visitImpl(const ReturnStmt *stmt) { visit(stmt->getResultExpr()); } 90 void visitImpl(const RewriteStmt *stmt) { 91 visit(stmt->getRootOpExpr()); 92 visit(stmt->getRewriteBody()); 93 } 94 95 void visitImpl(const AttributeExpr *expr) {} 96 void visitImpl(const CallExpr *expr) { 97 visit(expr->getCallableExpr()); 98 for (const Node *child : expr->getArguments()) 99 visit(child); 100 } 101 void visitImpl(const DeclRefExpr *expr) { visit(expr->getDecl()); } 102 void visitImpl(const MemberAccessExpr *expr) { visit(expr->getParentExpr()); } 103 void visitImpl(const OperationExpr *expr) { 104 visit(expr->getNameDecl()); 105 for (const Node *child : expr->getOperands()) 106 visit(child); 107 for (const Node *child : expr->getResultTypes()) 108 visit(child); 109 for (const Node *child : expr->getAttributes()) 110 visit(child); 111 } 112 void visitImpl(const TupleExpr *expr) { 113 for (const Node *child : expr->getElements()) 114 visit(child); 115 } 116 void visitImpl(const TypeExpr *expr) {} 117 118 void visitImpl(const AttrConstraintDecl *decl) { visit(decl->getTypeExpr()); } 119 void visitImpl(const OpConstraintDecl *decl) { visit(decl->getNameDecl()); } 120 void visitImpl(const TypeConstraintDecl *decl) {} 121 void visitImpl(const TypeRangeConstraintDecl *decl) {} 122 void visitImpl(const ValueConstraintDecl *decl) { 123 visit(decl->getTypeExpr()); 124 } 125 void visitImpl(const ValueRangeConstraintDecl *decl) { 126 visit(decl->getTypeExpr()); 127 } 128 129 void visitImpl(const NamedAttributeDecl *decl) { visit(decl->getValue()); } 130 void visitImpl(const OpNameDecl *decl) {} 131 void visitImpl(const PatternDecl *decl) { visit(decl->getBody()); } 132 void visitImpl(const UserConstraintDecl *decl) { 133 for (const Node *child : decl->getInputs()) 134 visit(child); 135 for (const Node *child : decl->getResults()) 136 visit(child); 137 visit(decl->getBody()); 138 } 139 void visitImpl(const UserRewriteDecl *decl) { 140 for (const Node *child : decl->getInputs()) 141 visit(child); 142 for (const Node *child : decl->getResults()) 143 visit(child); 144 visit(decl->getBody()); 145 } 146 void visitImpl(const VariableDecl *decl) { 147 visit(decl->getInitExpr()); 148 for (const ConstraintRef &child : decl->getConstraints()) 149 visit(child.constraint); 150 } 151 152 void visitImpl(const Module *module) { 153 for (const Node *child : module->getChildren()) 154 visit(child); 155 } 156 157 function_ref<void(const Node *)> visitFn; 158 SmallPtrSet<const Node *, 16> alreadyVisited; 159 }; 160 } // namespace 161 162 void Node::walk(function_ref<void(const Node *)> walkFn) const { 163 return NodeVisitor(walkFn).visit(this); 164 } 165 166 //===----------------------------------------------------------------------===// 167 // DeclScope 168 //===----------------------------------------------------------------------===// 169 170 void DeclScope::add(Decl *decl) { 171 const Name *name = decl->getName(); 172 assert(name && "expected a named decl"); 173 assert(!decls.count(name->getName()) && "decl with this name already exists"); 174 decls.try_emplace(name->getName(), decl); 175 } 176 177 Decl *DeclScope::lookup(StringRef name) { 178 if (Decl *decl = decls.lookup(name)) 179 return decl; 180 return parent ? parent->lookup(name) : nullptr; 181 } 182 183 //===----------------------------------------------------------------------===// 184 // CompoundStmt 185 //===----------------------------------------------------------------------===// 186 187 CompoundStmt *CompoundStmt::create(Context &ctx, SMRange loc, 188 ArrayRef<Stmt *> children) { 189 unsigned allocSize = CompoundStmt::totalSizeToAlloc<Stmt *>(children.size()); 190 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CompoundStmt)); 191 192 CompoundStmt *stmt = new (rawData) CompoundStmt(loc, children.size()); 193 std::uninitialized_copy(children.begin(), children.end(), 194 stmt->getChildren().begin()); 195 return stmt; 196 } 197 198 //===----------------------------------------------------------------------===// 199 // LetStmt 200 //===----------------------------------------------------------------------===// 201 202 LetStmt *LetStmt::create(Context &ctx, SMRange loc, 203 VariableDecl *varDecl) { 204 return new (ctx.getAllocator().Allocate<LetStmt>()) LetStmt(loc, varDecl); 205 } 206 207 //===----------------------------------------------------------------------===// 208 // OpRewriteStmt 209 //===----------------------------------------------------------------------===// 210 211 //===----------------------------------------------------------------------===// 212 // EraseStmt 213 214 EraseStmt *EraseStmt::create(Context &ctx, SMRange loc, Expr *rootOp) { 215 return new (ctx.getAllocator().Allocate<EraseStmt>()) EraseStmt(loc, rootOp); 216 } 217 218 //===----------------------------------------------------------------------===// 219 // ReplaceStmt 220 221 ReplaceStmt *ReplaceStmt::create(Context &ctx, SMRange loc, Expr *rootOp, 222 ArrayRef<Expr *> replExprs) { 223 unsigned allocSize = ReplaceStmt::totalSizeToAlloc<Expr *>(replExprs.size()); 224 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(ReplaceStmt)); 225 226 ReplaceStmt *stmt = new (rawData) ReplaceStmt(loc, rootOp, replExprs.size()); 227 std::uninitialized_copy(replExprs.begin(), replExprs.end(), 228 stmt->getReplExprs().begin()); 229 return stmt; 230 } 231 232 //===----------------------------------------------------------------------===// 233 // RewriteStmt 234 235 RewriteStmt *RewriteStmt::create(Context &ctx, SMRange loc, Expr *rootOp, 236 CompoundStmt *rewriteBody) { 237 return new (ctx.getAllocator().Allocate<RewriteStmt>()) 238 RewriteStmt(loc, rootOp, rewriteBody); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // ReturnStmt 243 //===----------------------------------------------------------------------===// 244 245 ReturnStmt *ReturnStmt::create(Context &ctx, SMRange loc, Expr *resultExpr) { 246 return new (ctx.getAllocator().Allocate<ReturnStmt>()) 247 ReturnStmt(loc, resultExpr); 248 } 249 250 //===----------------------------------------------------------------------===// 251 // AttributeExpr 252 //===----------------------------------------------------------------------===// 253 254 AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc, 255 StringRef value) { 256 return new (ctx.getAllocator().Allocate<AttributeExpr>()) 257 AttributeExpr(ctx, loc, copyStringWithNull(ctx, value)); 258 } 259 260 //===----------------------------------------------------------------------===// 261 // CallExpr 262 //===----------------------------------------------------------------------===// 263 264 CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable, 265 ArrayRef<Expr *> arguments, Type resultType) { 266 unsigned allocSize = CallExpr::totalSizeToAlloc<Expr *>(arguments.size()); 267 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr)); 268 269 CallExpr *expr = 270 new (rawData) CallExpr(loc, resultType, callable, arguments.size()); 271 std::uninitialized_copy(arguments.begin(), arguments.end(), 272 expr->getArguments().begin()); 273 return expr; 274 } 275 276 //===----------------------------------------------------------------------===// 277 // DeclRefExpr 278 //===----------------------------------------------------------------------===// 279 280 DeclRefExpr *DeclRefExpr::create(Context &ctx, SMRange loc, Decl *decl, 281 Type type) { 282 return new (ctx.getAllocator().Allocate<DeclRefExpr>()) 283 DeclRefExpr(loc, decl, type); 284 } 285 286 //===----------------------------------------------------------------------===// 287 // MemberAccessExpr 288 //===----------------------------------------------------------------------===// 289 290 MemberAccessExpr *MemberAccessExpr::create(Context &ctx, SMRange loc, 291 const Expr *parentExpr, 292 StringRef memberName, Type type) { 293 return new (ctx.getAllocator().Allocate<MemberAccessExpr>()) MemberAccessExpr( 294 loc, parentExpr, memberName.copy(ctx.getAllocator()), type); 295 } 296 297 //===----------------------------------------------------------------------===// 298 // OperationExpr 299 //===----------------------------------------------------------------------===// 300 301 OperationExpr * 302 OperationExpr::create(Context &ctx, SMRange loc, const ods::Operation *odsOp, 303 const OpNameDecl *name, ArrayRef<Expr *> operands, 304 ArrayRef<Expr *> resultTypes, 305 ArrayRef<NamedAttributeDecl *> attributes) { 306 unsigned allocSize = 307 OperationExpr::totalSizeToAlloc<Expr *, NamedAttributeDecl *>( 308 operands.size() + resultTypes.size(), attributes.size()); 309 void *rawData = 310 ctx.getAllocator().Allocate(allocSize, alignof(OperationExpr)); 311 312 Type resultType = OperationType::get(ctx, name->getName(), odsOp); 313 OperationExpr *opExpr = new (rawData) 314 OperationExpr(loc, resultType, name, operands.size(), resultTypes.size(), 315 attributes.size(), name->getLoc()); 316 std::uninitialized_copy(operands.begin(), operands.end(), 317 opExpr->getOperands().begin()); 318 std::uninitialized_copy(resultTypes.begin(), resultTypes.end(), 319 opExpr->getResultTypes().begin()); 320 std::uninitialized_copy(attributes.begin(), attributes.end(), 321 opExpr->getAttributes().begin()); 322 return opExpr; 323 } 324 325 Optional<StringRef> OperationExpr::getName() const { 326 return getNameDecl()->getName(); 327 } 328 329 //===----------------------------------------------------------------------===// 330 // TupleExpr 331 //===----------------------------------------------------------------------===// 332 333 TupleExpr *TupleExpr::create(Context &ctx, SMRange loc, 334 ArrayRef<Expr *> elements, 335 ArrayRef<StringRef> names) { 336 unsigned allocSize = TupleExpr::totalSizeToAlloc<Expr *>(elements.size()); 337 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(TupleExpr)); 338 339 auto elementTypes = llvm::map_range( 340 elements, [](const Expr *expr) { return expr->getType(); }); 341 TupleType type = TupleType::get(ctx, llvm::to_vector(elementTypes), names); 342 343 TupleExpr *expr = new (rawData) TupleExpr(loc, type); 344 std::uninitialized_copy(elements.begin(), elements.end(), 345 expr->getElements().begin()); 346 return expr; 347 } 348 349 //===----------------------------------------------------------------------===// 350 // TypeExpr 351 //===----------------------------------------------------------------------===// 352 353 TypeExpr *TypeExpr::create(Context &ctx, SMRange loc, StringRef value) { 354 return new (ctx.getAllocator().Allocate<TypeExpr>()) 355 TypeExpr(ctx, loc, copyStringWithNull(ctx, value)); 356 } 357 358 //===----------------------------------------------------------------------===// 359 // Decl 360 //===----------------------------------------------------------------------===// 361 362 void Decl::setDocComment(Context &ctx, StringRef comment) { 363 docComment = comment.copy(ctx.getAllocator()); 364 } 365 366 //===----------------------------------------------------------------------===// 367 // AttrConstraintDecl 368 //===----------------------------------------------------------------------===// 369 370 AttrConstraintDecl *AttrConstraintDecl::create(Context &ctx, SMRange loc, 371 Expr *typeExpr) { 372 return new (ctx.getAllocator().Allocate<AttrConstraintDecl>()) 373 AttrConstraintDecl(loc, typeExpr); 374 } 375 376 //===----------------------------------------------------------------------===// 377 // OpConstraintDecl 378 //===----------------------------------------------------------------------===// 379 380 OpConstraintDecl *OpConstraintDecl::create(Context &ctx, SMRange loc, 381 const OpNameDecl *nameDecl) { 382 if (!nameDecl) 383 nameDecl = OpNameDecl::create(ctx, SMRange()); 384 385 return new (ctx.getAllocator().Allocate<OpConstraintDecl>()) 386 OpConstraintDecl(loc, nameDecl); 387 } 388 389 Optional<StringRef> OpConstraintDecl::getName() const { 390 return getNameDecl()->getName(); 391 } 392 393 //===----------------------------------------------------------------------===// 394 // TypeConstraintDecl 395 //===----------------------------------------------------------------------===// 396 397 TypeConstraintDecl *TypeConstraintDecl::create(Context &ctx, 398 SMRange loc) { 399 return new (ctx.getAllocator().Allocate<TypeConstraintDecl>()) 400 TypeConstraintDecl(loc); 401 } 402 403 //===----------------------------------------------------------------------===// 404 // TypeRangeConstraintDecl 405 //===----------------------------------------------------------------------===// 406 407 TypeRangeConstraintDecl *TypeRangeConstraintDecl::create(Context &ctx, 408 SMRange loc) { 409 return new (ctx.getAllocator().Allocate<TypeRangeConstraintDecl>()) 410 TypeRangeConstraintDecl(loc); 411 } 412 413 //===----------------------------------------------------------------------===// 414 // ValueConstraintDecl 415 //===----------------------------------------------------------------------===// 416 417 ValueConstraintDecl * 418 ValueConstraintDecl::create(Context &ctx, SMRange loc, Expr *typeExpr) { 419 return new (ctx.getAllocator().Allocate<ValueConstraintDecl>()) 420 ValueConstraintDecl(loc, typeExpr); 421 } 422 423 //===----------------------------------------------------------------------===// 424 // ValueRangeConstraintDecl 425 //===----------------------------------------------------------------------===// 426 427 ValueRangeConstraintDecl *ValueRangeConstraintDecl::create(Context &ctx, 428 SMRange loc, 429 Expr *typeExpr) { 430 return new (ctx.getAllocator().Allocate<ValueRangeConstraintDecl>()) 431 ValueRangeConstraintDecl(loc, typeExpr); 432 } 433 434 //===----------------------------------------------------------------------===// 435 // UserConstraintDecl 436 //===----------------------------------------------------------------------===// 437 438 Optional<StringRef> 439 UserConstraintDecl::getNativeInputType(unsigned index) const { 440 return hasNativeInputTypes ? getTrailingObjects<StringRef>()[index] 441 : Optional<StringRef>(); 442 } 443 444 UserConstraintDecl *UserConstraintDecl::createImpl( 445 Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs, 446 ArrayRef<StringRef> nativeInputTypes, ArrayRef<VariableDecl *> results, 447 Optional<StringRef> codeBlock, const CompoundStmt *body, Type resultType) { 448 bool hasNativeInputTypes = !nativeInputTypes.empty(); 449 assert(!hasNativeInputTypes || nativeInputTypes.size() == inputs.size()); 450 451 unsigned allocSize = 452 UserConstraintDecl::totalSizeToAlloc<VariableDecl *, StringRef>( 453 inputs.size() + results.size(), 454 hasNativeInputTypes ? inputs.size() : 0); 455 void *rawData = 456 ctx.getAllocator().Allocate(allocSize, alignof(UserConstraintDecl)); 457 if (codeBlock) 458 codeBlock = codeBlock->copy(ctx.getAllocator()); 459 460 UserConstraintDecl *decl = new (rawData) 461 UserConstraintDecl(name, inputs.size(), hasNativeInputTypes, 462 results.size(), codeBlock, body, resultType); 463 std::uninitialized_copy(inputs.begin(), inputs.end(), 464 decl->getInputs().begin()); 465 std::uninitialized_copy(results.begin(), results.end(), 466 decl->getResults().begin()); 467 if (hasNativeInputTypes) { 468 StringRef *nativeInputTypesPtr = decl->getTrailingObjects<StringRef>(); 469 for (unsigned i = 0, e = inputs.size(); i < e; ++i) 470 nativeInputTypesPtr[i] = nativeInputTypes[i].copy(ctx.getAllocator()); 471 } 472 473 return decl; 474 } 475 476 //===----------------------------------------------------------------------===// 477 // NamedAttributeDecl 478 //===----------------------------------------------------------------------===// 479 480 NamedAttributeDecl *NamedAttributeDecl::create(Context &ctx, const Name &name, 481 Expr *value) { 482 return new (ctx.getAllocator().Allocate<NamedAttributeDecl>()) 483 NamedAttributeDecl(name, value); 484 } 485 486 //===----------------------------------------------------------------------===// 487 // OpNameDecl 488 //===----------------------------------------------------------------------===// 489 490 OpNameDecl *OpNameDecl::create(Context &ctx, const Name &name) { 491 return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(name); 492 } 493 OpNameDecl *OpNameDecl::create(Context &ctx, SMRange loc) { 494 return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(loc); 495 } 496 497 //===----------------------------------------------------------------------===// 498 // PatternDecl 499 //===----------------------------------------------------------------------===// 500 501 PatternDecl *PatternDecl::create(Context &ctx, SMRange loc, 502 const Name *name, Optional<uint16_t> benefit, 503 bool hasBoundedRecursion, 504 const CompoundStmt *body) { 505 return new (ctx.getAllocator().Allocate<PatternDecl>()) 506 PatternDecl(loc, name, benefit, hasBoundedRecursion, body); 507 } 508 509 //===----------------------------------------------------------------------===// 510 // UserRewriteDecl 511 //===----------------------------------------------------------------------===// 512 513 UserRewriteDecl *UserRewriteDecl::createImpl(Context &ctx, const Name &name, 514 ArrayRef<VariableDecl *> inputs, 515 ArrayRef<VariableDecl *> results, 516 Optional<StringRef> codeBlock, 517 const CompoundStmt *body, 518 Type resultType) { 519 unsigned allocSize = UserRewriteDecl::totalSizeToAlloc<VariableDecl *>( 520 inputs.size() + results.size()); 521 void *rawData = 522 ctx.getAllocator().Allocate(allocSize, alignof(UserRewriteDecl)); 523 if (codeBlock) 524 codeBlock = codeBlock->copy(ctx.getAllocator()); 525 526 UserRewriteDecl *decl = new (rawData) UserRewriteDecl( 527 name, inputs.size(), results.size(), codeBlock, body, resultType); 528 std::uninitialized_copy(inputs.begin(), inputs.end(), 529 decl->getInputs().begin()); 530 std::uninitialized_copy(results.begin(), results.end(), 531 decl->getResults().begin()); 532 return decl; 533 } 534 535 //===----------------------------------------------------------------------===// 536 // VariableDecl 537 //===----------------------------------------------------------------------===// 538 539 VariableDecl *VariableDecl::create(Context &ctx, const Name &name, Type type, 540 Expr *initExpr, 541 ArrayRef<ConstraintRef> constraints) { 542 unsigned allocSize = 543 VariableDecl::totalSizeToAlloc<ConstraintRef>(constraints.size()); 544 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(VariableDecl)); 545 546 VariableDecl *varDecl = 547 new (rawData) VariableDecl(name, type, initExpr, constraints.size()); 548 std::uninitialized_copy(constraints.begin(), constraints.end(), 549 varDecl->getConstraints().begin()); 550 return varDecl; 551 } 552 553 //===----------------------------------------------------------------------===// 554 // Module 555 //===----------------------------------------------------------------------===// 556 557 Module *Module::create(Context &ctx, SMLoc loc, 558 ArrayRef<Decl *> children) { 559 unsigned allocSize = Module::totalSizeToAlloc<Decl *>(children.size()); 560 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(Module)); 561 562 Module *module = new (rawData) Module(loc, children.size()); 563 std::uninitialized_copy(children.begin(), children.end(), 564 module->getChildren().begin()); 565 return module; 566 } 567