1 //===- TransformDialect.cpp - Transform dialect operations ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformOps.h" 10 #include "mlir/Dialect/PDL/IR/PDLOps.h" 11 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/Interfaces/ControlFlowInterfaces.h" 16 #include "mlir/Rewrite/FrozenRewritePatternSet.h" 17 #include "mlir/Rewrite/PatternApplicator.h" 18 #include "llvm/ADT/ScopeExit.h" 19 #include "llvm/Support/Debug.h" 20 21 #define DEBUG_TYPE "transform-dialect" 22 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") 23 24 using namespace mlir; 25 26 #define GET_OP_CLASSES 27 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" 28 29 //===----------------------------------------------------------------------===// 30 // PatternApplicatorExtension 31 //===----------------------------------------------------------------------===// 32 33 namespace { 34 /// A simple pattern rewriter that can be constructed from a context. This is 35 /// necessary to apply patterns to a specific op locally. 36 class TrivialPatternRewriter : public PatternRewriter { 37 public: 38 explicit TrivialPatternRewriter(MLIRContext *context) 39 : PatternRewriter(context) {} 40 }; 41 42 /// A TransformState extension that keeps track of compiled PDL pattern sets. 43 /// This is intended to be used along the WithPDLPatterns op. The extension 44 /// can be constructed given an operation that has a SymbolTable trait and 45 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one 46 /// by one when requested; this behavior is subject to change. 47 class PatternApplicatorExtension : public transform::TransformState::Extension { 48 public: 49 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension) 50 51 /// Creates the extension for patterns contained in `patternContainer`. 52 explicit PatternApplicatorExtension(transform::TransformState &state, 53 Operation *patternContainer) 54 : Extension(state), patterns(patternContainer) {} 55 56 /// Appends to `results` the operations contained in `root` that matched the 57 /// PDL pattern with the given name. Note that `root` may or may not be the 58 /// operation that contains PDL patterns. Reports an error if the pattern 59 /// cannot be found. Note that when no operations are matched, this still 60 /// succeeds as long as the pattern exists. 61 LogicalResult findAllMatches(StringRef patternName, Operation *root, 62 SmallVectorImpl<Operation *> &results); 63 64 private: 65 /// Map from the pattern name to a singleton set of rewrite patterns that only 66 /// contains the pattern with this name. Populated when the pattern is first 67 /// requested. 68 // TODO: reconsider the efficiency of this storage when more usage data is 69 // available. Storing individual patterns in a set and triggering compilation 70 // for each of them has overhead. So does compiling a large set of patterns 71 // only to apply a handlful of them. 72 llvm::StringMap<FrozenRewritePatternSet> compiledPatterns; 73 74 /// A symbol table operation containing the relevant PDL patterns. 75 SymbolTable patterns; 76 }; 77 78 LogicalResult PatternApplicatorExtension::findAllMatches( 79 StringRef patternName, Operation *root, 80 SmallVectorImpl<Operation *> &results) { 81 auto it = compiledPatterns.find(patternName); 82 if (it == compiledPatterns.end()) { 83 auto patternOp = patterns.lookup<pdl::PatternOp>(patternName); 84 if (!patternOp) 85 return failure(); 86 87 OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc()); 88 patternOp->moveBefore(pdlModuleOp->getBody(), 89 pdlModuleOp->getBody()->end()); 90 PDLPatternModule patternModule(std::move(pdlModuleOp)); 91 92 // Merge in the hooks owned by the dialect. Make a copy as they may be 93 // also used by the following operations. 94 auto *dialect = 95 root->getContext()->getLoadedDialect<transform::TransformDialect>(); 96 for (const auto &pair : dialect->getPDLConstraintHooks()) 97 patternModule.registerConstraintFunction(pair.first(), pair.second); 98 99 // Register a noop rewriter because PDL requires patterns to end with some 100 // rewrite call. 101 patternModule.registerRewriteFunction( 102 "transform.dialect", [](PatternRewriter &, Operation *) {}); 103 104 it = compiledPatterns 105 .try_emplace(patternOp.getName(), std::move(patternModule)) 106 .first; 107 } 108 109 PatternApplicator applicator(it->second); 110 TrivialPatternRewriter rewriter(root->getContext()); 111 applicator.applyDefaultCostModel(); 112 root->walk([&](Operation *op) { 113 if (succeeded(applicator.matchAndRewrite(op, rewriter))) 114 results.push_back(op); 115 }); 116 117 return success(); 118 } 119 } // namespace 120 121 //===----------------------------------------------------------------------===// 122 // AlternativesOp 123 //===----------------------------------------------------------------------===// 124 125 OperandRange 126 transform::AlternativesOp::getSuccessorEntryOperands(Optional<unsigned> index) { 127 if (index.hasValue() && getOperation()->getNumOperands() == 1) 128 return getOperation()->getOperands(); 129 return OperandRange(getOperation()->operand_end(), 130 getOperation()->operand_end()); 131 } 132 133 void transform::AlternativesOp::getSuccessorRegions( 134 Optional<unsigned> index, ArrayRef<Attribute> operands, 135 SmallVectorImpl<RegionSuccessor> ®ions) { 136 for (Region &alternative : 137 llvm::drop_begin(getAlternatives(), index.hasValue() ? *index + 1 : 0)) { 138 regions.emplace_back(&alternative, !getOperands().empty() 139 ? alternative.getArguments() 140 : Block::BlockArgListType()); 141 } 142 if (index.hasValue()) 143 regions.emplace_back(getOperation()->getResults()); 144 } 145 146 void transform::AlternativesOp::getRegionInvocationBounds( 147 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { 148 (void)operands; 149 // The region corresponding to the first alternative is always executed, the 150 // remaining may or may not be executed. 151 bounds.reserve(getNumRegions()); 152 bounds.emplace_back(1, 1); 153 bounds.resize(getNumRegions(), InvocationBounds(0, 1)); 154 } 155 156 static void forwardTerminatorOperands(Block *block, 157 transform::TransformState &state, 158 transform::TransformResults &results) { 159 for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(), 160 block->getParentOp()->getOpResults())) { 161 Value terminatorOperand = std::get<0>(pair); 162 OpResult result = std::get<1>(pair); 163 results.set(result, state.getPayloadOps(terminatorOperand)); 164 } 165 } 166 167 DiagnosedSilencableFailure 168 transform::AlternativesOp::apply(transform::TransformResults &results, 169 transform::TransformState &state) { 170 SmallVector<Operation *> originals; 171 if (Value scopeHandle = getScope()) 172 llvm::append_range(originals, state.getPayloadOps(scopeHandle)); 173 else 174 originals.push_back(state.getTopLevel()); 175 176 for (Operation *original : originals) { 177 if (original->isAncestor(getOperation())) { 178 InFlightDiagnostic diag = 179 emitError() << "scope must not contain the transforms being applied"; 180 diag.attachNote(original->getLoc()) << "scope"; 181 return DiagnosedSilencableFailure::definiteFailure(); 182 } 183 } 184 185 for (Region ® : getAlternatives()) { 186 // Clone the scope operations and make the transforms in this alternative 187 // region apply to them by virtue of mapping the block argument (the only 188 // visible handle) to the cloned scope operations. This effectively prevents 189 // the transformation from accessing any IR outside the scope. 190 auto scope = state.make_region_scope(reg); 191 auto clones = llvm::to_vector( 192 llvm::map_range(originals, [](Operation *op) { return op->clone(); })); 193 if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones))) 194 return DiagnosedSilencableFailure::definiteFailure(); 195 auto deleteClones = llvm::make_scope_exit([&] { 196 for (Operation *clone : clones) 197 clone->erase(); 198 }); 199 200 bool failed = false; 201 for (Operation &transform : reg.front().without_terminator()) { 202 DiagnosedSilencableFailure result = 203 state.applyTransform(cast<TransformOpInterface>(transform)); 204 if (result.isSilencableFailure()) { 205 LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage() 206 << "\n"); 207 failed = true; 208 break; 209 } 210 211 if (::mlir::failed(result.silence())) 212 return DiagnosedSilencableFailure::definiteFailure(); 213 } 214 215 // If all operations in the given alternative succeeded, no need to consider 216 // the rest. Replace the original scoping operation with the clone on which 217 // the transformations were performed. 218 if (!failed) { 219 // We will be using the clones, so cancel their scheduled deletion. 220 deleteClones.release(); 221 IRRewriter rewriter(getContext()); 222 for (const auto &kvp : llvm::zip(originals, clones)) { 223 Operation *original = std::get<0>(kvp); 224 Operation *clone = std::get<1>(kvp); 225 original->getBlock()->getOperations().insert(original->getIterator(), 226 clone); 227 rewriter.replaceOp(original, clone->getResults()); 228 } 229 forwardTerminatorOperands(®.front(), state, results); 230 return DiagnosedSilencableFailure::success(); 231 } 232 } 233 return emitSilencableError() << "all alternatives failed"; 234 } 235 236 LogicalResult transform::AlternativesOp::verify() { 237 for (Region &alternative : getAlternatives()) { 238 Block &block = alternative.front(); 239 if (block.getNumArguments() != 1 || 240 !block.getArgument(0).getType().isa<pdl::OperationType>()) { 241 return emitOpError() 242 << "expects region blocks to have one operand of type " 243 << pdl::OperationType::get(getContext()); 244 } 245 246 Operation *terminator = block.getTerminator(); 247 if (terminator->getOperands().getTypes() != getResults().getTypes()) { 248 InFlightDiagnostic diag = emitOpError() 249 << "expects terminator operands to have the " 250 "same type as results of the operation"; 251 diag.attachNote(terminator->getLoc()) << "terminator"; 252 return diag; 253 } 254 } 255 256 return success(); 257 } 258 259 //===----------------------------------------------------------------------===// 260 // GetClosestIsolatedParentOp 261 //===----------------------------------------------------------------------===// 262 263 DiagnosedSilencableFailure transform::GetClosestIsolatedParentOp::apply( 264 transform::TransformResults &results, transform::TransformState &state) { 265 SetVector<Operation *> parents; 266 for (Operation *target : state.getPayloadOps(getTarget())) { 267 Operation *parent = 268 target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(); 269 if (!parent) { 270 DiagnosedSilencableFailure diag = 271 emitSilencableError() 272 << "could not find an isolated-from-above parent op"; 273 diag.attachNote(target->getLoc()) << "target op"; 274 return diag; 275 } 276 parents.insert(parent); 277 } 278 results.set(getResult().cast<OpResult>(), parents.getArrayRef()); 279 return DiagnosedSilencableFailure::success(); 280 } 281 282 //===----------------------------------------------------------------------===// 283 // PDLMatchOp 284 //===----------------------------------------------------------------------===// 285 286 DiagnosedSilencableFailure 287 transform::PDLMatchOp::apply(transform::TransformResults &results, 288 transform::TransformState &state) { 289 auto *extension = state.getExtension<PatternApplicatorExtension>(); 290 assert(extension && 291 "expected PatternApplicatorExtension to be attached by the parent op"); 292 SmallVector<Operation *> targets; 293 for (Operation *root : state.getPayloadOps(getRoot())) { 294 if (failed(extension->findAllMatches( 295 getPatternName().getLeafReference().getValue(), root, targets))) { 296 emitOpError() << "could not find pattern '" << getPatternName() << "'"; 297 return DiagnosedSilencableFailure::definiteFailure(); 298 } 299 } 300 results.set(getResult().cast<OpResult>(), targets); 301 return DiagnosedSilencableFailure::success(); 302 } 303 304 //===----------------------------------------------------------------------===// 305 // SequenceOp 306 //===----------------------------------------------------------------------===// 307 308 DiagnosedSilencableFailure 309 transform::SequenceOp::apply(transform::TransformResults &results, 310 transform::TransformState &state) { 311 // Map the entry block argument to the list of operations. 312 auto scope = state.make_region_scope(*getBodyBlock()->getParent()); 313 if (failed(mapBlockArguments(state))) 314 return DiagnosedSilencableFailure::definiteFailure(); 315 316 // Apply the sequenced ops one by one. 317 for (Operation &transform : getBodyBlock()->without_terminator()) { 318 DiagnosedSilencableFailure result = 319 state.applyTransform(cast<TransformOpInterface>(transform)); 320 if (!result.succeeded()) 321 return result; 322 } 323 324 // Forward the operation mapping for values yielded from the sequence to the 325 // values produced by the sequence op. 326 forwardTerminatorOperands(getBodyBlock(), state, results); 327 return DiagnosedSilencableFailure::success(); 328 } 329 330 /// Returns `true` if the given op operand may be consuming the handle value in 331 /// the Transform IR. That is, if it may have a Free effect on it. 332 static bool isValueUsePotentialConsumer(OpOperand &use) { 333 // Conservatively assume the effect being present in absence of the interface. 334 auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(use.getOwner()); 335 if (!memEffectInterface) 336 return true; 337 338 SmallVector<MemoryEffects::EffectInstance, 2> effects; 339 memEffectInterface.getEffectsOnValue(use.get(), effects); 340 return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { 341 return isa<transform::TransformMappingResource>(effect.getResource()) && 342 isa<MemoryEffects::Free>(effect.getEffect()); 343 }); 344 } 345 346 LogicalResult 347 checkDoubleConsume(Value value, 348 function_ref<InFlightDiagnostic()> reportError) { 349 OpOperand *potentialConsumer = nullptr; 350 for (OpOperand &use : value.getUses()) { 351 if (!isValueUsePotentialConsumer(use)) 352 continue; 353 354 if (!potentialConsumer) { 355 potentialConsumer = &use; 356 continue; 357 } 358 359 InFlightDiagnostic diag = reportError() 360 << " has more than one potential consumer"; 361 diag.attachNote(potentialConsumer->getOwner()->getLoc()) 362 << "used here as operand #" << potentialConsumer->getOperandNumber(); 363 diag.attachNote(use.getOwner()->getLoc()) 364 << "used here as operand #" << use.getOperandNumber(); 365 return diag; 366 } 367 368 return success(); 369 } 370 371 LogicalResult transform::SequenceOp::verify() { 372 // Check if the block argument has more than one consuming use. 373 for (BlockArgument argument : getBodyBlock()->getArguments()) { 374 auto report = [&]() { 375 return (emitOpError() << "block argument #" << argument.getArgNumber()); 376 }; 377 if (failed(checkDoubleConsume(argument, report))) 378 return failure(); 379 } 380 381 // Check properties of the nested operations they cannot check themselves. 382 for (Operation &child : *getBodyBlock()) { 383 if (!isa<TransformOpInterface>(child) && 384 &child != &getBodyBlock()->back()) { 385 InFlightDiagnostic diag = 386 emitOpError() 387 << "expected children ops to implement TransformOpInterface"; 388 diag.attachNote(child.getLoc()) << "op without interface"; 389 return diag; 390 } 391 392 for (OpResult result : child.getResults()) { 393 auto report = [&]() { 394 return (child.emitError() << "result #" << result.getResultNumber()); 395 }; 396 if (failed(checkDoubleConsume(result, report))) 397 return failure(); 398 } 399 } 400 401 if (getBodyBlock()->getTerminator()->getOperandTypes() != 402 getOperation()->getResultTypes()) { 403 InFlightDiagnostic diag = emitOpError() 404 << "expects the types of the terminator operands " 405 "to match the types of the result"; 406 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator"; 407 return diag; 408 } 409 return success(); 410 } 411 412 void transform::SequenceOp::getEffects( 413 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 414 auto *mappingResource = TransformMappingResource::get(); 415 effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource); 416 417 for (Value result : getResults()) { 418 effects.emplace_back(MemoryEffects::Allocate::get(), result, 419 mappingResource); 420 effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource); 421 } 422 423 if (!getRoot()) { 424 for (Operation &op : *getBodyBlock()) { 425 auto iface = dyn_cast<MemoryEffectOpInterface>(&op); 426 if (!iface) { 427 // TODO: fill all possible effects; or require ops to actually implement 428 // the memory effect interface always 429 assert(false); 430 } 431 432 SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects; 433 iface.getEffects(effects); 434 } 435 return; 436 } 437 438 // Carry over all effects on the argument of the entry block as those on the 439 // operand, this is the same value just remapped. 440 for (Operation &op : *getBodyBlock()) { 441 auto iface = dyn_cast<MemoryEffectOpInterface>(&op); 442 if (!iface) { 443 // TODO: fill all possible effects; or require ops to actually implement 444 // the memory effect interface always 445 assert(false); 446 } 447 448 SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects; 449 iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects); 450 for (const auto &effect : nestedEffects) 451 effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource()); 452 } 453 } 454 455 OperandRange 456 transform::SequenceOp::getSuccessorEntryOperands(Optional<unsigned> index) { 457 assert(index && *index == 0 && "unexpected region index"); 458 if (getOperation()->getNumOperands() == 1) 459 return getOperation()->getOperands(); 460 return OperandRange(getOperation()->operand_end(), 461 getOperation()->operand_end()); 462 } 463 464 void transform::SequenceOp::getSuccessorRegions( 465 Optional<unsigned> index, ArrayRef<Attribute> operands, 466 SmallVectorImpl<RegionSuccessor> ®ions) { 467 if (!index.hasValue()) { 468 Region *bodyRegion = &getBody(); 469 regions.emplace_back(bodyRegion, !operands.empty() 470 ? bodyRegion->getArguments() 471 : Block::BlockArgListType()); 472 return; 473 } 474 475 assert(*index == 0 && "unexpected region index"); 476 regions.emplace_back(getOperation()->getResults()); 477 } 478 479 void transform::SequenceOp::getRegionInvocationBounds( 480 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { 481 (void)operands; 482 bounds.emplace_back(1, 1); 483 } 484 485 //===----------------------------------------------------------------------===// 486 // WithPDLPatternsOp 487 //===----------------------------------------------------------------------===// 488 489 DiagnosedSilencableFailure 490 transform::WithPDLPatternsOp::apply(transform::TransformResults &results, 491 transform::TransformState &state) { 492 OwningOpRef<ModuleOp> pdlModuleOp = 493 ModuleOp::create(getOperation()->getLoc()); 494 TransformOpInterface transformOp = nullptr; 495 for (Operation &nested : getBody().front()) { 496 if (!isa<pdl::PatternOp>(nested)) { 497 transformOp = cast<TransformOpInterface>(nested); 498 break; 499 } 500 } 501 502 state.addExtension<PatternApplicatorExtension>(getOperation()); 503 auto guard = llvm::make_scope_exit( 504 [&]() { state.removeExtension<PatternApplicatorExtension>(); }); 505 506 auto scope = state.make_region_scope(getBody()); 507 if (failed(mapBlockArguments(state))) 508 return DiagnosedSilencableFailure::definiteFailure(); 509 return state.applyTransform(transformOp); 510 } 511 512 LogicalResult transform::WithPDLPatternsOp::verify() { 513 Block *body = getBodyBlock(); 514 Operation *topLevelOp = nullptr; 515 for (Operation &op : body->getOperations()) { 516 if (isa<pdl::PatternOp>(op)) 517 continue; 518 519 if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) { 520 if (topLevelOp) { 521 InFlightDiagnostic diag = 522 emitOpError() << "expects only one non-pattern op in its body"; 523 diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op"; 524 diag.attachNote(op.getLoc()) << "second non-pattern op"; 525 return diag; 526 } 527 topLevelOp = &op; 528 continue; 529 } 530 531 InFlightDiagnostic diag = 532 emitOpError() 533 << "expects only pattern and top-level transform ops in its body"; 534 diag.attachNote(op.getLoc()) << "offending op"; 535 return diag; 536 } 537 538 if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) { 539 InFlightDiagnostic diag = emitOpError() << "cannot be nested"; 540 diag.attachNote(parent.getLoc()) << "parent operation"; 541 return diag; 542 } 543 544 return success(); 545 } 546