1 //===- TransformDialect.cpp - Transform dialect operations ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformOps.h" 10 #include "mlir/Dialect/PDL/IR/PDLOps.h" 11 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/OpImplementation.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/Interfaces/ControlFlowInterfaces.h" 17 #include "mlir/Rewrite/FrozenRewritePatternSet.h" 18 #include "mlir/Rewrite/PatternApplicator.h" 19 #include "llvm/ADT/ScopeExit.h" 20 21 using namespace mlir; 22 23 #define GET_OP_CLASSES 24 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" 25 26 //===----------------------------------------------------------------------===// 27 // PatternApplicatorExtension 28 //===----------------------------------------------------------------------===// 29 30 namespace { 31 /// A simple pattern rewriter that can be constructed from a context. This is 32 /// necessary to apply patterns to a specific op locally. 33 class TrivialPatternRewriter : public PatternRewriter { 34 public: 35 explicit TrivialPatternRewriter(MLIRContext *context) 36 : PatternRewriter(context) {} 37 }; 38 39 /// A TransformState extension that keeps track of compiled PDL pattern sets. 40 /// This is intended to be used along the WithPDLPatterns op. The extension 41 /// can be constructed given an operation that has a SymbolTable trait and 42 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one 43 /// by one when requested; this behavior is subject to change. 44 class PatternApplicatorExtension : public transform::TransformState::Extension { 45 public: 46 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension) 47 48 /// Creates the extension for patterns contained in `patternContainer`. 49 explicit PatternApplicatorExtension(transform::TransformState &state, 50 Operation *patternContainer) 51 : Extension(state), patterns(patternContainer) {} 52 53 /// Appends to `results` the operations contained in `root` that matched the 54 /// PDL pattern with the given name. Note that `root` may or may not be the 55 /// operation that contains PDL patterns. Reports an error if the pattern 56 /// cannot be found. Note that when no operations are matched, this still 57 /// succeeds as long as the pattern exists. 58 LogicalResult findAllMatches(StringRef patternName, Operation *root, 59 SmallVectorImpl<Operation *> &results); 60 61 private: 62 /// Map from the pattern name to a singleton set of rewrite patterns that only 63 /// contains the pattern with this name. Populated when the pattern is first 64 /// requested. 65 // TODO: reconsider the efficiency of this storage when more usage data is 66 // available. Storing individual patterns in a set and triggering compilation 67 // for each of them has overhead. So does compiling a large set of patterns 68 // only to apply a handlful of them. 69 llvm::StringMap<FrozenRewritePatternSet> compiledPatterns; 70 71 /// A symbol table operation containing the relevant PDL patterns. 72 SymbolTable patterns; 73 }; 74 75 LogicalResult PatternApplicatorExtension::findAllMatches( 76 StringRef patternName, Operation *root, 77 SmallVectorImpl<Operation *> &results) { 78 auto it = compiledPatterns.find(patternName); 79 if (it == compiledPatterns.end()) { 80 auto patternOp = patterns.lookup<pdl::PatternOp>(patternName); 81 if (!patternOp) 82 return failure(); 83 84 OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc()); 85 patternOp->moveBefore(pdlModuleOp->getBody(), 86 pdlModuleOp->getBody()->end()); 87 PDLPatternModule patternModule(std::move(pdlModuleOp)); 88 89 // Merge in the hooks owned by the dialect. Make a copy as they may be 90 // also used by the following operations. 91 auto *dialect = 92 root->getContext()->getLoadedDialect<transform::TransformDialect>(); 93 for (const auto &pair : dialect->getPDLConstraintHooks()) 94 patternModule.registerConstraintFunction(pair.first(), pair.second); 95 96 // Register a noop rewriter because PDL requires patterns to end with some 97 // rewrite call. 98 patternModule.registerRewriteFunction( 99 "transform.dialect", [](PatternRewriter &, Operation *) {}); 100 101 it = compiledPatterns 102 .try_emplace(patternOp.getName(), std::move(patternModule)) 103 .first; 104 } 105 106 PatternApplicator applicator(it->second); 107 TrivialPatternRewriter rewriter(root->getContext()); 108 applicator.applyDefaultCostModel(); 109 root->walk([&](Operation *op) { 110 if (succeeded(applicator.matchAndRewrite(op, rewriter))) 111 results.push_back(op); 112 }); 113 114 return success(); 115 } 116 } // namespace 117 118 //===----------------------------------------------------------------------===// 119 // GetClosestIsolatedParentOp 120 //===----------------------------------------------------------------------===// 121 122 LogicalResult transform::GetClosestIsolatedParentOp::apply( 123 transform::TransformResults &results, transform::TransformState &state) { 124 SetVector<Operation *> parents; 125 for (Operation *target : state.getPayloadOps(getTarget())) { 126 Operation *parent = 127 target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(); 128 if (!parent) { 129 InFlightDiagnostic diag = 130 emitError() << "could not find an isolated-from-above parent op"; 131 diag.attachNote(target->getLoc()) << "target op"; 132 return diag; 133 } 134 parents.insert(parent); 135 } 136 results.set(getResult().cast<OpResult>(), parents.getArrayRef()); 137 return success(); 138 } 139 140 void transform::GetClosestIsolatedParentOp::getEffects( 141 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 142 effects.emplace_back(MemoryEffects::Read::get(), getTarget(), 143 TransformMappingResource::get()); 144 effects.emplace_back(MemoryEffects::Allocate::get(), getParent(), 145 TransformMappingResource::get()); 146 effects.emplace_back(MemoryEffects::Write::get(), getParent(), 147 TransformMappingResource::get()); 148 effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); 149 } 150 151 //===----------------------------------------------------------------------===// 152 // PDLMatchOp 153 //===----------------------------------------------------------------------===// 154 155 LogicalResult transform::PDLMatchOp::apply(transform::TransformResults &results, 156 transform::TransformState &state) { 157 auto *extension = state.getExtension<PatternApplicatorExtension>(); 158 assert(extension && 159 "expected PatternApplicatorExtension to be attached by the parent op"); 160 SmallVector<Operation *> targets; 161 for (Operation *root : state.getPayloadOps(getRoot())) { 162 if (failed(extension->findAllMatches( 163 getPatternName().getLeafReference().getValue(), root, targets))) { 164 return emitOpError() << "could not find pattern '" << getPatternName() 165 << "'"; 166 } 167 } 168 results.set(getResult().cast<OpResult>(), targets); 169 return success(); 170 } 171 172 //===----------------------------------------------------------------------===// 173 // SequenceOp 174 //===----------------------------------------------------------------------===// 175 176 LogicalResult transform::SequenceOp::apply(transform::TransformResults &results, 177 transform::TransformState &state) { 178 // Map the entry block argument to the list of operations. 179 auto scope = state.make_region_scope(*getBodyBlock()->getParent()); 180 if (failed(mapBlockArguments(state))) 181 return failure(); 182 183 // Apply the sequenced ops one by one. 184 for (Operation &transform : getBodyBlock()->without_terminator()) 185 if (failed(state.applyTransform(cast<TransformOpInterface>(transform)))) 186 return failure(); 187 188 // Forward the operation mapping for values yielded from the sequence to the 189 // values produced by the sequence op. 190 for (const auto &pair : 191 llvm::zip(getBodyBlock()->getTerminator()->getOperands(), 192 getOperation()->getOpResults())) { 193 Value terminatorOperand = std::get<0>(pair); 194 OpResult result = std::get<1>(pair); 195 results.set(result, state.getPayloadOps(terminatorOperand)); 196 } 197 198 return success(); 199 } 200 201 /// Returns `true` if the given op operand may be consuming the handle value in 202 /// the Transform IR. That is, if it may have a Free effect on it. 203 static bool isValueUsePotentialConsumer(OpOperand &use) { 204 // Conservatively assume the effect being present in absence of the interface. 205 auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(use.getOwner()); 206 if (!memEffectInterface) 207 return true; 208 209 SmallVector<MemoryEffects::EffectInstance, 2> effects; 210 memEffectInterface.getEffectsOnValue(use.get(), effects); 211 return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { 212 return isa<transform::TransformMappingResource>(effect.getResource()) && 213 isa<MemoryEffects::Free>(effect.getEffect()); 214 }); 215 } 216 217 LogicalResult 218 checkDoubleConsume(Value value, 219 function_ref<InFlightDiagnostic()> reportError) { 220 OpOperand *potentialConsumer = nullptr; 221 for (OpOperand &use : value.getUses()) { 222 if (!isValueUsePotentialConsumer(use)) 223 continue; 224 225 if (!potentialConsumer) { 226 potentialConsumer = &use; 227 continue; 228 } 229 230 InFlightDiagnostic diag = reportError() 231 << " has more than one potential consumer"; 232 diag.attachNote(potentialConsumer->getOwner()->getLoc()) 233 << "used here as operand #" << potentialConsumer->getOperandNumber(); 234 diag.attachNote(use.getOwner()->getLoc()) 235 << "used here as operand #" << use.getOperandNumber(); 236 return diag; 237 } 238 239 return success(); 240 } 241 242 LogicalResult transform::SequenceOp::verify() { 243 // Check if the block argument has more than one consuming use. 244 for (BlockArgument argument : getBodyBlock()->getArguments()) { 245 auto report = [&]() { 246 return (emitOpError() << "block argument #" << argument.getArgNumber()); 247 }; 248 if (failed(checkDoubleConsume(argument, report))) 249 return failure(); 250 } 251 252 // Check properties of the nested operations they cannot check themselves. 253 for (Operation &child : *getBodyBlock()) { 254 if (!isa<TransformOpInterface>(child) && 255 &child != &getBodyBlock()->back()) { 256 InFlightDiagnostic diag = 257 emitOpError() 258 << "expected children ops to implement TransformOpInterface"; 259 diag.attachNote(child.getLoc()) << "op without interface"; 260 return diag; 261 } 262 263 for (OpResult result : child.getResults()) { 264 auto report = [&]() { 265 return (child.emitError() << "result #" << result.getResultNumber()); 266 }; 267 if (failed(checkDoubleConsume(result, report))) 268 return failure(); 269 } 270 } 271 272 if (getBodyBlock()->getTerminator()->getOperandTypes() != 273 getOperation()->getResultTypes()) { 274 InFlightDiagnostic diag = emitOpError() 275 << "expects the types of the terminator operands " 276 "to match the types of the result"; 277 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator"; 278 return diag; 279 } 280 return success(); 281 } 282 283 void transform::SequenceOp::getEffects( 284 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 285 auto *mappingResource = TransformMappingResource::get(); 286 effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource); 287 288 for (Value result : getResults()) { 289 effects.emplace_back(MemoryEffects::Allocate::get(), result, 290 mappingResource); 291 effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource); 292 } 293 294 if (!getRoot()) { 295 for (Operation &op : *getBodyBlock()) { 296 auto iface = dyn_cast<MemoryEffectOpInterface>(&op); 297 if (!iface) { 298 // TODO: fill all possible effects; or require ops to actually implement 299 // the memory effect interface always 300 assert(false); 301 } 302 303 SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects; 304 iface.getEffects(effects); 305 } 306 return; 307 } 308 309 // Carry over all effects on the argument of the entry block as those on the 310 // operand, this is the same value just remapped. 311 for (Operation &op : *getBodyBlock()) { 312 auto iface = dyn_cast<MemoryEffectOpInterface>(&op); 313 if (!iface) { 314 // TODO: fill all possible effects; or require ops to actually implement 315 // the memory effect interface always 316 assert(false); 317 } 318 319 SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects; 320 iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects); 321 for (const auto &effect : nestedEffects) 322 effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource()); 323 } 324 } 325 326 OperandRange transform::SequenceOp::getSuccessorEntryOperands(unsigned index) { 327 assert(index == 0 && "unexpected region index"); 328 if (getOperation()->getNumOperands() == 1) 329 return getOperation()->getOperands(); 330 return OperandRange(getOperation()->operand_end(), 331 getOperation()->operand_end()); 332 } 333 334 void transform::SequenceOp::getSuccessorRegions( 335 Optional<unsigned> index, ArrayRef<Attribute> operands, 336 SmallVectorImpl<RegionSuccessor> ®ions) { 337 if (!index.hasValue()) { 338 Region *bodyRegion = &getBody(); 339 regions.emplace_back(bodyRegion, !operands.empty() 340 ? bodyRegion->getArguments() 341 : Block::BlockArgListType()); 342 return; 343 } 344 345 assert(*index == 0 && "unexpected region index"); 346 regions.emplace_back(getOperation()->getResults()); 347 } 348 349 void transform::SequenceOp::getRegionInvocationBounds( 350 ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) { 351 (void)operands; 352 bounds.emplace_back(1, 1); 353 } 354 355 //===----------------------------------------------------------------------===// 356 // WithPDLPatternsOp 357 //===----------------------------------------------------------------------===// 358 359 LogicalResult 360 transform::WithPDLPatternsOp::apply(transform::TransformResults &results, 361 transform::TransformState &state) { 362 OwningOpRef<ModuleOp> pdlModuleOp = 363 ModuleOp::create(getOperation()->getLoc()); 364 TransformOpInterface transformOp = nullptr; 365 for (Operation &nested : getBody().front()) { 366 if (!isa<pdl::PatternOp>(nested)) { 367 transformOp = cast<TransformOpInterface>(nested); 368 break; 369 } 370 } 371 372 state.addExtension<PatternApplicatorExtension>(getOperation()); 373 auto guard = llvm::make_scope_exit( 374 [&]() { state.removeExtension<PatternApplicatorExtension>(); }); 375 376 auto scope = state.make_region_scope(getBody()); 377 if (failed(mapBlockArguments(state))) 378 return failure(); 379 return state.applyTransform(transformOp); 380 } 381 382 LogicalResult transform::WithPDLPatternsOp::verify() { 383 Block *body = getBodyBlock(); 384 Operation *topLevelOp = nullptr; 385 for (Operation &op : body->getOperations()) { 386 if (isa<pdl::PatternOp>(op)) 387 continue; 388 389 if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) { 390 if (topLevelOp) { 391 InFlightDiagnostic diag = 392 emitOpError() << "expects only one non-pattern op in its body"; 393 diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op"; 394 diag.attachNote(op.getLoc()) << "second non-pattern op"; 395 return diag; 396 } 397 topLevelOp = &op; 398 continue; 399 } 400 401 InFlightDiagnostic diag = 402 emitOpError() 403 << "expects only pattern and top-level transform ops in its body"; 404 diag.attachNote(op.getLoc()) << "offending op"; 405 return diag; 406 } 407 408 if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) { 409 InFlightDiagnostic diag = emitOpError() << "cannot be nested"; 410 diag.attachNote(parent.getLoc()) << "parent operation"; 411 return diag; 412 } 413 414 return success(); 415 } 416