1 //===- TransformDialect.cpp - Transform dialect operations ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformOps.h" 10 #include "mlir/Dialect/PDL/IR/PDLOps.h" 11 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/Interfaces/ControlFlowInterfaces.h" 16 #include "mlir/Rewrite/FrozenRewritePatternSet.h" 17 #include "mlir/Rewrite/PatternApplicator.h" 18 #include "llvm/ADT/ScopeExit.h" 19 #include "llvm/Support/Debug.h" 20 21 #define DEBUG_TYPE "transform-dialect" 22 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") 23 24 using namespace mlir; 25 26 static ParseResult parsePDLOpTypedResults( 27 OpAsmParser &parser, SmallVectorImpl<Type> &types, 28 const SmallVectorImpl<OpAsmParser::UnresolvedOperand> &handles) { 29 types.resize(handles.size(), pdl::OperationType::get(parser.getContext())); 30 return success(); 31 } 32 33 static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange, 34 ValueRange) {} 35 36 #define GET_OP_CLASSES 37 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" 38 39 //===----------------------------------------------------------------------===// 40 // PatternApplicatorExtension 41 //===----------------------------------------------------------------------===// 42 43 namespace { 44 /// A simple pattern rewriter that can be constructed from a context. This is 45 /// necessary to apply patterns to a specific op locally. 46 class TrivialPatternRewriter : public PatternRewriter { 47 public: 48 explicit TrivialPatternRewriter(MLIRContext *context) 49 : PatternRewriter(context) {} 50 }; 51 52 /// A TransformState extension that keeps track of compiled PDL pattern sets. 53 /// This is intended to be used along the WithPDLPatterns op. The extension 54 /// can be constructed given an operation that has a SymbolTable trait and 55 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one 56 /// by one when requested; this behavior is subject to change. 57 class PatternApplicatorExtension : public transform::TransformState::Extension { 58 public: 59 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension) 60 61 /// Creates the extension for patterns contained in `patternContainer`. 62 explicit PatternApplicatorExtension(transform::TransformState &state, 63 Operation *patternContainer) 64 : Extension(state), patterns(patternContainer) {} 65 66 /// Appends to `results` the operations contained in `root` that matched the 67 /// PDL pattern with the given name. Note that `root` may or may not be the 68 /// operation that contains PDL patterns. Reports an error if the pattern 69 /// cannot be found. Note that when no operations are matched, this still 70 /// succeeds as long as the pattern exists. 71 LogicalResult findAllMatches(StringRef patternName, Operation *root, 72 SmallVectorImpl<Operation *> &results); 73 74 private: 75 /// Map from the pattern name to a singleton set of rewrite patterns that only 76 /// contains the pattern with this name. Populated when the pattern is first 77 /// requested. 78 // TODO: reconsider the efficiency of this storage when more usage data is 79 // available. Storing individual patterns in a set and triggering compilation 80 // for each of them has overhead. So does compiling a large set of patterns 81 // only to apply a handlful of them. 82 llvm::StringMap<FrozenRewritePatternSet> compiledPatterns; 83 84 /// A symbol table operation containing the relevant PDL patterns. 85 SymbolTable patterns; 86 }; 87 88 LogicalResult PatternApplicatorExtension::findAllMatches( 89 StringRef patternName, Operation *root, 90 SmallVectorImpl<Operation *> &results) { 91 auto it = compiledPatterns.find(patternName); 92 if (it == compiledPatterns.end()) { 93 auto patternOp = patterns.lookup<pdl::PatternOp>(patternName); 94 if (!patternOp) 95 return failure(); 96 97 OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc()); 98 patternOp->moveBefore(pdlModuleOp->getBody(), 99 pdlModuleOp->getBody()->end()); 100 PDLPatternModule patternModule(std::move(pdlModuleOp)); 101 102 // Merge in the hooks owned by the dialect. Make a copy as they may be 103 // also used by the following operations. 104 auto *dialect = 105 root->getContext()->getLoadedDialect<transform::TransformDialect>(); 106 for (const auto &pair : dialect->getPDLConstraintHooks()) 107 patternModule.registerConstraintFunction(pair.first(), pair.second); 108 109 // Register a noop rewriter because PDL requires patterns to end with some 110 // rewrite call. 111 patternModule.registerRewriteFunction( 112 "transform.dialect", [](PatternRewriter &, Operation *) {}); 113 114 it = compiledPatterns 115 .try_emplace(patternOp.getName(), std::move(patternModule)) 116 .first; 117 } 118 119 PatternApplicator applicator(it->second); 120 TrivialPatternRewriter rewriter(root->getContext()); 121 applicator.applyDefaultCostModel(); 122 root->walk([&](Operation *op) { 123 if (succeeded(applicator.matchAndRewrite(op, rewriter))) 124 results.push_back(op); 125 }); 126 127 return success(); 128 } 129 } // namespace 130 131 //===----------------------------------------------------------------------===// 132 // AlternativesOp 133 //===----------------------------------------------------------------------===// 134 135 OperandRange 136 transform::AlternativesOp::getSuccessorEntryOperands(Optional<unsigned> index) { 137 if (index && getOperation()->getNumOperands() == 1) 138 return getOperation()->getOperands(); 139 return OperandRange(getOperation()->operand_end(), 140 getOperation()->operand_end()); 141 } 142 143 void transform::AlternativesOp::getSuccessorRegions( 144 Optional<unsigned> index, ArrayRef<Attribute> operands, 145 SmallVectorImpl<RegionSuccessor> ®ions) { 146 for (Region &alternative : llvm::drop_begin( 147 getAlternatives(), index.has_value() ? *index + 1 : 0)) { 148 regions.emplace_back(&alternative, !getOperands().empty() 149 ? alternative.getArguments() 150 : Block::BlockArgListType()); 151 } 152 if (index.has_value()) 153 regions.emplace_back(getOperation()->getResults()); 154 } 155 156 void transform::AlternativesOp::getRegionInvocationBounds( 157 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { 158 (void)operands; 159 // The region corresponding to the first alternative is always executed, the 160 // remaining may or may not be executed. 161 bounds.reserve(getNumRegions()); 162 bounds.emplace_back(1, 1); 163 bounds.resize(getNumRegions(), InvocationBounds(0, 1)); 164 } 165 166 static void forwardTerminatorOperands(Block *block, 167 transform::TransformState &state, 168 transform::TransformResults &results) { 169 for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(), 170 block->getParentOp()->getOpResults())) { 171 Value terminatorOperand = std::get<0>(pair); 172 OpResult result = std::get<1>(pair); 173 results.set(result, state.getPayloadOps(terminatorOperand)); 174 } 175 } 176 177 DiagnosedSilenceableFailure 178 transform::AlternativesOp::apply(transform::TransformResults &results, 179 transform::TransformState &state) { 180 SmallVector<Operation *> originals; 181 if (Value scopeHandle = getScope()) 182 llvm::append_range(originals, state.getPayloadOps(scopeHandle)); 183 else 184 originals.push_back(state.getTopLevel()); 185 186 for (Operation *original : originals) { 187 if (original->isAncestor(getOperation())) { 188 InFlightDiagnostic diag = 189 emitError() << "scope must not contain the transforms being applied"; 190 diag.attachNote(original->getLoc()) << "scope"; 191 return DiagnosedSilenceableFailure::definiteFailure(); 192 } 193 if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 194 InFlightDiagnostic diag = 195 emitError() 196 << "only isolated-from-above ops can be alternative scopes"; 197 diag.attachNote(original->getLoc()) << "scope"; 198 return DiagnosedSilenceableFailure(std::move(diag)); 199 } 200 } 201 202 for (Region ® : getAlternatives()) { 203 // Clone the scope operations and make the transforms in this alternative 204 // region apply to them by virtue of mapping the block argument (the only 205 // visible handle) to the cloned scope operations. This effectively prevents 206 // the transformation from accessing any IR outside the scope. 207 auto scope = state.make_region_scope(reg); 208 auto clones = llvm::to_vector( 209 llvm::map_range(originals, [](Operation *op) { return op->clone(); })); 210 auto deleteClones = llvm::make_scope_exit([&] { 211 for (Operation *clone : clones) 212 clone->erase(); 213 }); 214 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones))) 215 return DiagnosedSilenceableFailure::definiteFailure(); 216 217 bool failed = false; 218 for (Operation &transform : reg.front().without_terminator()) { 219 DiagnosedSilenceableFailure result = 220 state.applyTransform(cast<TransformOpInterface>(transform)); 221 if (result.isSilenceableFailure()) { 222 LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage() 223 << "\n"); 224 failed = true; 225 break; 226 } 227 228 if (::mlir::failed(result.silence())) 229 return DiagnosedSilenceableFailure::definiteFailure(); 230 } 231 232 // If all operations in the given alternative succeeded, no need to consider 233 // the rest. Replace the original scoping operation with the clone on which 234 // the transformations were performed. 235 if (!failed) { 236 // We will be using the clones, so cancel their scheduled deletion. 237 deleteClones.release(); 238 IRRewriter rewriter(getContext()); 239 for (const auto &kvp : llvm::zip(originals, clones)) { 240 Operation *original = std::get<0>(kvp); 241 Operation *clone = std::get<1>(kvp); 242 original->getBlock()->getOperations().insert(original->getIterator(), 243 clone); 244 rewriter.replaceOp(original, clone->getResults()); 245 } 246 forwardTerminatorOperands(®.front(), state, results); 247 return DiagnosedSilenceableFailure::success(); 248 } 249 } 250 return emitSilenceableError() << "all alternatives failed"; 251 } 252 253 LogicalResult transform::AlternativesOp::verify() { 254 for (Region &alternative : getAlternatives()) { 255 Block &block = alternative.front(); 256 if (block.getNumArguments() != 1 || 257 !block.getArgument(0).getType().isa<pdl::OperationType>()) { 258 return emitOpError() 259 << "expects region blocks to have one operand of type " 260 << pdl::OperationType::get(getContext()); 261 } 262 263 Operation *terminator = block.getTerminator(); 264 if (terminator->getOperands().getTypes() != getResults().getTypes()) { 265 InFlightDiagnostic diag = emitOpError() 266 << "expects terminator operands to have the " 267 "same type as results of the operation"; 268 diag.attachNote(terminator->getLoc()) << "terminator"; 269 return diag; 270 } 271 } 272 273 return success(); 274 } 275 276 //===----------------------------------------------------------------------===// 277 // ForeachOp 278 //===----------------------------------------------------------------------===// 279 280 DiagnosedSilenceableFailure 281 transform::ForeachOp::apply(transform::TransformResults &results, 282 transform::TransformState &state) { 283 ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget()); 284 for (Operation *op : payloadOps) { 285 auto scope = state.make_region_scope(getBody()); 286 if (failed(state.mapBlockArguments(getIterationVariable(), {op}))) 287 return DiagnosedSilenceableFailure::definiteFailure(); 288 289 for (Operation &transform : getBody().front().without_terminator()) { 290 DiagnosedSilenceableFailure result = state.applyTransform( 291 cast<transform::TransformOpInterface>(transform)); 292 if (!result.succeeded()) 293 return result; 294 } 295 } 296 return DiagnosedSilenceableFailure::success(); 297 } 298 299 void transform::ForeachOp::getEffects( 300 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 301 BlockArgument iterVar = getIterationVariable(); 302 if (any_of(getBody().front().without_terminator(), [&](Operation &op) { 303 return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op)); 304 })) { 305 consumesHandle(getTarget(), effects); 306 } else { 307 onlyReadsHandle(getTarget(), effects); 308 } 309 } 310 311 void transform::ForeachOp::getSuccessorRegions( 312 Optional<unsigned> index, ArrayRef<Attribute> operands, 313 SmallVectorImpl<RegionSuccessor> ®ions) { 314 Region *bodyRegion = &getBody(); 315 if (!index) { 316 regions.emplace_back(bodyRegion, bodyRegion->getArguments()); 317 return; 318 } 319 320 // Branch back to the region or the parent. 321 assert(*index == 0 && "unexpected region index"); 322 regions.emplace_back(bodyRegion, bodyRegion->getArguments()); 323 regions.emplace_back(); 324 } 325 326 OperandRange 327 transform::ForeachOp::getSuccessorEntryOperands(Optional<unsigned> index) { 328 // The iteration variable op handle is mapped to a subset (one op to be 329 // precise) of the payload ops of the ForeachOp operand. 330 assert(index && *index == 0 && "unexpected region index"); 331 return getOperation()->getOperands(); 332 } 333 334 //===----------------------------------------------------------------------===// 335 // GetClosestIsolatedParentOp 336 //===----------------------------------------------------------------------===// 337 338 DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply( 339 transform::TransformResults &results, transform::TransformState &state) { 340 SetVector<Operation *> parents; 341 for (Operation *target : state.getPayloadOps(getTarget())) { 342 Operation *parent = 343 target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(); 344 if (!parent) { 345 DiagnosedSilenceableFailure diag = 346 emitSilenceableError() 347 << "could not find an isolated-from-above parent op"; 348 diag.attachNote(target->getLoc()) << "target op"; 349 return diag; 350 } 351 parents.insert(parent); 352 } 353 results.set(getResult().cast<OpResult>(), parents.getArrayRef()); 354 return DiagnosedSilenceableFailure::success(); 355 } 356 357 //===----------------------------------------------------------------------===// 358 // MergeHandlesOp 359 //===----------------------------------------------------------------------===// 360 361 DiagnosedSilenceableFailure 362 transform::MergeHandlesOp::apply(transform::TransformResults &results, 363 transform::TransformState &state) { 364 SmallVector<Operation *> operations; 365 for (Value operand : getHandles()) 366 llvm::append_range(operations, state.getPayloadOps(operand)); 367 if (!getDeduplicate()) { 368 results.set(getResult().cast<OpResult>(), operations); 369 return DiagnosedSilenceableFailure::success(); 370 } 371 372 SetVector<Operation *> uniqued(operations.begin(), operations.end()); 373 results.set(getResult().cast<OpResult>(), uniqued.getArrayRef()); 374 return DiagnosedSilenceableFailure::success(); 375 } 376 377 void transform::MergeHandlesOp::getEffects( 378 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 379 consumesHandle(getHandles(), effects); 380 producesHandle(getResult(), effects); 381 382 // There are no effects on the Payload IR as this is only a handle 383 // manipulation. 384 } 385 386 OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) { 387 if (getDeduplicate() || getHandles().size() != 1) 388 return {}; 389 390 // If deduplication is not required and there is only one operand, it can be 391 // used directly instead of merging. 392 return getHandles().front(); 393 } 394 395 //===----------------------------------------------------------------------===// 396 // PDLMatchOp 397 //===----------------------------------------------------------------------===// 398 399 DiagnosedSilenceableFailure 400 transform::PDLMatchOp::apply(transform::TransformResults &results, 401 transform::TransformState &state) { 402 auto *extension = state.getExtension<PatternApplicatorExtension>(); 403 assert(extension && 404 "expected PatternApplicatorExtension to be attached by the parent op"); 405 SmallVector<Operation *> targets; 406 for (Operation *root : state.getPayloadOps(getRoot())) { 407 if (failed(extension->findAllMatches( 408 getPatternName().getLeafReference().getValue(), root, targets))) { 409 emitOpError() << "could not find pattern '" << getPatternName() << "'"; 410 return DiagnosedSilenceableFailure::definiteFailure(); 411 } 412 } 413 results.set(getResult().cast<OpResult>(), targets); 414 return DiagnosedSilenceableFailure::success(); 415 } 416 417 //===----------------------------------------------------------------------===// 418 // ReplicateOp 419 //===----------------------------------------------------------------------===// 420 421 DiagnosedSilenceableFailure 422 transform::ReplicateOp::apply(transform::TransformResults &results, 423 transform::TransformState &state) { 424 unsigned numRepetitions = state.getPayloadOps(getPattern()).size(); 425 for (const auto &en : llvm::enumerate(getHandles())) { 426 Value handle = en.value(); 427 ArrayRef<Operation *> current = state.getPayloadOps(handle); 428 SmallVector<Operation *> payload; 429 payload.reserve(numRepetitions * current.size()); 430 for (unsigned i = 0; i < numRepetitions; ++i) 431 llvm::append_range(payload, current); 432 results.set(getReplicated()[en.index()].cast<OpResult>(), payload); 433 } 434 return DiagnosedSilenceableFailure::success(); 435 } 436 437 void transform::ReplicateOp::getEffects( 438 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 439 onlyReadsHandle(getPattern(), effects); 440 consumesHandle(getHandles(), effects); 441 producesHandle(getReplicated(), effects); 442 } 443 444 //===----------------------------------------------------------------------===// 445 // SequenceOp 446 //===----------------------------------------------------------------------===// 447 448 DiagnosedSilenceableFailure 449 transform::SequenceOp::apply(transform::TransformResults &results, 450 transform::TransformState &state) { 451 // Map the entry block argument to the list of operations. 452 auto scope = state.make_region_scope(*getBodyBlock()->getParent()); 453 if (failed(mapBlockArguments(state))) 454 return DiagnosedSilenceableFailure::definiteFailure(); 455 456 // Apply the sequenced ops one by one. 457 for (Operation &transform : getBodyBlock()->without_terminator()) { 458 DiagnosedSilenceableFailure result = 459 state.applyTransform(cast<TransformOpInterface>(transform)); 460 if (!result.succeeded()) 461 return result; 462 } 463 464 // Forward the operation mapping for values yielded from the sequence to the 465 // values produced by the sequence op. 466 forwardTerminatorOperands(getBodyBlock(), state, results); 467 return DiagnosedSilenceableFailure::success(); 468 } 469 470 /// Returns `true` if the given op operand may be consuming the handle value in 471 /// the Transform IR. That is, if it may have a Free effect on it. 472 static bool isValueUsePotentialConsumer(OpOperand &use) { 473 // Conservatively assume the effect being present in absence of the interface. 474 auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner()); 475 if (!iface) 476 return true; 477 478 return isHandleConsumed(use.get(), iface); 479 } 480 481 LogicalResult 482 checkDoubleConsume(Value value, 483 function_ref<InFlightDiagnostic()> reportError) { 484 OpOperand *potentialConsumer = nullptr; 485 for (OpOperand &use : value.getUses()) { 486 if (!isValueUsePotentialConsumer(use)) 487 continue; 488 489 if (!potentialConsumer) { 490 potentialConsumer = &use; 491 continue; 492 } 493 494 InFlightDiagnostic diag = reportError() 495 << " has more than one potential consumer"; 496 diag.attachNote(potentialConsumer->getOwner()->getLoc()) 497 << "used here as operand #" << potentialConsumer->getOperandNumber(); 498 diag.attachNote(use.getOwner()->getLoc()) 499 << "used here as operand #" << use.getOperandNumber(); 500 return diag; 501 } 502 503 return success(); 504 } 505 506 LogicalResult transform::SequenceOp::verify() { 507 // Check if the block argument has more than one consuming use. 508 for (BlockArgument argument : getBodyBlock()->getArguments()) { 509 auto report = [&]() { 510 return (emitOpError() << "block argument #" << argument.getArgNumber()); 511 }; 512 if (failed(checkDoubleConsume(argument, report))) 513 return failure(); 514 } 515 516 // Check properties of the nested operations they cannot check themselves. 517 for (Operation &child : *getBodyBlock()) { 518 if (!isa<TransformOpInterface>(child) && 519 &child != &getBodyBlock()->back()) { 520 InFlightDiagnostic diag = 521 emitOpError() 522 << "expected children ops to implement TransformOpInterface"; 523 diag.attachNote(child.getLoc()) << "op without interface"; 524 return diag; 525 } 526 527 for (OpResult result : child.getResults()) { 528 auto report = [&]() { 529 return (child.emitError() << "result #" << result.getResultNumber()); 530 }; 531 if (failed(checkDoubleConsume(result, report))) 532 return failure(); 533 } 534 } 535 536 if (getBodyBlock()->getTerminator()->getOperandTypes() != 537 getOperation()->getResultTypes()) { 538 InFlightDiagnostic diag = emitOpError() 539 << "expects the types of the terminator operands " 540 "to match the types of the result"; 541 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator"; 542 return diag; 543 } 544 return success(); 545 } 546 547 void transform::SequenceOp::getEffects( 548 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 549 auto *mappingResource = TransformMappingResource::get(); 550 effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource); 551 552 for (Value result : getResults()) { 553 effects.emplace_back(MemoryEffects::Allocate::get(), result, 554 mappingResource); 555 effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource); 556 } 557 558 if (!getRoot()) { 559 for (Operation &op : *getBodyBlock()) { 560 auto iface = dyn_cast<MemoryEffectOpInterface>(&op); 561 if (!iface) { 562 // TODO: fill all possible effects; or require ops to actually implement 563 // the memory effect interface always 564 assert(false); 565 } 566 567 SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects; 568 iface.getEffects(effects); 569 } 570 return; 571 } 572 573 // Carry over all effects on the argument of the entry block as those on the 574 // operand, this is the same value just remapped. 575 for (Operation &op : *getBodyBlock()) { 576 auto iface = dyn_cast<MemoryEffectOpInterface>(&op); 577 if (!iface) { 578 // TODO: fill all possible effects; or require ops to actually implement 579 // the memory effect interface always 580 assert(false); 581 } 582 583 SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects; 584 iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects); 585 for (const auto &effect : nestedEffects) 586 effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource()); 587 } 588 } 589 590 OperandRange 591 transform::SequenceOp::getSuccessorEntryOperands(Optional<unsigned> index) { 592 assert(index && *index == 0 && "unexpected region index"); 593 if (getOperation()->getNumOperands() == 1) 594 return getOperation()->getOperands(); 595 return OperandRange(getOperation()->operand_end(), 596 getOperation()->operand_end()); 597 } 598 599 void transform::SequenceOp::getSuccessorRegions( 600 Optional<unsigned> index, ArrayRef<Attribute> operands, 601 SmallVectorImpl<RegionSuccessor> ®ions) { 602 if (!index) { 603 Region *bodyRegion = &getBody(); 604 regions.emplace_back(bodyRegion, !operands.empty() 605 ? bodyRegion->getArguments() 606 : Block::BlockArgListType()); 607 return; 608 } 609 610 assert(*index == 0 && "unexpected region index"); 611 regions.emplace_back(getOperation()->getResults()); 612 } 613 614 void transform::SequenceOp::getRegionInvocationBounds( 615 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { 616 (void)operands; 617 bounds.emplace_back(1, 1); 618 } 619 620 //===----------------------------------------------------------------------===// 621 // WithPDLPatternsOp 622 //===----------------------------------------------------------------------===// 623 624 DiagnosedSilenceableFailure 625 transform::WithPDLPatternsOp::apply(transform::TransformResults &results, 626 transform::TransformState &state) { 627 OwningOpRef<ModuleOp> pdlModuleOp = 628 ModuleOp::create(getOperation()->getLoc()); 629 TransformOpInterface transformOp = nullptr; 630 for (Operation &nested : getBody().front()) { 631 if (!isa<pdl::PatternOp>(nested)) { 632 transformOp = cast<TransformOpInterface>(nested); 633 break; 634 } 635 } 636 637 state.addExtension<PatternApplicatorExtension>(getOperation()); 638 auto guard = llvm::make_scope_exit( 639 [&]() { state.removeExtension<PatternApplicatorExtension>(); }); 640 641 auto scope = state.make_region_scope(getBody()); 642 if (failed(mapBlockArguments(state))) 643 return DiagnosedSilenceableFailure::definiteFailure(); 644 return state.applyTransform(transformOp); 645 } 646 647 LogicalResult transform::WithPDLPatternsOp::verify() { 648 Block *body = getBodyBlock(); 649 Operation *topLevelOp = nullptr; 650 for (Operation &op : body->getOperations()) { 651 if (isa<pdl::PatternOp>(op)) 652 continue; 653 654 if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) { 655 if (topLevelOp) { 656 InFlightDiagnostic diag = 657 emitOpError() << "expects only one non-pattern op in its body"; 658 diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op"; 659 diag.attachNote(op.getLoc()) << "second non-pattern op"; 660 return diag; 661 } 662 topLevelOp = &op; 663 continue; 664 } 665 666 InFlightDiagnostic diag = 667 emitOpError() 668 << "expects only pattern and top-level transform ops in its body"; 669 diag.attachNote(op.getLoc()) << "offending op"; 670 return diag; 671 } 672 673 if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) { 674 InFlightDiagnostic diag = emitOpError() << "cannot be nested"; 675 diag.attachNote(parent.getLoc()) << "parent operation"; 676 return diag; 677 } 678 679 return success(); 680 } 681