1 //===- Nodes.cpp ----------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/AST/Nodes.h" 10 #include "mlir/Tools/PDLL/AST/Context.h" 11 #include "llvm/ADT/SmallPtrSet.h" 12 #include "llvm/ADT/TypeSwitch.h" 13 14 using namespace mlir; 15 using namespace mlir::pdll::ast; 16 17 /// Copy a string reference into the context with a null terminator. 18 static StringRef copyStringWithNull(Context &ctx, StringRef str) { 19 if (str.empty()) 20 return str; 21 22 char *data = ctx.getAllocator().Allocate<char>(str.size() + 1); 23 std::copy(str.begin(), str.end(), data); 24 data[str.size()] = 0; 25 return StringRef(data, str.size()); 26 } 27 28 //===----------------------------------------------------------------------===// 29 // Name 30 //===----------------------------------------------------------------------===// 31 32 const Name &Name::create(Context &ctx, StringRef name, SMRange location) { 33 return *new (ctx.getAllocator().Allocate<Name>()) 34 Name(copyStringWithNull(ctx, name), location); 35 } 36 37 //===----------------------------------------------------------------------===// 38 // Node 39 //===----------------------------------------------------------------------===// 40 41 namespace { 42 class NodeVisitor { 43 public: 44 explicit NodeVisitor(function_ref<void(const Node *)> visitFn) 45 : visitFn(visitFn) {} 46 47 void visit(const Node *node) { 48 if (!node || !alreadyVisited.insert(node).second) 49 return; 50 51 visitFn(node); 52 TypeSwitch<const Node *>(node) 53 .Case< 54 // Statements. 55 const CompoundStmt, const EraseStmt, const LetStmt, 56 const ReplaceStmt, const ReturnStmt, const RewriteStmt, 57 58 // Expressions. 59 const AttributeExpr, const CallExpr, const DeclRefExpr, 60 const MemberAccessExpr, const OperationExpr, const TupleExpr, 61 const TypeExpr, 62 63 // Core Constraint Decls. 64 const AttrConstraintDecl, const OpConstraintDecl, 65 const TypeConstraintDecl, const TypeRangeConstraintDecl, 66 const ValueConstraintDecl, const ValueRangeConstraintDecl, 67 68 // Decls. 69 const NamedAttributeDecl, const OpNameDecl, const PatternDecl, 70 const UserConstraintDecl, const UserRewriteDecl, const VariableDecl, 71 72 const Module>( 73 [&](auto derivedNode) { this->visitImpl(derivedNode); }) 74 .Default([](const Node *) { llvm_unreachable("unknown AST node"); }); 75 } 76 77 private: 78 void visitImpl(const CompoundStmt *stmt) { 79 for (const Node *child : stmt->getChildren()) 80 visit(child); 81 } 82 void visitImpl(const EraseStmt *stmt) { visit(stmt->getRootOpExpr()); } 83 void visitImpl(const LetStmt *stmt) { visit(stmt->getVarDecl()); } 84 void visitImpl(const ReplaceStmt *stmt) { 85 visit(stmt->getRootOpExpr()); 86 for (const Node *child : stmt->getReplExprs()) 87 visit(child); 88 } 89 void visitImpl(const ReturnStmt *stmt) { visit(stmt->getResultExpr()); } 90 void visitImpl(const RewriteStmt *stmt) { 91 visit(stmt->getRootOpExpr()); 92 visit(stmt->getRewriteBody()); 93 } 94 95 void visitImpl(const AttributeExpr *expr) {} 96 void visitImpl(const CallExpr *expr) { 97 visit(expr->getCallableExpr()); 98 for (const Node *child : expr->getArguments()) 99 visit(child); 100 } 101 void visitImpl(const DeclRefExpr *expr) { visit(expr->getDecl()); } 102 void visitImpl(const MemberAccessExpr *expr) { visit(expr->getParentExpr()); } 103 void visitImpl(const OperationExpr *expr) { 104 visit(expr->getNameDecl()); 105 for (const Node *child : expr->getOperands()) 106 visit(child); 107 for (const Node *child : expr->getResultTypes()) 108 visit(child); 109 for (const Node *child : expr->getAttributes()) 110 visit(child); 111 } 112 void visitImpl(const TupleExpr *expr) { 113 for (const Node *child : expr->getElements()) 114 visit(child); 115 } 116 void visitImpl(const TypeExpr *expr) {} 117 118 void visitImpl(const AttrConstraintDecl *decl) { visit(decl->getTypeExpr()); } 119 void visitImpl(const OpConstraintDecl *decl) { visit(decl->getNameDecl()); } 120 void visitImpl(const TypeConstraintDecl *decl) {} 121 void visitImpl(const TypeRangeConstraintDecl *decl) {} 122 void visitImpl(const ValueConstraintDecl *decl) { 123 visit(decl->getTypeExpr()); 124 } 125 void visitImpl(const ValueRangeConstraintDecl *decl) { 126 visit(decl->getTypeExpr()); 127 } 128 129 void visitImpl(const NamedAttributeDecl *decl) { visit(decl->getValue()); } 130 void visitImpl(const OpNameDecl *decl) {} 131 void visitImpl(const PatternDecl *decl) { visit(decl->getBody()); } 132 void visitImpl(const UserConstraintDecl *decl) { 133 for (const Node *child : decl->getInputs()) 134 visit(child); 135 for (const Node *child : decl->getResults()) 136 visit(child); 137 visit(decl->getBody()); 138 } 139 void visitImpl(const UserRewriteDecl *decl) { 140 for (const Node *child : decl->getInputs()) 141 visit(child); 142 for (const Node *child : decl->getResults()) 143 visit(child); 144 visit(decl->getBody()); 145 } 146 void visitImpl(const VariableDecl *decl) { 147 visit(decl->getInitExpr()); 148 for (const ConstraintRef &child : decl->getConstraints()) 149 visit(child.constraint); 150 } 151 152 void visitImpl(const Module *module) { 153 for (const Node *child : module->getChildren()) 154 visit(child); 155 } 156 157 function_ref<void(const Node *)> visitFn; 158 SmallPtrSet<const Node *, 16> alreadyVisited; 159 }; 160 } // namespace 161 162 void Node::walk(function_ref<void(const Node *)> walkFn) const { 163 return NodeVisitor(walkFn).visit(this); 164 } 165 166 //===----------------------------------------------------------------------===// 167 // DeclScope 168 //===----------------------------------------------------------------------===// 169 170 void DeclScope::add(Decl *decl) { 171 const Name *name = decl->getName(); 172 assert(name && "expected a named decl"); 173 assert(!decls.count(name->getName()) && "decl with this name already exists"); 174 decls.try_emplace(name->getName(), decl); 175 } 176 177 Decl *DeclScope::lookup(StringRef name) { 178 if (Decl *decl = decls.lookup(name)) 179 return decl; 180 return parent ? parent->lookup(name) : nullptr; 181 } 182 183 //===----------------------------------------------------------------------===// 184 // CompoundStmt 185 //===----------------------------------------------------------------------===// 186 187 CompoundStmt *CompoundStmt::create(Context &ctx, SMRange loc, 188 ArrayRef<Stmt *> children) { 189 unsigned allocSize = CompoundStmt::totalSizeToAlloc<Stmt *>(children.size()); 190 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CompoundStmt)); 191 192 CompoundStmt *stmt = new (rawData) CompoundStmt(loc, children.size()); 193 std::uninitialized_copy(children.begin(), children.end(), 194 stmt->getChildren().begin()); 195 return stmt; 196 } 197 198 //===----------------------------------------------------------------------===// 199 // LetStmt 200 //===----------------------------------------------------------------------===// 201 202 LetStmt *LetStmt::create(Context &ctx, SMRange loc, 203 VariableDecl *varDecl) { 204 return new (ctx.getAllocator().Allocate<LetStmt>()) LetStmt(loc, varDecl); 205 } 206 207 //===----------------------------------------------------------------------===// 208 // OpRewriteStmt 209 //===----------------------------------------------------------------------===// 210 211 //===----------------------------------------------------------------------===// 212 // EraseStmt 213 214 EraseStmt *EraseStmt::create(Context &ctx, SMRange loc, Expr *rootOp) { 215 return new (ctx.getAllocator().Allocate<EraseStmt>()) EraseStmt(loc, rootOp); 216 } 217 218 //===----------------------------------------------------------------------===// 219 // ReplaceStmt 220 221 ReplaceStmt *ReplaceStmt::create(Context &ctx, SMRange loc, Expr *rootOp, 222 ArrayRef<Expr *> replExprs) { 223 unsigned allocSize = ReplaceStmt::totalSizeToAlloc<Expr *>(replExprs.size()); 224 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(ReplaceStmt)); 225 226 ReplaceStmt *stmt = new (rawData) ReplaceStmt(loc, rootOp, replExprs.size()); 227 std::uninitialized_copy(replExprs.begin(), replExprs.end(), 228 stmt->getReplExprs().begin()); 229 return stmt; 230 } 231 232 //===----------------------------------------------------------------------===// 233 // RewriteStmt 234 235 RewriteStmt *RewriteStmt::create(Context &ctx, SMRange loc, Expr *rootOp, 236 CompoundStmt *rewriteBody) { 237 return new (ctx.getAllocator().Allocate<RewriteStmt>()) 238 RewriteStmt(loc, rootOp, rewriteBody); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // ReturnStmt 243 //===----------------------------------------------------------------------===// 244 245 ReturnStmt *ReturnStmt::create(Context &ctx, SMRange loc, Expr *resultExpr) { 246 return new (ctx.getAllocator().Allocate<ReturnStmt>()) 247 ReturnStmt(loc, resultExpr); 248 } 249 250 //===----------------------------------------------------------------------===// 251 // AttributeExpr 252 //===----------------------------------------------------------------------===// 253 254 AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc, 255 StringRef value) { 256 return new (ctx.getAllocator().Allocate<AttributeExpr>()) 257 AttributeExpr(ctx, loc, copyStringWithNull(ctx, value)); 258 } 259 260 //===----------------------------------------------------------------------===// 261 // CallExpr 262 //===----------------------------------------------------------------------===// 263 264 CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable, 265 ArrayRef<Expr *> arguments, Type resultType) { 266 unsigned allocSize = CallExpr::totalSizeToAlloc<Expr *>(arguments.size()); 267 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr)); 268 269 CallExpr *expr = 270 new (rawData) CallExpr(loc, resultType, callable, arguments.size()); 271 std::uninitialized_copy(arguments.begin(), arguments.end(), 272 expr->getArguments().begin()); 273 return expr; 274 } 275 276 //===----------------------------------------------------------------------===// 277 // DeclRefExpr 278 //===----------------------------------------------------------------------===// 279 280 DeclRefExpr *DeclRefExpr::create(Context &ctx, SMRange loc, Decl *decl, 281 Type type) { 282 return new (ctx.getAllocator().Allocate<DeclRefExpr>()) 283 DeclRefExpr(loc, decl, type); 284 } 285 286 //===----------------------------------------------------------------------===// 287 // MemberAccessExpr 288 //===----------------------------------------------------------------------===// 289 290 MemberAccessExpr *MemberAccessExpr::create(Context &ctx, SMRange loc, 291 const Expr *parentExpr, 292 StringRef memberName, Type type) { 293 return new (ctx.getAllocator().Allocate<MemberAccessExpr>()) MemberAccessExpr( 294 loc, parentExpr, memberName.copy(ctx.getAllocator()), type); 295 } 296 297 //===----------------------------------------------------------------------===// 298 // OperationExpr 299 //===----------------------------------------------------------------------===// 300 301 OperationExpr *OperationExpr::create( 302 Context &ctx, SMRange loc, const OpNameDecl *name, 303 ArrayRef<Expr *> operands, ArrayRef<Expr *> resultTypes, 304 ArrayRef<NamedAttributeDecl *> attributes) { 305 unsigned allocSize = 306 OperationExpr::totalSizeToAlloc<Expr *, NamedAttributeDecl *>( 307 operands.size() + resultTypes.size(), attributes.size()); 308 void *rawData = 309 ctx.getAllocator().Allocate(allocSize, alignof(OperationExpr)); 310 311 Type resultType = OperationType::get(ctx, name->getName()); 312 OperationExpr *opExpr = new (rawData) 313 OperationExpr(loc, resultType, name, operands.size(), resultTypes.size(), 314 attributes.size(), name->getLoc()); 315 std::uninitialized_copy(operands.begin(), operands.end(), 316 opExpr->getOperands().begin()); 317 std::uninitialized_copy(resultTypes.begin(), resultTypes.end(), 318 opExpr->getResultTypes().begin()); 319 std::uninitialized_copy(attributes.begin(), attributes.end(), 320 opExpr->getAttributes().begin()); 321 return opExpr; 322 } 323 324 Optional<StringRef> OperationExpr::getName() const { 325 return getNameDecl()->getName(); 326 } 327 328 //===----------------------------------------------------------------------===// 329 // TupleExpr 330 //===----------------------------------------------------------------------===// 331 332 TupleExpr *TupleExpr::create(Context &ctx, SMRange loc, 333 ArrayRef<Expr *> elements, 334 ArrayRef<StringRef> names) { 335 unsigned allocSize = TupleExpr::totalSizeToAlloc<Expr *>(elements.size()); 336 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(TupleExpr)); 337 338 auto elementTypes = llvm::map_range( 339 elements, [](const Expr *expr) { return expr->getType(); }); 340 TupleType type = TupleType::get(ctx, llvm::to_vector(elementTypes), names); 341 342 TupleExpr *expr = new (rawData) TupleExpr(loc, type); 343 std::uninitialized_copy(elements.begin(), elements.end(), 344 expr->getElements().begin()); 345 return expr; 346 } 347 348 //===----------------------------------------------------------------------===// 349 // TypeExpr 350 //===----------------------------------------------------------------------===// 351 352 TypeExpr *TypeExpr::create(Context &ctx, SMRange loc, StringRef value) { 353 return new (ctx.getAllocator().Allocate<TypeExpr>()) 354 TypeExpr(ctx, loc, copyStringWithNull(ctx, value)); 355 } 356 357 //===----------------------------------------------------------------------===// 358 // AttrConstraintDecl 359 //===----------------------------------------------------------------------===// 360 361 AttrConstraintDecl *AttrConstraintDecl::create(Context &ctx, SMRange loc, 362 Expr *typeExpr) { 363 return new (ctx.getAllocator().Allocate<AttrConstraintDecl>()) 364 AttrConstraintDecl(loc, typeExpr); 365 } 366 367 //===----------------------------------------------------------------------===// 368 // OpConstraintDecl 369 //===----------------------------------------------------------------------===// 370 371 OpConstraintDecl *OpConstraintDecl::create(Context &ctx, SMRange loc, 372 const OpNameDecl *nameDecl) { 373 if (!nameDecl) 374 nameDecl = OpNameDecl::create(ctx, SMRange()); 375 376 return new (ctx.getAllocator().Allocate<OpConstraintDecl>()) 377 OpConstraintDecl(loc, nameDecl); 378 } 379 380 Optional<StringRef> OpConstraintDecl::getName() const { 381 return getNameDecl()->getName(); 382 } 383 384 //===----------------------------------------------------------------------===// 385 // TypeConstraintDecl 386 //===----------------------------------------------------------------------===// 387 388 TypeConstraintDecl *TypeConstraintDecl::create(Context &ctx, 389 SMRange loc) { 390 return new (ctx.getAllocator().Allocate<TypeConstraintDecl>()) 391 TypeConstraintDecl(loc); 392 } 393 394 //===----------------------------------------------------------------------===// 395 // TypeRangeConstraintDecl 396 //===----------------------------------------------------------------------===// 397 398 TypeRangeConstraintDecl *TypeRangeConstraintDecl::create(Context &ctx, 399 SMRange loc) { 400 return new (ctx.getAllocator().Allocate<TypeRangeConstraintDecl>()) 401 TypeRangeConstraintDecl(loc); 402 } 403 404 //===----------------------------------------------------------------------===// 405 // ValueConstraintDecl 406 //===----------------------------------------------------------------------===// 407 408 ValueConstraintDecl * 409 ValueConstraintDecl::create(Context &ctx, SMRange loc, Expr *typeExpr) { 410 return new (ctx.getAllocator().Allocate<ValueConstraintDecl>()) 411 ValueConstraintDecl(loc, typeExpr); 412 } 413 414 //===----------------------------------------------------------------------===// 415 // ValueRangeConstraintDecl 416 //===----------------------------------------------------------------------===// 417 418 ValueRangeConstraintDecl *ValueRangeConstraintDecl::create(Context &ctx, 419 SMRange loc, 420 Expr *typeExpr) { 421 return new (ctx.getAllocator().Allocate<ValueRangeConstraintDecl>()) 422 ValueRangeConstraintDecl(loc, typeExpr); 423 } 424 425 //===----------------------------------------------------------------------===// 426 // UserConstraintDecl 427 //===----------------------------------------------------------------------===// 428 429 UserConstraintDecl *UserConstraintDecl::createImpl( 430 Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs, 431 ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock, 432 const CompoundStmt *body, Type resultType) { 433 unsigned allocSize = UserConstraintDecl::totalSizeToAlloc<VariableDecl *>( 434 inputs.size() + results.size()); 435 void *rawData = 436 ctx.getAllocator().Allocate(allocSize, alignof(UserConstraintDecl)); 437 if (codeBlock) 438 codeBlock = codeBlock->copy(ctx.getAllocator()); 439 440 UserConstraintDecl *decl = new (rawData) UserConstraintDecl( 441 name, inputs.size(), results.size(), codeBlock, body, resultType); 442 std::uninitialized_copy(inputs.begin(), inputs.end(), 443 decl->getInputs().begin()); 444 std::uninitialized_copy(results.begin(), results.end(), 445 decl->getResults().begin()); 446 return decl; 447 } 448 449 //===----------------------------------------------------------------------===// 450 // NamedAttributeDecl 451 //===----------------------------------------------------------------------===// 452 453 NamedAttributeDecl *NamedAttributeDecl::create(Context &ctx, const Name &name, 454 Expr *value) { 455 return new (ctx.getAllocator().Allocate<NamedAttributeDecl>()) 456 NamedAttributeDecl(name, value); 457 } 458 459 //===----------------------------------------------------------------------===// 460 // OpNameDecl 461 //===----------------------------------------------------------------------===// 462 463 OpNameDecl *OpNameDecl::create(Context &ctx, const Name &name) { 464 return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(name); 465 } 466 OpNameDecl *OpNameDecl::create(Context &ctx, SMRange loc) { 467 return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(loc); 468 } 469 470 //===----------------------------------------------------------------------===// 471 // PatternDecl 472 //===----------------------------------------------------------------------===// 473 474 PatternDecl *PatternDecl::create(Context &ctx, SMRange loc, 475 const Name *name, Optional<uint16_t> benefit, 476 bool hasBoundedRecursion, 477 const CompoundStmt *body) { 478 return new (ctx.getAllocator().Allocate<PatternDecl>()) 479 PatternDecl(loc, name, benefit, hasBoundedRecursion, body); 480 } 481 482 //===----------------------------------------------------------------------===// 483 // UserRewriteDecl 484 //===----------------------------------------------------------------------===// 485 486 UserRewriteDecl *UserRewriteDecl::createImpl(Context &ctx, const Name &name, 487 ArrayRef<VariableDecl *> inputs, 488 ArrayRef<VariableDecl *> results, 489 Optional<StringRef> codeBlock, 490 const CompoundStmt *body, 491 Type resultType) { 492 unsigned allocSize = UserRewriteDecl::totalSizeToAlloc<VariableDecl *>( 493 inputs.size() + results.size()); 494 void *rawData = 495 ctx.getAllocator().Allocate(allocSize, alignof(UserRewriteDecl)); 496 if (codeBlock) 497 codeBlock = codeBlock->copy(ctx.getAllocator()); 498 499 UserRewriteDecl *decl = new (rawData) UserRewriteDecl( 500 name, inputs.size(), results.size(), codeBlock, body, resultType); 501 std::uninitialized_copy(inputs.begin(), inputs.end(), 502 decl->getInputs().begin()); 503 std::uninitialized_copy(results.begin(), results.end(), 504 decl->getResults().begin()); 505 return decl; 506 } 507 508 //===----------------------------------------------------------------------===// 509 // VariableDecl 510 //===----------------------------------------------------------------------===// 511 512 VariableDecl *VariableDecl::create(Context &ctx, const Name &name, Type type, 513 Expr *initExpr, 514 ArrayRef<ConstraintRef> constraints) { 515 unsigned allocSize = 516 VariableDecl::totalSizeToAlloc<ConstraintRef>(constraints.size()); 517 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(VariableDecl)); 518 519 VariableDecl *varDecl = 520 new (rawData) VariableDecl(name, type, initExpr, constraints.size()); 521 std::uninitialized_copy(constraints.begin(), constraints.end(), 522 varDecl->getConstraints().begin()); 523 return varDecl; 524 } 525 526 //===----------------------------------------------------------------------===// 527 // Module 528 //===----------------------------------------------------------------------===// 529 530 Module *Module::create(Context &ctx, SMLoc loc, 531 ArrayRef<Decl *> children) { 532 unsigned allocSize = Module::totalSizeToAlloc<Decl *>(children.size()); 533 void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(Module)); 534 535 Module *module = new (rawData) Module(loc, children.size()); 536 std::uninitialized_copy(children.begin(), children.end(), 537 module->getChildren().begin()); 538 return module; 539 } 540