1 //===- TransformDialect.cpp - Transform dialect operations ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformOps.h" 10 #include "mlir/Dialect/PDL/IR/PDLOps.h" 11 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/Interfaces/ControlFlowInterfaces.h" 16 #include "mlir/Rewrite/FrozenRewritePatternSet.h" 17 #include "mlir/Rewrite/PatternApplicator.h" 18 #include "llvm/ADT/ScopeExit.h" 19 #include "llvm/Support/Debug.h" 20 21 #define DEBUG_TYPE "transform-dialect" 22 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") 23 24 using namespace mlir; 25 26 static ParseResult parsePDLOpTypedResults( 27 OpAsmParser &parser, SmallVectorImpl<Type> &types, 28 const SmallVectorImpl<OpAsmParser::UnresolvedOperand> &handles) { 29 types.resize(handles.size(), pdl::OperationType::get(parser.getContext())); 30 return success(); 31 } 32 33 static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange, 34 ValueRange) {} 35 36 #define GET_OP_CLASSES 37 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" 38 39 //===----------------------------------------------------------------------===// 40 // PatternApplicatorExtension 41 //===----------------------------------------------------------------------===// 42 43 namespace { 44 /// A simple pattern rewriter that can be constructed from a context. This is 45 /// necessary to apply patterns to a specific op locally. 46 class TrivialPatternRewriter : public PatternRewriter { 47 public: 48 explicit TrivialPatternRewriter(MLIRContext *context) 49 : PatternRewriter(context) {} 50 }; 51 52 /// A TransformState extension that keeps track of compiled PDL pattern sets. 53 /// This is intended to be used along the WithPDLPatterns op. The extension 54 /// can be constructed given an operation that has a SymbolTable trait and 55 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one 56 /// by one when requested; this behavior is subject to change. 57 class PatternApplicatorExtension : public transform::TransformState::Extension { 58 public: 59 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension) 60 61 /// Creates the extension for patterns contained in `patternContainer`. 62 explicit PatternApplicatorExtension(transform::TransformState &state, 63 Operation *patternContainer) 64 : Extension(state), patterns(patternContainer) {} 65 66 /// Appends to `results` the operations contained in `root` that matched the 67 /// PDL pattern with the given name. Note that `root` may or may not be the 68 /// operation that contains PDL patterns. Reports an error if the pattern 69 /// cannot be found. Note that when no operations are matched, this still 70 /// succeeds as long as the pattern exists. 71 LogicalResult findAllMatches(StringRef patternName, Operation *root, 72 SmallVectorImpl<Operation *> &results); 73 74 private: 75 /// Map from the pattern name to a singleton set of rewrite patterns that only 76 /// contains the pattern with this name. Populated when the pattern is first 77 /// requested. 78 // TODO: reconsider the efficiency of this storage when more usage data is 79 // available. Storing individual patterns in a set and triggering compilation 80 // for each of them has overhead. So does compiling a large set of patterns 81 // only to apply a handlful of them. 82 llvm::StringMap<FrozenRewritePatternSet> compiledPatterns; 83 84 /// A symbol table operation containing the relevant PDL patterns. 85 SymbolTable patterns; 86 }; 87 88 LogicalResult PatternApplicatorExtension::findAllMatches( 89 StringRef patternName, Operation *root, 90 SmallVectorImpl<Operation *> &results) { 91 auto it = compiledPatterns.find(patternName); 92 if (it == compiledPatterns.end()) { 93 auto patternOp = patterns.lookup<pdl::PatternOp>(patternName); 94 if (!patternOp) 95 return failure(); 96 97 OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc()); 98 patternOp->moveBefore(pdlModuleOp->getBody(), 99 pdlModuleOp->getBody()->end()); 100 PDLPatternModule patternModule(std::move(pdlModuleOp)); 101 102 // Merge in the hooks owned by the dialect. Make a copy as they may be 103 // also used by the following operations. 104 auto *dialect = 105 root->getContext()->getLoadedDialect<transform::TransformDialect>(); 106 for (const auto &pair : dialect->getPDLConstraintHooks()) 107 patternModule.registerConstraintFunction(pair.first(), pair.second); 108 109 // Register a noop rewriter because PDL requires patterns to end with some 110 // rewrite call. 111 patternModule.registerRewriteFunction( 112 "transform.dialect", [](PatternRewriter &, Operation *) {}); 113 114 it = compiledPatterns 115 .try_emplace(patternOp.getName(), std::move(patternModule)) 116 .first; 117 } 118 119 PatternApplicator applicator(it->second); 120 TrivialPatternRewriter rewriter(root->getContext()); 121 applicator.applyDefaultCostModel(); 122 root->walk([&](Operation *op) { 123 if (succeeded(applicator.matchAndRewrite(op, rewriter))) 124 results.push_back(op); 125 }); 126 127 return success(); 128 } 129 } // namespace 130 131 //===----------------------------------------------------------------------===// 132 // AlternativesOp 133 //===----------------------------------------------------------------------===// 134 135 OperandRange 136 transform::AlternativesOp::getSuccessorEntryOperands(Optional<unsigned> index) { 137 if (index && getOperation()->getNumOperands() == 1) 138 return getOperation()->getOperands(); 139 return OperandRange(getOperation()->operand_end(), 140 getOperation()->operand_end()); 141 } 142 143 void transform::AlternativesOp::getSuccessorRegions( 144 Optional<unsigned> index, ArrayRef<Attribute> operands, 145 SmallVectorImpl<RegionSuccessor> ®ions) { 146 for (Region &alternative : 147 llvm::drop_begin(getAlternatives(), index.hasValue() ? *index + 1 : 0)) { 148 regions.emplace_back(&alternative, !getOperands().empty() 149 ? alternative.getArguments() 150 : Block::BlockArgListType()); 151 } 152 if (index.hasValue()) 153 regions.emplace_back(getOperation()->getResults()); 154 } 155 156 void transform::AlternativesOp::getRegionInvocationBounds( 157 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { 158 (void)operands; 159 // The region corresponding to the first alternative is always executed, the 160 // remaining may or may not be executed. 161 bounds.reserve(getNumRegions()); 162 bounds.emplace_back(1, 1); 163 bounds.resize(getNumRegions(), InvocationBounds(0, 1)); 164 } 165 166 static void forwardTerminatorOperands(Block *block, 167 transform::TransformState &state, 168 transform::TransformResults &results) { 169 for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(), 170 block->getParentOp()->getOpResults())) { 171 Value terminatorOperand = std::get<0>(pair); 172 OpResult result = std::get<1>(pair); 173 results.set(result, state.getPayloadOps(terminatorOperand)); 174 } 175 } 176 177 DiagnosedSilenceableFailure 178 transform::AlternativesOp::apply(transform::TransformResults &results, 179 transform::TransformState &state) { 180 SmallVector<Operation *> originals; 181 if (Value scopeHandle = getScope()) 182 llvm::append_range(originals, state.getPayloadOps(scopeHandle)); 183 else 184 originals.push_back(state.getTopLevel()); 185 186 for (Operation *original : originals) { 187 if (original->isAncestor(getOperation())) { 188 InFlightDiagnostic diag = 189 emitError() << "scope must not contain the transforms being applied"; 190 diag.attachNote(original->getLoc()) << "scope"; 191 return DiagnosedSilenceableFailure::definiteFailure(); 192 } 193 if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 194 InFlightDiagnostic diag = 195 emitError() 196 << "only isolated-from-above ops can be alternative scopes"; 197 diag.attachNote(original->getLoc()) << "scope"; 198 return DiagnosedSilenceableFailure(std::move(diag)); 199 } 200 } 201 202 for (Region ® : getAlternatives()) { 203 // Clone the scope operations and make the transforms in this alternative 204 // region apply to them by virtue of mapping the block argument (the only 205 // visible handle) to the cloned scope operations. This effectively prevents 206 // the transformation from accessing any IR outside the scope. 207 auto scope = state.make_region_scope(reg); 208 auto clones = llvm::to_vector( 209 llvm::map_range(originals, [](Operation *op) { return op->clone(); })); 210 auto deleteClones = llvm::make_scope_exit([&] { 211 for (Operation *clone : clones) 212 clone->erase(); 213 }); 214 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones))) 215 return DiagnosedSilenceableFailure::definiteFailure(); 216 217 bool failed = false; 218 for (Operation &transform : reg.front().without_terminator()) { 219 DiagnosedSilenceableFailure result = 220 state.applyTransform(cast<TransformOpInterface>(transform)); 221 if (result.isSilenceableFailure()) { 222 LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage() 223 << "\n"); 224 failed = true; 225 break; 226 } 227 228 if (::mlir::failed(result.silence())) 229 return DiagnosedSilenceableFailure::definiteFailure(); 230 } 231 232 // If all operations in the given alternative succeeded, no need to consider 233 // the rest. Replace the original scoping operation with the clone on which 234 // the transformations were performed. 235 if (!failed) { 236 // We will be using the clones, so cancel their scheduled deletion. 237 deleteClones.release(); 238 IRRewriter rewriter(getContext()); 239 for (const auto &kvp : llvm::zip(originals, clones)) { 240 Operation *original = std::get<0>(kvp); 241 Operation *clone = std::get<1>(kvp); 242 original->getBlock()->getOperations().insert(original->getIterator(), 243 clone); 244 rewriter.replaceOp(original, clone->getResults()); 245 } 246 forwardTerminatorOperands(®.front(), state, results); 247 return DiagnosedSilenceableFailure::success(); 248 } 249 } 250 return emitSilenceableError() << "all alternatives failed"; 251 } 252 253 LogicalResult transform::AlternativesOp::verify() { 254 for (Region &alternative : getAlternatives()) { 255 Block &block = alternative.front(); 256 if (block.getNumArguments() != 1 || 257 !block.getArgument(0).getType().isa<pdl::OperationType>()) { 258 return emitOpError() 259 << "expects region blocks to have one operand of type " 260 << pdl::OperationType::get(getContext()); 261 } 262 263 Operation *terminator = block.getTerminator(); 264 if (terminator->getOperands().getTypes() != getResults().getTypes()) { 265 InFlightDiagnostic diag = emitOpError() 266 << "expects terminator operands to have the " 267 "same type as results of the operation"; 268 diag.attachNote(terminator->getLoc()) << "terminator"; 269 return diag; 270 } 271 } 272 273 return success(); 274 } 275 276 //===----------------------------------------------------------------------===// 277 // GetClosestIsolatedParentOp 278 //===----------------------------------------------------------------------===// 279 280 DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply( 281 transform::TransformResults &results, transform::TransformState &state) { 282 SetVector<Operation *> parents; 283 for (Operation *target : state.getPayloadOps(getTarget())) { 284 Operation *parent = 285 target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(); 286 if (!parent) { 287 DiagnosedSilenceableFailure diag = 288 emitSilenceableError() 289 << "could not find an isolated-from-above parent op"; 290 diag.attachNote(target->getLoc()) << "target op"; 291 return diag; 292 } 293 parents.insert(parent); 294 } 295 results.set(getResult().cast<OpResult>(), parents.getArrayRef()); 296 return DiagnosedSilenceableFailure::success(); 297 } 298 299 //===----------------------------------------------------------------------===// 300 // MergeHandlesOp 301 //===----------------------------------------------------------------------===// 302 303 DiagnosedSilenceableFailure 304 transform::MergeHandlesOp::apply(transform::TransformResults &results, 305 transform::TransformState &state) { 306 SmallVector<Operation *> operations; 307 for (Value operand : getHandles()) 308 llvm::append_range(operations, state.getPayloadOps(operand)); 309 if (!getDeduplicate()) { 310 results.set(getResult().cast<OpResult>(), operations); 311 return DiagnosedSilenceableFailure::success(); 312 } 313 314 SetVector<Operation *> uniqued(operations.begin(), operations.end()); 315 results.set(getResult().cast<OpResult>(), uniqued.getArrayRef()); 316 return DiagnosedSilenceableFailure::success(); 317 } 318 319 void transform::MergeHandlesOp::getEffects( 320 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 321 for (Value operand : getHandles()) { 322 effects.emplace_back(MemoryEffects::Read::get(), operand, 323 transform::TransformMappingResource::get()); 324 effects.emplace_back(MemoryEffects::Free::get(), operand, 325 transform::TransformMappingResource::get()); 326 } 327 effects.emplace_back(MemoryEffects::Allocate::get(), getResult(), 328 transform::TransformMappingResource::get()); 329 effects.emplace_back(MemoryEffects::Write::get(), getResult(), 330 transform::TransformMappingResource::get()); 331 332 // There are no effects on the Payload IR as this is only a handle 333 // manipulation. 334 } 335 336 OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) { 337 if (getDeduplicate() || getHandles().size() != 1) 338 return {}; 339 340 // If deduplication is not required and there is only one operand, it can be 341 // used directly instead of merging. 342 return getHandles().front(); 343 } 344 345 //===----------------------------------------------------------------------===// 346 // PDLMatchOp 347 //===----------------------------------------------------------------------===// 348 349 DiagnosedSilenceableFailure 350 transform::PDLMatchOp::apply(transform::TransformResults &results, 351 transform::TransformState &state) { 352 auto *extension = state.getExtension<PatternApplicatorExtension>(); 353 assert(extension && 354 "expected PatternApplicatorExtension to be attached by the parent op"); 355 SmallVector<Operation *> targets; 356 for (Operation *root : state.getPayloadOps(getRoot())) { 357 if (failed(extension->findAllMatches( 358 getPatternName().getLeafReference().getValue(), root, targets))) { 359 emitOpError() << "could not find pattern '" << getPatternName() << "'"; 360 return DiagnosedSilenceableFailure::definiteFailure(); 361 } 362 } 363 results.set(getResult().cast<OpResult>(), targets); 364 return DiagnosedSilenceableFailure::success(); 365 } 366 367 //===----------------------------------------------------------------------===// 368 // ReplicateOp 369 //===----------------------------------------------------------------------===// 370 371 DiagnosedSilenceableFailure 372 transform::ReplicateOp::apply(transform::TransformResults &results, 373 transform::TransformState &state) { 374 unsigned numRepetitions = state.getPayloadOps(getPattern()).size(); 375 for (const auto &en : llvm::enumerate(getHandles())) { 376 Value handle = en.value(); 377 ArrayRef<Operation *> current = state.getPayloadOps(handle); 378 SmallVector<Operation *> payload; 379 payload.reserve(numRepetitions * current.size()); 380 for (unsigned i = 0; i < numRepetitions; ++i) 381 llvm::append_range(payload, current); 382 results.set(getReplicated()[en.index()].cast<OpResult>(), payload); 383 } 384 return DiagnosedSilenceableFailure::success(); 385 } 386 387 void transform::ReplicateOp::getEffects( 388 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 389 onlyReadsHandle(getPattern(), effects); 390 consumesHandle(getHandles(), effects); 391 producesHandle(getReplicated(), effects); 392 } 393 394 //===----------------------------------------------------------------------===// 395 // SequenceOp 396 //===----------------------------------------------------------------------===// 397 398 DiagnosedSilenceableFailure 399 transform::SequenceOp::apply(transform::TransformResults &results, 400 transform::TransformState &state) { 401 // Map the entry block argument to the list of operations. 402 auto scope = state.make_region_scope(*getBodyBlock()->getParent()); 403 if (failed(mapBlockArguments(state))) 404 return DiagnosedSilenceableFailure::definiteFailure(); 405 406 // Apply the sequenced ops one by one. 407 for (Operation &transform : getBodyBlock()->without_terminator()) { 408 DiagnosedSilenceableFailure result = 409 state.applyTransform(cast<TransformOpInterface>(transform)); 410 if (!result.succeeded()) 411 return result; 412 } 413 414 // Forward the operation mapping for values yielded from the sequence to the 415 // values produced by the sequence op. 416 forwardTerminatorOperands(getBodyBlock(), state, results); 417 return DiagnosedSilenceableFailure::success(); 418 } 419 420 /// Returns `true` if the given op operand may be consuming the handle value in 421 /// the Transform IR. That is, if it may have a Free effect on it. 422 static bool isValueUsePotentialConsumer(OpOperand &use) { 423 // Conservatively assume the effect being present in absence of the interface. 424 auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(use.getOwner()); 425 if (!memEffectInterface) 426 return true; 427 428 SmallVector<MemoryEffects::EffectInstance, 2> effects; 429 memEffectInterface.getEffectsOnValue(use.get(), effects); 430 return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { 431 return isa<transform::TransformMappingResource>(effect.getResource()) && 432 isa<MemoryEffects::Free>(effect.getEffect()); 433 }); 434 } 435 436 LogicalResult 437 checkDoubleConsume(Value value, 438 function_ref<InFlightDiagnostic()> reportError) { 439 OpOperand *potentialConsumer = nullptr; 440 for (OpOperand &use : value.getUses()) { 441 if (!isValueUsePotentialConsumer(use)) 442 continue; 443 444 if (!potentialConsumer) { 445 potentialConsumer = &use; 446 continue; 447 } 448 449 InFlightDiagnostic diag = reportError() 450 << " has more than one potential consumer"; 451 diag.attachNote(potentialConsumer->getOwner()->getLoc()) 452 << "used here as operand #" << potentialConsumer->getOperandNumber(); 453 diag.attachNote(use.getOwner()->getLoc()) 454 << "used here as operand #" << use.getOperandNumber(); 455 return diag; 456 } 457 458 return success(); 459 } 460 461 LogicalResult transform::SequenceOp::verify() { 462 // Check if the block argument has more than one consuming use. 463 for (BlockArgument argument : getBodyBlock()->getArguments()) { 464 auto report = [&]() { 465 return (emitOpError() << "block argument #" << argument.getArgNumber()); 466 }; 467 if (failed(checkDoubleConsume(argument, report))) 468 return failure(); 469 } 470 471 // Check properties of the nested operations they cannot check themselves. 472 for (Operation &child : *getBodyBlock()) { 473 if (!isa<TransformOpInterface>(child) && 474 &child != &getBodyBlock()->back()) { 475 InFlightDiagnostic diag = 476 emitOpError() 477 << "expected children ops to implement TransformOpInterface"; 478 diag.attachNote(child.getLoc()) << "op without interface"; 479 return diag; 480 } 481 482 for (OpResult result : child.getResults()) { 483 auto report = [&]() { 484 return (child.emitError() << "result #" << result.getResultNumber()); 485 }; 486 if (failed(checkDoubleConsume(result, report))) 487 return failure(); 488 } 489 } 490 491 if (getBodyBlock()->getTerminator()->getOperandTypes() != 492 getOperation()->getResultTypes()) { 493 InFlightDiagnostic diag = emitOpError() 494 << "expects the types of the terminator operands " 495 "to match the types of the result"; 496 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator"; 497 return diag; 498 } 499 return success(); 500 } 501 502 void transform::SequenceOp::getEffects( 503 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 504 auto *mappingResource = TransformMappingResource::get(); 505 effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource); 506 507 for (Value result : getResults()) { 508 effects.emplace_back(MemoryEffects::Allocate::get(), result, 509 mappingResource); 510 effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource); 511 } 512 513 if (!getRoot()) { 514 for (Operation &op : *getBodyBlock()) { 515 auto iface = dyn_cast<MemoryEffectOpInterface>(&op); 516 if (!iface) { 517 // TODO: fill all possible effects; or require ops to actually implement 518 // the memory effect interface always 519 assert(false); 520 } 521 522 SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects; 523 iface.getEffects(effects); 524 } 525 return; 526 } 527 528 // Carry over all effects on the argument of the entry block as those on the 529 // operand, this is the same value just remapped. 530 for (Operation &op : *getBodyBlock()) { 531 auto iface = dyn_cast<MemoryEffectOpInterface>(&op); 532 if (!iface) { 533 // TODO: fill all possible effects; or require ops to actually implement 534 // the memory effect interface always 535 assert(false); 536 } 537 538 SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects; 539 iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects); 540 for (const auto &effect : nestedEffects) 541 effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource()); 542 } 543 } 544 545 OperandRange 546 transform::SequenceOp::getSuccessorEntryOperands(Optional<unsigned> index) { 547 assert(index && *index == 0 && "unexpected region index"); 548 if (getOperation()->getNumOperands() == 1) 549 return getOperation()->getOperands(); 550 return OperandRange(getOperation()->operand_end(), 551 getOperation()->operand_end()); 552 } 553 554 void transform::SequenceOp::getSuccessorRegions( 555 Optional<unsigned> index, ArrayRef<Attribute> operands, 556 SmallVectorImpl<RegionSuccessor> ®ions) { 557 if (!index) { 558 Region *bodyRegion = &getBody(); 559 regions.emplace_back(bodyRegion, !operands.empty() 560 ? bodyRegion->getArguments() 561 : Block::BlockArgListType()); 562 return; 563 } 564 565 assert(*index == 0 && "unexpected region index"); 566 regions.emplace_back(getOperation()->getResults()); 567 } 568 569 void transform::SequenceOp::getRegionInvocationBounds( 570 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { 571 (void)operands; 572 bounds.emplace_back(1, 1); 573 } 574 575 //===----------------------------------------------------------------------===// 576 // WithPDLPatternsOp 577 //===----------------------------------------------------------------------===// 578 579 DiagnosedSilenceableFailure 580 transform::WithPDLPatternsOp::apply(transform::TransformResults &results, 581 transform::TransformState &state) { 582 OwningOpRef<ModuleOp> pdlModuleOp = 583 ModuleOp::create(getOperation()->getLoc()); 584 TransformOpInterface transformOp = nullptr; 585 for (Operation &nested : getBody().front()) { 586 if (!isa<pdl::PatternOp>(nested)) { 587 transformOp = cast<TransformOpInterface>(nested); 588 break; 589 } 590 } 591 592 state.addExtension<PatternApplicatorExtension>(getOperation()); 593 auto guard = llvm::make_scope_exit( 594 [&]() { state.removeExtension<PatternApplicatorExtension>(); }); 595 596 auto scope = state.make_region_scope(getBody()); 597 if (failed(mapBlockArguments(state))) 598 return DiagnosedSilenceableFailure::definiteFailure(); 599 return state.applyTransform(transformOp); 600 } 601 602 LogicalResult transform::WithPDLPatternsOp::verify() { 603 Block *body = getBodyBlock(); 604 Operation *topLevelOp = nullptr; 605 for (Operation &op : body->getOperations()) { 606 if (isa<pdl::PatternOp>(op)) 607 continue; 608 609 if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) { 610 if (topLevelOp) { 611 InFlightDiagnostic diag = 612 emitOpError() << "expects only one non-pattern op in its body"; 613 diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op"; 614 diag.attachNote(op.getLoc()) << "second non-pattern op"; 615 return diag; 616 } 617 topLevelOp = &op; 618 continue; 619 } 620 621 InFlightDiagnostic diag = 622 emitOpError() 623 << "expects only pattern and top-level transform ops in its body"; 624 diag.attachNote(op.getLoc()) << "offending op"; 625 return diag; 626 } 627 628 if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) { 629 InFlightDiagnostic diag = emitOpError() << "cannot be nested"; 630 diag.attachNote(parent.getLoc()) << "parent operation"; 631 return diag; 632 } 633 634 return success(); 635 } 636