1 //===- TransformDialect.cpp - Transform dialect operations ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformOps.h" 10 #include "mlir/Dialect/PDL/IR/PDLOps.h" 11 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/Interfaces/ControlFlowInterfaces.h" 16 #include "mlir/Rewrite/FrozenRewritePatternSet.h" 17 #include "mlir/Rewrite/PatternApplicator.h" 18 #include "llvm/ADT/ScopeExit.h" 19 #include "llvm/Support/Debug.h" 20 21 #define DEBUG_TYPE "transform-dialect" 22 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") 23 24 using namespace mlir; 25 26 #define GET_OP_CLASSES 27 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" 28 29 //===----------------------------------------------------------------------===// 30 // PatternApplicatorExtension 31 //===----------------------------------------------------------------------===// 32 33 namespace { 34 /// A simple pattern rewriter that can be constructed from a context. This is 35 /// necessary to apply patterns to a specific op locally. 36 class TrivialPatternRewriter : public PatternRewriter { 37 public: 38 explicit TrivialPatternRewriter(MLIRContext *context) 39 : PatternRewriter(context) {} 40 }; 41 42 /// A TransformState extension that keeps track of compiled PDL pattern sets. 43 /// This is intended to be used along the WithPDLPatterns op. The extension 44 /// can be constructed given an operation that has a SymbolTable trait and 45 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one 46 /// by one when requested; this behavior is subject to change. 47 class PatternApplicatorExtension : public transform::TransformState::Extension { 48 public: 49 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension) 50 51 /// Creates the extension for patterns contained in `patternContainer`. 52 explicit PatternApplicatorExtension(transform::TransformState &state, 53 Operation *patternContainer) 54 : Extension(state), patterns(patternContainer) {} 55 56 /// Appends to `results` the operations contained in `root` that matched the 57 /// PDL pattern with the given name. Note that `root` may or may not be the 58 /// operation that contains PDL patterns. Reports an error if the pattern 59 /// cannot be found. Note that when no operations are matched, this still 60 /// succeeds as long as the pattern exists. 61 LogicalResult findAllMatches(StringRef patternName, Operation *root, 62 SmallVectorImpl<Operation *> &results); 63 64 private: 65 /// Map from the pattern name to a singleton set of rewrite patterns that only 66 /// contains the pattern with this name. Populated when the pattern is first 67 /// requested. 68 // TODO: reconsider the efficiency of this storage when more usage data is 69 // available. Storing individual patterns in a set and triggering compilation 70 // for each of them has overhead. So does compiling a large set of patterns 71 // only to apply a handlful of them. 72 llvm::StringMap<FrozenRewritePatternSet> compiledPatterns; 73 74 /// A symbol table operation containing the relevant PDL patterns. 75 SymbolTable patterns; 76 }; 77 78 LogicalResult PatternApplicatorExtension::findAllMatches( 79 StringRef patternName, Operation *root, 80 SmallVectorImpl<Operation *> &results) { 81 auto it = compiledPatterns.find(patternName); 82 if (it == compiledPatterns.end()) { 83 auto patternOp = patterns.lookup<pdl::PatternOp>(patternName); 84 if (!patternOp) 85 return failure(); 86 87 OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc()); 88 patternOp->moveBefore(pdlModuleOp->getBody(), 89 pdlModuleOp->getBody()->end()); 90 PDLPatternModule patternModule(std::move(pdlModuleOp)); 91 92 // Merge in the hooks owned by the dialect. Make a copy as they may be 93 // also used by the following operations. 94 auto *dialect = 95 root->getContext()->getLoadedDialect<transform::TransformDialect>(); 96 for (const auto &pair : dialect->getPDLConstraintHooks()) 97 patternModule.registerConstraintFunction(pair.first(), pair.second); 98 99 // Register a noop rewriter because PDL requires patterns to end with some 100 // rewrite call. 101 patternModule.registerRewriteFunction( 102 "transform.dialect", [](PatternRewriter &, Operation *) {}); 103 104 it = compiledPatterns 105 .try_emplace(patternOp.getName(), std::move(patternModule)) 106 .first; 107 } 108 109 PatternApplicator applicator(it->second); 110 TrivialPatternRewriter rewriter(root->getContext()); 111 applicator.applyDefaultCostModel(); 112 root->walk([&](Operation *op) { 113 if (succeeded(applicator.matchAndRewrite(op, rewriter))) 114 results.push_back(op); 115 }); 116 117 return success(); 118 } 119 } // namespace 120 121 //===----------------------------------------------------------------------===// 122 // AlternativesOp 123 //===----------------------------------------------------------------------===// 124 125 OperandRange 126 transform::AlternativesOp::getSuccessorEntryOperands(Optional<unsigned> index) { 127 if (index && getOperation()->getNumOperands() == 1) 128 return getOperation()->getOperands(); 129 return OperandRange(getOperation()->operand_end(), 130 getOperation()->operand_end()); 131 } 132 133 void transform::AlternativesOp::getSuccessorRegions( 134 Optional<unsigned> index, ArrayRef<Attribute> operands, 135 SmallVectorImpl<RegionSuccessor> ®ions) { 136 for (Region &alternative : 137 llvm::drop_begin(getAlternatives(), index.hasValue() ? *index + 1 : 0)) { 138 regions.emplace_back(&alternative, !getOperands().empty() 139 ? alternative.getArguments() 140 : Block::BlockArgListType()); 141 } 142 if (index.hasValue()) 143 regions.emplace_back(getOperation()->getResults()); 144 } 145 146 void transform::AlternativesOp::getRegionInvocationBounds( 147 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { 148 (void)operands; 149 // The region corresponding to the first alternative is always executed, the 150 // remaining may or may not be executed. 151 bounds.reserve(getNumRegions()); 152 bounds.emplace_back(1, 1); 153 bounds.resize(getNumRegions(), InvocationBounds(0, 1)); 154 } 155 156 static void forwardTerminatorOperands(Block *block, 157 transform::TransformState &state, 158 transform::TransformResults &results) { 159 for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(), 160 block->getParentOp()->getOpResults())) { 161 Value terminatorOperand = std::get<0>(pair); 162 OpResult result = std::get<1>(pair); 163 results.set(result, state.getPayloadOps(terminatorOperand)); 164 } 165 } 166 167 DiagnosedSilenceableFailure 168 transform::AlternativesOp::apply(transform::TransformResults &results, 169 transform::TransformState &state) { 170 SmallVector<Operation *> originals; 171 if (Value scopeHandle = getScope()) 172 llvm::append_range(originals, state.getPayloadOps(scopeHandle)); 173 else 174 originals.push_back(state.getTopLevel()); 175 176 for (Operation *original : originals) { 177 if (original->isAncestor(getOperation())) { 178 InFlightDiagnostic diag = 179 emitError() << "scope must not contain the transforms being applied"; 180 diag.attachNote(original->getLoc()) << "scope"; 181 return DiagnosedSilenceableFailure::definiteFailure(); 182 } 183 if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 184 InFlightDiagnostic diag = 185 emitError() 186 << "only isolated-from-above ops can be alternative scopes"; 187 diag.attachNote(original->getLoc()) << "scope"; 188 return DiagnosedSilenceableFailure(std::move(diag)); 189 } 190 } 191 192 for (Region ® : getAlternatives()) { 193 // Clone the scope operations and make the transforms in this alternative 194 // region apply to them by virtue of mapping the block argument (the only 195 // visible handle) to the cloned scope operations. This effectively prevents 196 // the transformation from accessing any IR outside the scope. 197 auto scope = state.make_region_scope(reg); 198 auto clones = llvm::to_vector( 199 llvm::map_range(originals, [](Operation *op) { return op->clone(); })); 200 auto deleteClones = llvm::make_scope_exit([&] { 201 for (Operation *clone : clones) 202 clone->erase(); 203 }); 204 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones))) 205 return DiagnosedSilenceableFailure::definiteFailure(); 206 207 bool failed = false; 208 for (Operation &transform : reg.front().without_terminator()) { 209 DiagnosedSilenceableFailure result = 210 state.applyTransform(cast<TransformOpInterface>(transform)); 211 if (result.isSilenceableFailure()) { 212 LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage() 213 << "\n"); 214 failed = true; 215 break; 216 } 217 218 if (::mlir::failed(result.silence())) 219 return DiagnosedSilenceableFailure::definiteFailure(); 220 } 221 222 // If all operations in the given alternative succeeded, no need to consider 223 // the rest. Replace the original scoping operation with the clone on which 224 // the transformations were performed. 225 if (!failed) { 226 // We will be using the clones, so cancel their scheduled deletion. 227 deleteClones.release(); 228 IRRewriter rewriter(getContext()); 229 for (const auto &kvp : llvm::zip(originals, clones)) { 230 Operation *original = std::get<0>(kvp); 231 Operation *clone = std::get<1>(kvp); 232 original->getBlock()->getOperations().insert(original->getIterator(), 233 clone); 234 rewriter.replaceOp(original, clone->getResults()); 235 } 236 forwardTerminatorOperands(®.front(), state, results); 237 return DiagnosedSilenceableFailure::success(); 238 } 239 } 240 return emitSilenceableError() << "all alternatives failed"; 241 } 242 243 LogicalResult transform::AlternativesOp::verify() { 244 for (Region &alternative : getAlternatives()) { 245 Block &block = alternative.front(); 246 if (block.getNumArguments() != 1 || 247 !block.getArgument(0).getType().isa<pdl::OperationType>()) { 248 return emitOpError() 249 << "expects region blocks to have one operand of type " 250 << pdl::OperationType::get(getContext()); 251 } 252 253 Operation *terminator = block.getTerminator(); 254 if (terminator->getOperands().getTypes() != getResults().getTypes()) { 255 InFlightDiagnostic diag = emitOpError() 256 << "expects terminator operands to have the " 257 "same type as results of the operation"; 258 diag.attachNote(terminator->getLoc()) << "terminator"; 259 return diag; 260 } 261 } 262 263 return success(); 264 } 265 266 //===----------------------------------------------------------------------===// 267 // GetClosestIsolatedParentOp 268 //===----------------------------------------------------------------------===// 269 270 DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply( 271 transform::TransformResults &results, transform::TransformState &state) { 272 SetVector<Operation *> parents; 273 for (Operation *target : state.getPayloadOps(getTarget())) { 274 Operation *parent = 275 target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(); 276 if (!parent) { 277 DiagnosedSilenceableFailure diag = 278 emitSilenceableError() 279 << "could not find an isolated-from-above parent op"; 280 diag.attachNote(target->getLoc()) << "target op"; 281 return diag; 282 } 283 parents.insert(parent); 284 } 285 results.set(getResult().cast<OpResult>(), parents.getArrayRef()); 286 return DiagnosedSilenceableFailure::success(); 287 } 288 289 //===----------------------------------------------------------------------===// 290 // MergeHandlesOp 291 //===----------------------------------------------------------------------===// 292 293 DiagnosedSilenceableFailure 294 transform::MergeHandlesOp::apply(transform::TransformResults &results, 295 transform::TransformState &state) { 296 SmallVector<Operation *> operations; 297 for (Value operand : getHandles()) 298 llvm::append_range(operations, state.getPayloadOps(operand)); 299 if (!getDeduplicate()) { 300 results.set(getResult().cast<OpResult>(), operations); 301 return DiagnosedSilenceableFailure::success(); 302 } 303 304 SetVector<Operation *> uniqued(operations.begin(), operations.end()); 305 results.set(getResult().cast<OpResult>(), uniqued.getArrayRef()); 306 return DiagnosedSilenceableFailure::success(); 307 } 308 309 void transform::MergeHandlesOp::getEffects( 310 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 311 for (Value operand : getHandles()) { 312 effects.emplace_back(MemoryEffects::Read::get(), operand, 313 transform::TransformMappingResource::get()); 314 effects.emplace_back(MemoryEffects::Free::get(), operand, 315 transform::TransformMappingResource::get()); 316 } 317 effects.emplace_back(MemoryEffects::Allocate::get(), getResult(), 318 transform::TransformMappingResource::get()); 319 effects.emplace_back(MemoryEffects::Write::get(), getResult(), 320 transform::TransformMappingResource::get()); 321 322 // There are no effects on the Payload IR as this is only a handle 323 // manipulation. 324 } 325 326 OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) { 327 if (getDeduplicate() || getHandles().size() != 1) 328 return {}; 329 330 // If deduplication is not required and there is only one operand, it can be 331 // used directly instead of merging. 332 return getHandles().front(); 333 } 334 335 //===----------------------------------------------------------------------===// 336 // PDLMatchOp 337 //===----------------------------------------------------------------------===// 338 339 DiagnosedSilenceableFailure 340 transform::PDLMatchOp::apply(transform::TransformResults &results, 341 transform::TransformState &state) { 342 auto *extension = state.getExtension<PatternApplicatorExtension>(); 343 assert(extension && 344 "expected PatternApplicatorExtension to be attached by the parent op"); 345 SmallVector<Operation *> targets; 346 for (Operation *root : state.getPayloadOps(getRoot())) { 347 if (failed(extension->findAllMatches( 348 getPatternName().getLeafReference().getValue(), root, targets))) { 349 emitOpError() << "could not find pattern '" << getPatternName() << "'"; 350 return DiagnosedSilenceableFailure::definiteFailure(); 351 } 352 } 353 results.set(getResult().cast<OpResult>(), targets); 354 return DiagnosedSilenceableFailure::success(); 355 } 356 357 //===----------------------------------------------------------------------===// 358 // SequenceOp 359 //===----------------------------------------------------------------------===// 360 361 DiagnosedSilenceableFailure 362 transform::SequenceOp::apply(transform::TransformResults &results, 363 transform::TransformState &state) { 364 // Map the entry block argument to the list of operations. 365 auto scope = state.make_region_scope(*getBodyBlock()->getParent()); 366 if (failed(mapBlockArguments(state))) 367 return DiagnosedSilenceableFailure::definiteFailure(); 368 369 // Apply the sequenced ops one by one. 370 for (Operation &transform : getBodyBlock()->without_terminator()) { 371 DiagnosedSilenceableFailure result = 372 state.applyTransform(cast<TransformOpInterface>(transform)); 373 if (!result.succeeded()) 374 return result; 375 } 376 377 // Forward the operation mapping for values yielded from the sequence to the 378 // values produced by the sequence op. 379 forwardTerminatorOperands(getBodyBlock(), state, results); 380 return DiagnosedSilenceableFailure::success(); 381 } 382 383 /// Returns `true` if the given op operand may be consuming the handle value in 384 /// the Transform IR. That is, if it may have a Free effect on it. 385 static bool isValueUsePotentialConsumer(OpOperand &use) { 386 // Conservatively assume the effect being present in absence of the interface. 387 auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(use.getOwner()); 388 if (!memEffectInterface) 389 return true; 390 391 SmallVector<MemoryEffects::EffectInstance, 2> effects; 392 memEffectInterface.getEffectsOnValue(use.get(), effects); 393 return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { 394 return isa<transform::TransformMappingResource>(effect.getResource()) && 395 isa<MemoryEffects::Free>(effect.getEffect()); 396 }); 397 } 398 399 LogicalResult 400 checkDoubleConsume(Value value, 401 function_ref<InFlightDiagnostic()> reportError) { 402 OpOperand *potentialConsumer = nullptr; 403 for (OpOperand &use : value.getUses()) { 404 if (!isValueUsePotentialConsumer(use)) 405 continue; 406 407 if (!potentialConsumer) { 408 potentialConsumer = &use; 409 continue; 410 } 411 412 InFlightDiagnostic diag = reportError() 413 << " has more than one potential consumer"; 414 diag.attachNote(potentialConsumer->getOwner()->getLoc()) 415 << "used here as operand #" << potentialConsumer->getOperandNumber(); 416 diag.attachNote(use.getOwner()->getLoc()) 417 << "used here as operand #" << use.getOperandNumber(); 418 return diag; 419 } 420 421 return success(); 422 } 423 424 LogicalResult transform::SequenceOp::verify() { 425 // Check if the block argument has more than one consuming use. 426 for (BlockArgument argument : getBodyBlock()->getArguments()) { 427 auto report = [&]() { 428 return (emitOpError() << "block argument #" << argument.getArgNumber()); 429 }; 430 if (failed(checkDoubleConsume(argument, report))) 431 return failure(); 432 } 433 434 // Check properties of the nested operations they cannot check themselves. 435 for (Operation &child : *getBodyBlock()) { 436 if (!isa<TransformOpInterface>(child) && 437 &child != &getBodyBlock()->back()) { 438 InFlightDiagnostic diag = 439 emitOpError() 440 << "expected children ops to implement TransformOpInterface"; 441 diag.attachNote(child.getLoc()) << "op without interface"; 442 return diag; 443 } 444 445 for (OpResult result : child.getResults()) { 446 auto report = [&]() { 447 return (child.emitError() << "result #" << result.getResultNumber()); 448 }; 449 if (failed(checkDoubleConsume(result, report))) 450 return failure(); 451 } 452 } 453 454 if (getBodyBlock()->getTerminator()->getOperandTypes() != 455 getOperation()->getResultTypes()) { 456 InFlightDiagnostic diag = emitOpError() 457 << "expects the types of the terminator operands " 458 "to match the types of the result"; 459 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator"; 460 return diag; 461 } 462 return success(); 463 } 464 465 void transform::SequenceOp::getEffects( 466 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 467 auto *mappingResource = TransformMappingResource::get(); 468 effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource); 469 470 for (Value result : getResults()) { 471 effects.emplace_back(MemoryEffects::Allocate::get(), result, 472 mappingResource); 473 effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource); 474 } 475 476 if (!getRoot()) { 477 for (Operation &op : *getBodyBlock()) { 478 auto iface = dyn_cast<MemoryEffectOpInterface>(&op); 479 if (!iface) { 480 // TODO: fill all possible effects; or require ops to actually implement 481 // the memory effect interface always 482 assert(false); 483 } 484 485 SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects; 486 iface.getEffects(effects); 487 } 488 return; 489 } 490 491 // Carry over all effects on the argument of the entry block as those on the 492 // operand, this is the same value just remapped. 493 for (Operation &op : *getBodyBlock()) { 494 auto iface = dyn_cast<MemoryEffectOpInterface>(&op); 495 if (!iface) { 496 // TODO: fill all possible effects; or require ops to actually implement 497 // the memory effect interface always 498 assert(false); 499 } 500 501 SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects; 502 iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects); 503 for (const auto &effect : nestedEffects) 504 effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource()); 505 } 506 } 507 508 OperandRange 509 transform::SequenceOp::getSuccessorEntryOperands(Optional<unsigned> index) { 510 assert(index && *index == 0 && "unexpected region index"); 511 if (getOperation()->getNumOperands() == 1) 512 return getOperation()->getOperands(); 513 return OperandRange(getOperation()->operand_end(), 514 getOperation()->operand_end()); 515 } 516 517 void transform::SequenceOp::getSuccessorRegions( 518 Optional<unsigned> index, ArrayRef<Attribute> operands, 519 SmallVectorImpl<RegionSuccessor> ®ions) { 520 if (!index) { 521 Region *bodyRegion = &getBody(); 522 regions.emplace_back(bodyRegion, !operands.empty() 523 ? bodyRegion->getArguments() 524 : Block::BlockArgListType()); 525 return; 526 } 527 528 assert(*index == 0 && "unexpected region index"); 529 regions.emplace_back(getOperation()->getResults()); 530 } 531 532 void transform::SequenceOp::getRegionInvocationBounds( 533 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { 534 (void)operands; 535 bounds.emplace_back(1, 1); 536 } 537 538 //===----------------------------------------------------------------------===// 539 // WithPDLPatternsOp 540 //===----------------------------------------------------------------------===// 541 542 DiagnosedSilenceableFailure 543 transform::WithPDLPatternsOp::apply(transform::TransformResults &results, 544 transform::TransformState &state) { 545 OwningOpRef<ModuleOp> pdlModuleOp = 546 ModuleOp::create(getOperation()->getLoc()); 547 TransformOpInterface transformOp = nullptr; 548 for (Operation &nested : getBody().front()) { 549 if (!isa<pdl::PatternOp>(nested)) { 550 transformOp = cast<TransformOpInterface>(nested); 551 break; 552 } 553 } 554 555 state.addExtension<PatternApplicatorExtension>(getOperation()); 556 auto guard = llvm::make_scope_exit( 557 [&]() { state.removeExtension<PatternApplicatorExtension>(); }); 558 559 auto scope = state.make_region_scope(getBody()); 560 if (failed(mapBlockArguments(state))) 561 return DiagnosedSilenceableFailure::definiteFailure(); 562 return state.applyTransform(transformOp); 563 } 564 565 LogicalResult transform::WithPDLPatternsOp::verify() { 566 Block *body = getBodyBlock(); 567 Operation *topLevelOp = nullptr; 568 for (Operation &op : body->getOperations()) { 569 if (isa<pdl::PatternOp>(op)) 570 continue; 571 572 if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) { 573 if (topLevelOp) { 574 InFlightDiagnostic diag = 575 emitOpError() << "expects only one non-pattern op in its body"; 576 diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op"; 577 diag.attachNote(op.getLoc()) << "second non-pattern op"; 578 return diag; 579 } 580 topLevelOp = &op; 581 continue; 582 } 583 584 InFlightDiagnostic diag = 585 emitOpError() 586 << "expects only pattern and top-level transform ops in its body"; 587 diag.attachNote(op.getLoc()) << "offending op"; 588 return diag; 589 } 590 591 if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) { 592 InFlightDiagnostic diag = emitOpError() << "cannot be nested"; 593 diag.attachNote(parent.getLoc()) << "parent operation"; 594 return diag; 595 } 596 597 return success(); 598 } 599