1 //===- TransformDialect.cpp - Transform dialect operations ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformOps.h" 10 #include "mlir/Dialect/PDL/IR/PDLOps.h" 11 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/Interfaces/ControlFlowInterfaces.h" 16 #include "mlir/Rewrite/FrozenRewritePatternSet.h" 17 #include "mlir/Rewrite/PatternApplicator.h" 18 #include "llvm/ADT/ScopeExit.h" 19 #include "llvm/Support/Debug.h" 20 21 #define DEBUG_TYPE "transform-dialect" 22 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") 23 24 using namespace mlir; 25 26 #define GET_OP_CLASSES 27 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" 28 29 //===----------------------------------------------------------------------===// 30 // PatternApplicatorExtension 31 //===----------------------------------------------------------------------===// 32 33 namespace { 34 /// A simple pattern rewriter that can be constructed from a context. This is 35 /// necessary to apply patterns to a specific op locally. 36 class TrivialPatternRewriter : public PatternRewriter { 37 public: 38 explicit TrivialPatternRewriter(MLIRContext *context) 39 : PatternRewriter(context) {} 40 }; 41 42 /// A TransformState extension that keeps track of compiled PDL pattern sets. 43 /// This is intended to be used along the WithPDLPatterns op. The extension 44 /// can be constructed given an operation that has a SymbolTable trait and 45 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one 46 /// by one when requested; this behavior is subject to change. 47 class PatternApplicatorExtension : public transform::TransformState::Extension { 48 public: 49 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension) 50 51 /// Creates the extension for patterns contained in `patternContainer`. 52 explicit PatternApplicatorExtension(transform::TransformState &state, 53 Operation *patternContainer) 54 : Extension(state), patterns(patternContainer) {} 55 56 /// Appends to `results` the operations contained in `root` that matched the 57 /// PDL pattern with the given name. Note that `root` may or may not be the 58 /// operation that contains PDL patterns. Reports an error if the pattern 59 /// cannot be found. Note that when no operations are matched, this still 60 /// succeeds as long as the pattern exists. 61 LogicalResult findAllMatches(StringRef patternName, Operation *root, 62 SmallVectorImpl<Operation *> &results); 63 64 private: 65 /// Map from the pattern name to a singleton set of rewrite patterns that only 66 /// contains the pattern with this name. Populated when the pattern is first 67 /// requested. 68 // TODO: reconsider the efficiency of this storage when more usage data is 69 // available. Storing individual patterns in a set and triggering compilation 70 // for each of them has overhead. So does compiling a large set of patterns 71 // only to apply a handlful of them. 72 llvm::StringMap<FrozenRewritePatternSet> compiledPatterns; 73 74 /// A symbol table operation containing the relevant PDL patterns. 75 SymbolTable patterns; 76 }; 77 78 LogicalResult PatternApplicatorExtension::findAllMatches( 79 StringRef patternName, Operation *root, 80 SmallVectorImpl<Operation *> &results) { 81 auto it = compiledPatterns.find(patternName); 82 if (it == compiledPatterns.end()) { 83 auto patternOp = patterns.lookup<pdl::PatternOp>(patternName); 84 if (!patternOp) 85 return failure(); 86 87 OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc()); 88 patternOp->moveBefore(pdlModuleOp->getBody(), 89 pdlModuleOp->getBody()->end()); 90 PDLPatternModule patternModule(std::move(pdlModuleOp)); 91 92 // Merge in the hooks owned by the dialect. Make a copy as they may be 93 // also used by the following operations. 94 auto *dialect = 95 root->getContext()->getLoadedDialect<transform::TransformDialect>(); 96 for (const auto &pair : dialect->getPDLConstraintHooks()) 97 patternModule.registerConstraintFunction(pair.first(), pair.second); 98 99 // Register a noop rewriter because PDL requires patterns to end with some 100 // rewrite call. 101 patternModule.registerRewriteFunction( 102 "transform.dialect", [](PatternRewriter &, Operation *) {}); 103 104 it = compiledPatterns 105 .try_emplace(patternOp.getName(), std::move(patternModule)) 106 .first; 107 } 108 109 PatternApplicator applicator(it->second); 110 TrivialPatternRewriter rewriter(root->getContext()); 111 applicator.applyDefaultCostModel(); 112 root->walk([&](Operation *op) { 113 if (succeeded(applicator.matchAndRewrite(op, rewriter))) 114 results.push_back(op); 115 }); 116 117 return success(); 118 } 119 } // namespace 120 121 //===----------------------------------------------------------------------===// 122 // AlternativesOp 123 //===----------------------------------------------------------------------===// 124 125 OperandRange 126 transform::AlternativesOp::getSuccessorEntryOperands(Optional<unsigned> index) { 127 if (index && getOperation()->getNumOperands() == 1) 128 return getOperation()->getOperands(); 129 return OperandRange(getOperation()->operand_end(), 130 getOperation()->operand_end()); 131 } 132 133 void transform::AlternativesOp::getSuccessorRegions( 134 Optional<unsigned> index, ArrayRef<Attribute> operands, 135 SmallVectorImpl<RegionSuccessor> ®ions) { 136 for (Region &alternative : 137 llvm::drop_begin(getAlternatives(), index ? *index + 1 : 0)) { 138 regions.emplace_back(&alternative, !getOperands().empty() 139 ? alternative.getArguments() 140 : Block::BlockArgListType()); 141 } 142 if (index) 143 regions.emplace_back(getOperation()->getResults()); 144 } 145 146 void transform::AlternativesOp::getRegionInvocationBounds( 147 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { 148 (void)operands; 149 // The region corresponding to the first alternative is always executed, the 150 // remaining may or may not be executed. 151 bounds.reserve(getNumRegions()); 152 bounds.emplace_back(1, 1); 153 bounds.resize(getNumRegions(), InvocationBounds(0, 1)); 154 } 155 156 static void forwardTerminatorOperands(Block *block, 157 transform::TransformState &state, 158 transform::TransformResults &results) { 159 for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(), 160 block->getParentOp()->getOpResults())) { 161 Value terminatorOperand = std::get<0>(pair); 162 OpResult result = std::get<1>(pair); 163 results.set(result, state.getPayloadOps(terminatorOperand)); 164 } 165 } 166 167 DiagnosedSilenceableFailure 168 transform::AlternativesOp::apply(transform::TransformResults &results, 169 transform::TransformState &state) { 170 SmallVector<Operation *> originals; 171 if (Value scopeHandle = getScope()) 172 llvm::append_range(originals, state.getPayloadOps(scopeHandle)); 173 else 174 originals.push_back(state.getTopLevel()); 175 176 for (Operation *original : originals) { 177 if (original->isAncestor(getOperation())) { 178 InFlightDiagnostic diag = 179 emitError() << "scope must not contain the transforms being applied"; 180 diag.attachNote(original->getLoc()) << "scope"; 181 return DiagnosedSilenceableFailure::definiteFailure(); 182 } 183 if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) { 184 InFlightDiagnostic diag = 185 emitError() 186 << "only isolated-from-above ops can be alternative scopes"; 187 diag.attachNote(original->getLoc()) << "scope"; 188 return DiagnosedSilenceableFailure(std::move(diag)); 189 } 190 } 191 192 for (Region ® : getAlternatives()) { 193 // Clone the scope operations and make the transforms in this alternative 194 // region apply to them by virtue of mapping the block argument (the only 195 // visible handle) to the cloned scope operations. This effectively prevents 196 // the transformation from accessing any IR outside the scope. 197 auto scope = state.make_region_scope(reg); 198 auto clones = llvm::to_vector( 199 llvm::map_range(originals, [](Operation *op) { return op->clone(); })); 200 auto deleteClones = llvm::make_scope_exit([&] { 201 for (Operation *clone : clones) 202 clone->erase(); 203 }); 204 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones))) 205 return DiagnosedSilenceableFailure::definiteFailure(); 206 207 bool failed = false; 208 for (Operation &transform : reg.front().without_terminator()) { 209 DiagnosedSilenceableFailure result = 210 state.applyTransform(cast<TransformOpInterface>(transform)); 211 if (result.isSilenceableFailure()) { 212 LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage() 213 << "\n"); 214 failed = true; 215 break; 216 } 217 218 if (::mlir::failed(result.silence())) 219 return DiagnosedSilenceableFailure::definiteFailure(); 220 } 221 222 // If all operations in the given alternative succeeded, no need to consider 223 // the rest. Replace the original scoping operation with the clone on which 224 // the transformations were performed. 225 if (!failed) { 226 // We will be using the clones, so cancel their scheduled deletion. 227 deleteClones.release(); 228 IRRewriter rewriter(getContext()); 229 for (const auto &kvp : llvm::zip(originals, clones)) { 230 Operation *original = std::get<0>(kvp); 231 Operation *clone = std::get<1>(kvp); 232 original->getBlock()->getOperations().insert(original->getIterator(), 233 clone); 234 rewriter.replaceOp(original, clone->getResults()); 235 } 236 forwardTerminatorOperands(®.front(), state, results); 237 return DiagnosedSilenceableFailure::success(); 238 } 239 } 240 return emitSilenceableError() << "all alternatives failed"; 241 } 242 243 LogicalResult transform::AlternativesOp::verify() { 244 for (Region &alternative : getAlternatives()) { 245 Block &block = alternative.front(); 246 if (block.getNumArguments() != 1 || 247 !block.getArgument(0).getType().isa<pdl::OperationType>()) { 248 return emitOpError() 249 << "expects region blocks to have one operand of type " 250 << pdl::OperationType::get(getContext()); 251 } 252 253 Operation *terminator = block.getTerminator(); 254 if (terminator->getOperands().getTypes() != getResults().getTypes()) { 255 InFlightDiagnostic diag = emitOpError() 256 << "expects terminator operands to have the " 257 "same type as results of the operation"; 258 diag.attachNote(terminator->getLoc()) << "terminator"; 259 return diag; 260 } 261 } 262 263 return success(); 264 } 265 266 //===----------------------------------------------------------------------===// 267 // GetClosestIsolatedParentOp 268 //===----------------------------------------------------------------------===// 269 270 DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply( 271 transform::TransformResults &results, transform::TransformState &state) { 272 SetVector<Operation *> parents; 273 for (Operation *target : state.getPayloadOps(getTarget())) { 274 Operation *parent = 275 target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(); 276 if (!parent) { 277 DiagnosedSilenceableFailure diag = 278 emitSilenceableError() 279 << "could not find an isolated-from-above parent op"; 280 diag.attachNote(target->getLoc()) << "target op"; 281 return diag; 282 } 283 parents.insert(parent); 284 } 285 results.set(getResult().cast<OpResult>(), parents.getArrayRef()); 286 return DiagnosedSilenceableFailure::success(); 287 } 288 289 //===----------------------------------------------------------------------===// 290 // PDLMatchOp 291 //===----------------------------------------------------------------------===// 292 293 DiagnosedSilenceableFailure 294 transform::PDLMatchOp::apply(transform::TransformResults &results, 295 transform::TransformState &state) { 296 auto *extension = state.getExtension<PatternApplicatorExtension>(); 297 assert(extension && 298 "expected PatternApplicatorExtension to be attached by the parent op"); 299 SmallVector<Operation *> targets; 300 for (Operation *root : state.getPayloadOps(getRoot())) { 301 if (failed(extension->findAllMatches( 302 getPatternName().getLeafReference().getValue(), root, targets))) { 303 emitOpError() << "could not find pattern '" << getPatternName() << "'"; 304 return DiagnosedSilenceableFailure::definiteFailure(); 305 } 306 } 307 results.set(getResult().cast<OpResult>(), targets); 308 return DiagnosedSilenceableFailure::success(); 309 } 310 311 //===----------------------------------------------------------------------===// 312 // SequenceOp 313 //===----------------------------------------------------------------------===// 314 315 DiagnosedSilenceableFailure 316 transform::SequenceOp::apply(transform::TransformResults &results, 317 transform::TransformState &state) { 318 // Map the entry block argument to the list of operations. 319 auto scope = state.make_region_scope(*getBodyBlock()->getParent()); 320 if (failed(mapBlockArguments(state))) 321 return DiagnosedSilenceableFailure::definiteFailure(); 322 323 // Apply the sequenced ops one by one. 324 for (Operation &transform : getBodyBlock()->without_terminator()) { 325 DiagnosedSilenceableFailure result = 326 state.applyTransform(cast<TransformOpInterface>(transform)); 327 if (!result.succeeded()) 328 return result; 329 } 330 331 // Forward the operation mapping for values yielded from the sequence to the 332 // values produced by the sequence op. 333 forwardTerminatorOperands(getBodyBlock(), state, results); 334 return DiagnosedSilenceableFailure::success(); 335 } 336 337 /// Returns `true` if the given op operand may be consuming the handle value in 338 /// the Transform IR. That is, if it may have a Free effect on it. 339 static bool isValueUsePotentialConsumer(OpOperand &use) { 340 // Conservatively assume the effect being present in absence of the interface. 341 auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(use.getOwner()); 342 if (!memEffectInterface) 343 return true; 344 345 SmallVector<MemoryEffects::EffectInstance, 2> effects; 346 memEffectInterface.getEffectsOnValue(use.get(), effects); 347 return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { 348 return isa<transform::TransformMappingResource>(effect.getResource()) && 349 isa<MemoryEffects::Free>(effect.getEffect()); 350 }); 351 } 352 353 LogicalResult 354 checkDoubleConsume(Value value, 355 function_ref<InFlightDiagnostic()> reportError) { 356 OpOperand *potentialConsumer = nullptr; 357 for (OpOperand &use : value.getUses()) { 358 if (!isValueUsePotentialConsumer(use)) 359 continue; 360 361 if (!potentialConsumer) { 362 potentialConsumer = &use; 363 continue; 364 } 365 366 InFlightDiagnostic diag = reportError() 367 << " has more than one potential consumer"; 368 diag.attachNote(potentialConsumer->getOwner()->getLoc()) 369 << "used here as operand #" << potentialConsumer->getOperandNumber(); 370 diag.attachNote(use.getOwner()->getLoc()) 371 << "used here as operand #" << use.getOperandNumber(); 372 return diag; 373 } 374 375 return success(); 376 } 377 378 LogicalResult transform::SequenceOp::verify() { 379 // Check if the block argument has more than one consuming use. 380 for (BlockArgument argument : getBodyBlock()->getArguments()) { 381 auto report = [&]() { 382 return (emitOpError() << "block argument #" << argument.getArgNumber()); 383 }; 384 if (failed(checkDoubleConsume(argument, report))) 385 return failure(); 386 } 387 388 // Check properties of the nested operations they cannot check themselves. 389 for (Operation &child : *getBodyBlock()) { 390 if (!isa<TransformOpInterface>(child) && 391 &child != &getBodyBlock()->back()) { 392 InFlightDiagnostic diag = 393 emitOpError() 394 << "expected children ops to implement TransformOpInterface"; 395 diag.attachNote(child.getLoc()) << "op without interface"; 396 return diag; 397 } 398 399 for (OpResult result : child.getResults()) { 400 auto report = [&]() { 401 return (child.emitError() << "result #" << result.getResultNumber()); 402 }; 403 if (failed(checkDoubleConsume(result, report))) 404 return failure(); 405 } 406 } 407 408 if (getBodyBlock()->getTerminator()->getOperandTypes() != 409 getOperation()->getResultTypes()) { 410 InFlightDiagnostic diag = emitOpError() 411 << "expects the types of the terminator operands " 412 "to match the types of the result"; 413 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator"; 414 return diag; 415 } 416 return success(); 417 } 418 419 void transform::SequenceOp::getEffects( 420 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 421 auto *mappingResource = TransformMappingResource::get(); 422 effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource); 423 424 for (Value result : getResults()) { 425 effects.emplace_back(MemoryEffects::Allocate::get(), result, 426 mappingResource); 427 effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource); 428 } 429 430 if (!getRoot()) { 431 for (Operation &op : *getBodyBlock()) { 432 auto iface = dyn_cast<MemoryEffectOpInterface>(&op); 433 if (!iface) { 434 // TODO: fill all possible effects; or require ops to actually implement 435 // the memory effect interface always 436 assert(false); 437 } 438 439 SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects; 440 iface.getEffects(effects); 441 } 442 return; 443 } 444 445 // Carry over all effects on the argument of the entry block as those on the 446 // operand, this is the same value just remapped. 447 for (Operation &op : *getBodyBlock()) { 448 auto iface = dyn_cast<MemoryEffectOpInterface>(&op); 449 if (!iface) { 450 // TODO: fill all possible effects; or require ops to actually implement 451 // the memory effect interface always 452 assert(false); 453 } 454 455 SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects; 456 iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects); 457 for (const auto &effect : nestedEffects) 458 effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource()); 459 } 460 } 461 462 OperandRange 463 transform::SequenceOp::getSuccessorEntryOperands(Optional<unsigned> index) { 464 assert(index && *index == 0 && "unexpected region index"); 465 if (getOperation()->getNumOperands() == 1) 466 return getOperation()->getOperands(); 467 return OperandRange(getOperation()->operand_end(), 468 getOperation()->operand_end()); 469 } 470 471 void transform::SequenceOp::getSuccessorRegions( 472 Optional<unsigned> index, ArrayRef<Attribute> operands, 473 SmallVectorImpl<RegionSuccessor> ®ions) { 474 if (!index) { 475 Region *bodyRegion = &getBody(); 476 regions.emplace_back(bodyRegion, !operands.empty() 477 ? bodyRegion->getArguments() 478 : Block::BlockArgListType()); 479 return; 480 } 481 482 assert(*index == 0 && "unexpected region index"); 483 regions.emplace_back(getOperation()->getResults()); 484 } 485 486 void transform::SequenceOp::getRegionInvocationBounds( 487 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { 488 (void)operands; 489 bounds.emplace_back(1, 1); 490 } 491 492 //===----------------------------------------------------------------------===// 493 // WithPDLPatternsOp 494 //===----------------------------------------------------------------------===// 495 496 DiagnosedSilenceableFailure 497 transform::WithPDLPatternsOp::apply(transform::TransformResults &results, 498 transform::TransformState &state) { 499 OwningOpRef<ModuleOp> pdlModuleOp = 500 ModuleOp::create(getOperation()->getLoc()); 501 TransformOpInterface transformOp = nullptr; 502 for (Operation &nested : getBody().front()) { 503 if (!isa<pdl::PatternOp>(nested)) { 504 transformOp = cast<TransformOpInterface>(nested); 505 break; 506 } 507 } 508 509 state.addExtension<PatternApplicatorExtension>(getOperation()); 510 auto guard = llvm::make_scope_exit( 511 [&]() { state.removeExtension<PatternApplicatorExtension>(); }); 512 513 auto scope = state.make_region_scope(getBody()); 514 if (failed(mapBlockArguments(state))) 515 return DiagnosedSilenceableFailure::definiteFailure(); 516 return state.applyTransform(transformOp); 517 } 518 519 LogicalResult transform::WithPDLPatternsOp::verify() { 520 Block *body = getBodyBlock(); 521 Operation *topLevelOp = nullptr; 522 for (Operation &op : body->getOperations()) { 523 if (isa<pdl::PatternOp>(op)) 524 continue; 525 526 if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) { 527 if (topLevelOp) { 528 InFlightDiagnostic diag = 529 emitOpError() << "expects only one non-pattern op in its body"; 530 diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op"; 531 diag.attachNote(op.getLoc()) << "second non-pattern op"; 532 return diag; 533 } 534 topLevelOp = &op; 535 continue; 536 } 537 538 InFlightDiagnostic diag = 539 emitOpError() 540 << "expects only pattern and top-level transform ops in its body"; 541 diag.attachNote(op.getLoc()) << "offending op"; 542 return diag; 543 } 544 545 if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) { 546 InFlightDiagnostic diag = emitOpError() << "cannot be nested"; 547 diag.attachNote(parent.getLoc()) << "parent operation"; 548 return diag; 549 } 550 551 return success(); 552 } 553