1 //===- Nodes.cpp ----------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/AST/Nodes.h" 10 #include "mlir/Tools/PDLL/AST/Context.h" 11 #include "llvm/ADT/SmallPtrSet.h" 12 #include "llvm/ADT/TypeSwitch.h" 13 14 using namespace mlir; 15 using namespace mlir::pdll::ast; 16 17 /// Copy a string reference into the context with a null terminator. 18 static StringRef copyStringWithNull(Context &ctx, StringRef str) { 19 if (str.empty()) 20 return str; 21 22 char *data = ctx.getAllocator().Allocate<char>(str.size() + 1); 23 std::copy(str.begin(), str.end(), data); 24 data[str.size()] = 0; 25 return StringRef(data, str.size()); 26 } 27 28 //===----------------------------------------------------------------------===// 29 // Name 30 //===----------------------------------------------------------------------===// 31 32 const Name &Name::create(Context &ctx, StringRef name, SMRange location) { 33 return *new (ctx.getAllocator().Allocate<Name>()) 34 Name(copyStringWithNull(ctx, name), location); 35 } 36 37 //===----------------------------------------------------------------------===// 38 // Node 39 //===----------------------------------------------------------------------===// 40 41 namespace { 42 class NodeVisitor { 43 public: 44 explicit NodeVisitor(function_ref<void(const Node *)> visitFn) 45 : visitFn(visitFn) {} 46 47 void visit(const Node *node) { 48 if (!node || !alreadyVisited.insert(node).second) 49 return; 50 51 visitFn(node); 52 TypeSwitch<const Node *>(node) 53 .Case< 54 // Statements. 55 const CompoundStmt, const EraseStmt, const LetStmt, 56 const ReplaceStmt, const ReturnStmt, const RewriteStmt, 57 58 // Expressions. 59 const AttributeExpr, const CallExpr, const DeclRefExpr, 60 const MemberAccessExpr, const OperationExpr, const TupleExpr, 61 const TypeExpr, 62 63 // Core Constraint Decls. 64 const AttrConstraintDecl, const OpConstraintDecl, 65 const TypeConstraintDecl, const TypeRangeConstraintDecl, 66 const ValueConstraintDecl, const ValueRangeConstraintDecl, 67 68 // Decls. 69 const NamedAttributeDecl, const OpNameDecl, const PatternDecl, 70 const UserConstraintDecl, const UserRewriteDecl, const VariableDecl, 71 72 const Module>( 73 [&](auto derivedNode) { this->visitImpl(derivedNode); }) 74 .Default([](const Node *) { llvm_unreachable("unknown AST node"); }); 75 } 76 77 private: 78 void visitImpl(const CompoundStmt *stmt) { 79 for (const Node *child : stmt->getChildren()) 80 visit(child); 81 } 82 void visitImpl(const EraseStmt *stmt) { visit(stmt->getRootOpExpr()); } 83 void visitImpl(const LetStmt *stmt) { visit(stmt->getVarDecl()); } 84 void visitImpl(const ReplaceStmt *stmt) { 85 visit(stmt->getRootOpExpr()); 86 for (const Node *child : stmt->getReplExprs()) 87 visit(child); 88 } 89 void visitImpl(const ReturnStmt *stmt) { visit(stmt->getResultExpr()); } 90 void visitImpl(const RewriteStmt *stmt) { 91 visit(stmt->getRootOpExpr()); 92 visit(stmt->getRewriteBody()); 93 } 94 95 void visitImpl(const AttributeExpr *expr) {} 96 void visitImpl(const CallExpr *expr) { 97 visit(expr->getCallableExpr()); 98 for (const Node *child : expr->getArguments()) 99 visit(child); 100 } 101 void visitImpl(const DeclRefExpr *expr) { visit(expr->getDecl()); } 102 void visitImpl(const MemberAccessExpr *expr) { visit(expr->getParentExpr()); } 103 void visitImpl(const OperationExpr *expr) { 104 visit(expr->getNameDecl()); 105 for (const Node *child : expr->getOperands()) 106 visit(child); 107 for (const Node *child : expr->getResultTypes()) 108 visit(child); 109 for (const Node *child : expr->getAttributes()) 110 visit(child); 111 } 112 void visitImpl(const TupleExpr *expr) { 113 for (const Node *child : expr->getElements()) 114 visit(child); 115 } 116 void visitImpl(const TypeExpr *expr) {} 117 118 void visitImpl(const AttrConstraintDecl *decl) { visit(decl->getTypeExpr()); } 119 void visitImpl(const OpConstraintDecl *decl) { visit(decl->getNameDecl()); } 120 void visitImpl(const TypeConstraintDecl *decl) {} 121 void visitImpl(const TypeRangeConstraintDecl *decl) {} 122 void visitImpl(const ValueConstraintDecl *decl) { 123 visit(decl->getTypeExpr()); 124 } 125 void visitImpl(const ValueRangeConstraintDecl *decl) { 126 visit(decl->getTypeExpr()); 127 } 128 129 void visitImpl(const NamedAttributeDecl *decl) { visit(decl->getValue()); } 130 void visitImpl(const OpNameDecl *decl) {} 131 void visitImpl(const PatternDecl *decl) { visit(decl->getBody()); } 132 void visitImpl(const UserConstraintDecl *decl) { 133 for (const Node *child : decl->getInputs()) 134 visit(child); 135 for (const Node *child : decl->getResults()) 136 visit(child); 137 visit(decl->getBody()); 138 } 139 void visitImpl(const UserRewriteDecl *decl) { 140 for (const Node *child : decl->getInputs()) 141 visit(child); 142 for (const Node *child : decl->getResults()) 143 visit(child); 144 visit(decl->getBody()); 145 } 146 void visitImpl(const VariableDecl *decl) { 147 visit(decl->getInitExpr()); 148 for (const ConstraintRef &child : decl->getConstraints()) 149 visit(child.constraint); 150 } 151 152 void visitImpl(const Module *module) { 153 for (const Node *child : module->getChildren()) 154 visit(child); 155 } 156 157 function_ref<void(const Node *)> visitFn; 158 SmallPtrSet<const Node *, 16> alreadyVisited; 159 }; 160 } // namespace 161 162 void Node::walk(function_ref<void(const Node *)> walkFn) const { 163 return NodeVisitor(walkFn).visit(this); 164 } 165 166 //===----------------------------------------------------------------------===// 167 // DeclScope 168 //===----------------------------------------------------------------------===// 169 170 void DeclScope::add(Decl *decl) { 171 const Name *name = decl->getName(); 172 assert(name && "expected a named decl"); 173 assert(!decls.count(name->getName()) && "decl with this name already exists"); 174 decls.try_emplace(name->getName(), decl); 175 } 176 177 Decl *DeclScope::lookup(StringRef name) { 178 if (Decl *decl = decls.lookup(name)) 179 return decl; 180 return parent ? parent->lookup(name) : nullptr; 181 } 182 183 //===----------------------------------------------------------------------===// 184 // CompoundStmt 185 //===----------------------------------------------------------------------===// 186 187 CompoundStmt *CompoundStmt::create(Context &ctx, SMRange loc, 188 ArrayRef<Stmt *> children) { 189 unsigned allocSize = CompoundStmt::totalSizeToAlloc<Stmt *>(children.size()); 190 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CompoundStmt)); 191 192 CompoundStmt *stmt = new (rawData) CompoundStmt(loc, children.size()); 193 std::uninitialized_copy(children.begin(), children.end(), 194 stmt->getChildren().begin()); 195 return stmt; 196 } 197 198 //===----------------------------------------------------------------------===// 199 // LetStmt 200 //===----------------------------------------------------------------------===// 201 202 LetStmt *LetStmt::create(Context &ctx, SMRange loc, VariableDecl *varDecl) { 203 return new (ctx.getAllocator().Allocate<LetStmt>()) LetStmt(loc, varDecl); 204 } 205 206 //===----------------------------------------------------------------------===// 207 // OpRewriteStmt 208 //===----------------------------------------------------------------------===// 209 210 //===----------------------------------------------------------------------===// 211 // EraseStmt 212 213 EraseStmt *EraseStmt::create(Context &ctx, SMRange loc, Expr *rootOp) { 214 return new (ctx.getAllocator().Allocate<EraseStmt>()) EraseStmt(loc, rootOp); 215 } 216 217 //===----------------------------------------------------------------------===// 218 // ReplaceStmt 219 220 ReplaceStmt *ReplaceStmt::create(Context &ctx, SMRange loc, Expr *rootOp, 221 ArrayRef<Expr *> replExprs) { 222 unsigned allocSize = ReplaceStmt::totalSizeToAlloc<Expr *>(replExprs.size()); 223 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(ReplaceStmt)); 224 225 ReplaceStmt *stmt = new (rawData) ReplaceStmt(loc, rootOp, replExprs.size()); 226 std::uninitialized_copy(replExprs.begin(), replExprs.end(), 227 stmt->getReplExprs().begin()); 228 return stmt; 229 } 230 231 //===----------------------------------------------------------------------===// 232 // RewriteStmt 233 234 RewriteStmt *RewriteStmt::create(Context &ctx, SMRange loc, Expr *rootOp, 235 CompoundStmt *rewriteBody) { 236 return new (ctx.getAllocator().Allocate<RewriteStmt>()) 237 RewriteStmt(loc, rootOp, rewriteBody); 238 } 239 240 //===----------------------------------------------------------------------===// 241 // ReturnStmt 242 //===----------------------------------------------------------------------===// 243 244 ReturnStmt *ReturnStmt::create(Context &ctx, SMRange loc, Expr *resultExpr) { 245 return new (ctx.getAllocator().Allocate<ReturnStmt>()) 246 ReturnStmt(loc, resultExpr); 247 } 248 249 //===----------------------------------------------------------------------===// 250 // AttributeExpr 251 //===----------------------------------------------------------------------===// 252 253 AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc, 254 StringRef value) { 255 return new (ctx.getAllocator().Allocate<AttributeExpr>()) 256 AttributeExpr(ctx, loc, copyStringWithNull(ctx, value)); 257 } 258 259 //===----------------------------------------------------------------------===// 260 // CallExpr 261 //===----------------------------------------------------------------------===// 262 263 CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable, 264 ArrayRef<Expr *> arguments, Type resultType) { 265 unsigned allocSize = CallExpr::totalSizeToAlloc<Expr *>(arguments.size()); 266 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr)); 267 268 CallExpr *expr = 269 new (rawData) CallExpr(loc, resultType, callable, arguments.size()); 270 std::uninitialized_copy(arguments.begin(), arguments.end(), 271 expr->getArguments().begin()); 272 return expr; 273 } 274 275 //===----------------------------------------------------------------------===// 276 // DeclRefExpr 277 //===----------------------------------------------------------------------===// 278 279 DeclRefExpr *DeclRefExpr::create(Context &ctx, SMRange loc, Decl *decl, 280 Type type) { 281 return new (ctx.getAllocator().Allocate<DeclRefExpr>()) 282 DeclRefExpr(loc, decl, type); 283 } 284 285 //===----------------------------------------------------------------------===// 286 // MemberAccessExpr 287 //===----------------------------------------------------------------------===// 288 289 MemberAccessExpr *MemberAccessExpr::create(Context &ctx, SMRange loc, 290 const Expr *parentExpr, 291 StringRef memberName, Type type) { 292 return new (ctx.getAllocator().Allocate<MemberAccessExpr>()) MemberAccessExpr( 293 loc, parentExpr, memberName.copy(ctx.getAllocator()), type); 294 } 295 296 //===----------------------------------------------------------------------===// 297 // OperationExpr 298 //===----------------------------------------------------------------------===// 299 300 OperationExpr * 301 OperationExpr::create(Context &ctx, SMRange loc, const ods::Operation *odsOp, 302 const OpNameDecl *name, ArrayRef<Expr *> operands, 303 ArrayRef<Expr *> resultTypes, 304 ArrayRef<NamedAttributeDecl *> attributes) { 305 unsigned allocSize = 306 OperationExpr::totalSizeToAlloc<Expr *, NamedAttributeDecl *>( 307 operands.size() + resultTypes.size(), attributes.size()); 308 void *rawData = 309 ctx.getAllocator().Allocate(allocSize, alignof(OperationExpr)); 310 311 Type resultType = OperationType::get(ctx, name->getName(), odsOp); 312 OperationExpr *opExpr = new (rawData) 313 OperationExpr(loc, resultType, name, operands.size(), resultTypes.size(), 314 attributes.size(), name->getLoc()); 315 std::uninitialized_copy(operands.begin(), operands.end(), 316 opExpr->getOperands().begin()); 317 std::uninitialized_copy(resultTypes.begin(), resultTypes.end(), 318 opExpr->getResultTypes().begin()); 319 std::uninitialized_copy(attributes.begin(), attributes.end(), 320 opExpr->getAttributes().begin()); 321 return opExpr; 322 } 323 324 Optional<StringRef> OperationExpr::getName() const { 325 return getNameDecl()->getName(); 326 } 327 328 //===----------------------------------------------------------------------===// 329 // TupleExpr 330 //===----------------------------------------------------------------------===// 331 332 TupleExpr *TupleExpr::create(Context &ctx, SMRange loc, 333 ArrayRef<Expr *> elements, 334 ArrayRef<StringRef> names) { 335 unsigned allocSize = TupleExpr::totalSizeToAlloc<Expr *>(elements.size()); 336 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(TupleExpr)); 337 338 auto elementTypes = llvm::map_range( 339 elements, [](const Expr *expr) { return expr->getType(); }); 340 TupleType type = TupleType::get(ctx, llvm::to_vector(elementTypes), names); 341 342 TupleExpr *expr = new (rawData) TupleExpr(loc, type); 343 std::uninitialized_copy(elements.begin(), elements.end(), 344 expr->getElements().begin()); 345 return expr; 346 } 347 348 //===----------------------------------------------------------------------===// 349 // TypeExpr 350 //===----------------------------------------------------------------------===// 351 352 TypeExpr *TypeExpr::create(Context &ctx, SMRange loc, StringRef value) { 353 return new (ctx.getAllocator().Allocate<TypeExpr>()) 354 TypeExpr(ctx, loc, copyStringWithNull(ctx, value)); 355 } 356 357 //===----------------------------------------------------------------------===// 358 // Decl 359 //===----------------------------------------------------------------------===// 360 361 void Decl::setDocComment(Context &ctx, StringRef comment) { 362 docComment = comment.copy(ctx.getAllocator()); 363 } 364 365 //===----------------------------------------------------------------------===// 366 // AttrConstraintDecl 367 //===----------------------------------------------------------------------===// 368 369 AttrConstraintDecl *AttrConstraintDecl::create(Context &ctx, SMRange loc, 370 Expr *typeExpr) { 371 return new (ctx.getAllocator().Allocate<AttrConstraintDecl>()) 372 AttrConstraintDecl(loc, typeExpr); 373 } 374 375 //===----------------------------------------------------------------------===// 376 // OpConstraintDecl 377 //===----------------------------------------------------------------------===// 378 379 OpConstraintDecl *OpConstraintDecl::create(Context &ctx, SMRange loc, 380 const OpNameDecl *nameDecl) { 381 if (!nameDecl) 382 nameDecl = OpNameDecl::create(ctx, SMRange()); 383 384 return new (ctx.getAllocator().Allocate<OpConstraintDecl>()) 385 OpConstraintDecl(loc, nameDecl); 386 } 387 388 Optional<StringRef> OpConstraintDecl::getName() const { 389 return getNameDecl()->getName(); 390 } 391 392 //===----------------------------------------------------------------------===// 393 // TypeConstraintDecl 394 //===----------------------------------------------------------------------===// 395 396 TypeConstraintDecl *TypeConstraintDecl::create(Context &ctx, SMRange loc) { 397 return new (ctx.getAllocator().Allocate<TypeConstraintDecl>()) 398 TypeConstraintDecl(loc); 399 } 400 401 //===----------------------------------------------------------------------===// 402 // TypeRangeConstraintDecl 403 //===----------------------------------------------------------------------===// 404 405 TypeRangeConstraintDecl *TypeRangeConstraintDecl::create(Context &ctx, 406 SMRange loc) { 407 return new (ctx.getAllocator().Allocate<TypeRangeConstraintDecl>()) 408 TypeRangeConstraintDecl(loc); 409 } 410 411 //===----------------------------------------------------------------------===// 412 // ValueConstraintDecl 413 //===----------------------------------------------------------------------===// 414 415 ValueConstraintDecl *ValueConstraintDecl::create(Context &ctx, SMRange loc, 416 Expr *typeExpr) { 417 return new (ctx.getAllocator().Allocate<ValueConstraintDecl>()) 418 ValueConstraintDecl(loc, typeExpr); 419 } 420 421 //===----------------------------------------------------------------------===// 422 // ValueRangeConstraintDecl 423 //===----------------------------------------------------------------------===// 424 425 ValueRangeConstraintDecl * 426 ValueRangeConstraintDecl::create(Context &ctx, SMRange loc, Expr *typeExpr) { 427 return new (ctx.getAllocator().Allocate<ValueRangeConstraintDecl>()) 428 ValueRangeConstraintDecl(loc, typeExpr); 429 } 430 431 //===----------------------------------------------------------------------===// 432 // UserConstraintDecl 433 //===----------------------------------------------------------------------===// 434 435 Optional<StringRef> 436 UserConstraintDecl::getNativeInputType(unsigned index) const { 437 return hasNativeInputTypes ? getTrailingObjects<StringRef>()[index] 438 : Optional<StringRef>(); 439 } 440 441 UserConstraintDecl *UserConstraintDecl::createImpl( 442 Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs, 443 ArrayRef<StringRef> nativeInputTypes, ArrayRef<VariableDecl *> results, 444 Optional<StringRef> codeBlock, const CompoundStmt *body, Type resultType) { 445 bool hasNativeInputTypes = !nativeInputTypes.empty(); 446 assert(!hasNativeInputTypes || nativeInputTypes.size() == inputs.size()); 447 448 unsigned allocSize = 449 UserConstraintDecl::totalSizeToAlloc<VariableDecl *, StringRef>( 450 inputs.size() + results.size(), 451 hasNativeInputTypes ? inputs.size() : 0); 452 void *rawData = 453 ctx.getAllocator().Allocate(allocSize, alignof(UserConstraintDecl)); 454 if (codeBlock) 455 codeBlock = codeBlock->copy(ctx.getAllocator()); 456 457 UserConstraintDecl *decl = new (rawData) 458 UserConstraintDecl(name, inputs.size(), hasNativeInputTypes, 459 results.size(), codeBlock, body, resultType); 460 std::uninitialized_copy(inputs.begin(), inputs.end(), 461 decl->getInputs().begin()); 462 std::uninitialized_copy(results.begin(), results.end(), 463 decl->getResults().begin()); 464 if (hasNativeInputTypes) { 465 StringRef *nativeInputTypesPtr = decl->getTrailingObjects<StringRef>(); 466 for (unsigned i = 0, e = inputs.size(); i < e; ++i) 467 nativeInputTypesPtr[i] = nativeInputTypes[i].copy(ctx.getAllocator()); 468 } 469 470 return decl; 471 } 472 473 //===----------------------------------------------------------------------===// 474 // NamedAttributeDecl 475 //===----------------------------------------------------------------------===// 476 477 NamedAttributeDecl *NamedAttributeDecl::create(Context &ctx, const Name &name, 478 Expr *value) { 479 return new (ctx.getAllocator().Allocate<NamedAttributeDecl>()) 480 NamedAttributeDecl(name, value); 481 } 482 483 //===----------------------------------------------------------------------===// 484 // OpNameDecl 485 //===----------------------------------------------------------------------===// 486 487 OpNameDecl *OpNameDecl::create(Context &ctx, const Name &name) { 488 return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(name); 489 } 490 OpNameDecl *OpNameDecl::create(Context &ctx, SMRange loc) { 491 return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(loc); 492 } 493 494 //===----------------------------------------------------------------------===// 495 // PatternDecl 496 //===----------------------------------------------------------------------===// 497 498 PatternDecl *PatternDecl::create(Context &ctx, SMRange loc, const Name *name, 499 Optional<uint16_t> benefit, 500 bool hasBoundedRecursion, 501 const CompoundStmt *body) { 502 return new (ctx.getAllocator().Allocate<PatternDecl>()) 503 PatternDecl(loc, name, benefit, hasBoundedRecursion, body); 504 } 505 506 //===----------------------------------------------------------------------===// 507 // UserRewriteDecl 508 //===----------------------------------------------------------------------===// 509 510 UserRewriteDecl *UserRewriteDecl::createImpl(Context &ctx, const Name &name, 511 ArrayRef<VariableDecl *> inputs, 512 ArrayRef<VariableDecl *> results, 513 Optional<StringRef> codeBlock, 514 const CompoundStmt *body, 515 Type resultType) { 516 unsigned allocSize = UserRewriteDecl::totalSizeToAlloc<VariableDecl *>( 517 inputs.size() + results.size()); 518 void *rawData = 519 ctx.getAllocator().Allocate(allocSize, alignof(UserRewriteDecl)); 520 if (codeBlock) 521 codeBlock = codeBlock->copy(ctx.getAllocator()); 522 523 UserRewriteDecl *decl = new (rawData) UserRewriteDecl( 524 name, inputs.size(), results.size(), codeBlock, body, resultType); 525 std::uninitialized_copy(inputs.begin(), inputs.end(), 526 decl->getInputs().begin()); 527 std::uninitialized_copy(results.begin(), results.end(), 528 decl->getResults().begin()); 529 return decl; 530 } 531 532 //===----------------------------------------------------------------------===// 533 // VariableDecl 534 //===----------------------------------------------------------------------===// 535 536 VariableDecl *VariableDecl::create(Context &ctx, const Name &name, Type type, 537 Expr *initExpr, 538 ArrayRef<ConstraintRef> constraints) { 539 unsigned allocSize = 540 VariableDecl::totalSizeToAlloc<ConstraintRef>(constraints.size()); 541 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(VariableDecl)); 542 543 VariableDecl *varDecl = 544 new (rawData) VariableDecl(name, type, initExpr, constraints.size()); 545 std::uninitialized_copy(constraints.begin(), constraints.end(), 546 varDecl->getConstraints().begin()); 547 return varDecl; 548 } 549 550 //===----------------------------------------------------------------------===// 551 // Module 552 //===----------------------------------------------------------------------===// 553 554 Module *Module::create(Context &ctx, SMLoc loc, ArrayRef<Decl *> children) { 555 unsigned allocSize = Module::totalSizeToAlloc<Decl *>(children.size()); 556 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(Module)); 557 558 Module *module = new (rawData) Module(loc, children.size()); 559 std::uninitialized_copy(children.begin(), children.end(), 560 module->getChildren().begin()); 561 return module; 562 } 563