1 //===- TransformDialect.cpp - Transform dialect operations ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformOps.h" 10 #include "mlir/Dialect/PDL/IR/PDLOps.h" 11 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/OpImplementation.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/Rewrite/FrozenRewritePatternSet.h" 17 #include "mlir/Rewrite/PatternApplicator.h" 18 #include "llvm/ADT/ScopeExit.h" 19 20 using namespace mlir; 21 22 #define GET_OP_CLASSES 23 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" 24 25 //===----------------------------------------------------------------------===// 26 // PatternApplicatorExtension 27 //===----------------------------------------------------------------------===// 28 29 namespace { 30 /// A simple pattern rewriter that can be constructed from a context. This is 31 /// necessary to apply patterns to a specific op locally. 32 class TrivialPatternRewriter : public PatternRewriter { 33 public: 34 explicit TrivialPatternRewriter(MLIRContext *context) 35 : PatternRewriter(context) {} 36 }; 37 38 /// A TransformState extension that keeps track of compiled PDL pattern sets. 39 /// This is intended to be used along the WithPDLPatterns op. The extension 40 /// can be constructed given an operation that has a SymbolTable trait and 41 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one 42 /// by one when requested; this behavior is subject to change. 43 class PatternApplicatorExtension : public transform::TransformState::Extension { 44 public: 45 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension) 46 47 /// Creates the extension for patterns contained in `patternContainer`. 48 explicit PatternApplicatorExtension(transform::TransformState &state, 49 Operation *patternContainer) 50 : Extension(state), patterns(patternContainer) {} 51 52 /// Appends to `results` the operations contained in `root` that matched the 53 /// PDL pattern with the given name. Note that `root` may or may not be the 54 /// operation that contains PDL patterns. Reports an error if the pattern 55 /// cannot be found. Note that when no operations are matched, this still 56 /// succeeds as long as the pattern exists. 57 LogicalResult findAllMatches(StringRef patternName, Operation *root, 58 SmallVectorImpl<Operation *> &results); 59 60 private: 61 /// Map from the pattern name to a singleton set of rewrite patterns that only 62 /// contains the pattern with this name. Populated when the pattern is first 63 /// requested. 64 // TODO: reconsider the efficiency of this storage when more usage data is 65 // available. Storing individual patterns in a set and triggering compilation 66 // for each of them has overhead. So does compiling a large set of patterns 67 // only to apply a handlful of them. 68 llvm::StringMap<FrozenRewritePatternSet> compiledPatterns; 69 70 /// A symbol table operation containing the relevant PDL patterns. 71 SymbolTable patterns; 72 }; 73 74 LogicalResult PatternApplicatorExtension::findAllMatches( 75 StringRef patternName, Operation *root, 76 SmallVectorImpl<Operation *> &results) { 77 auto it = compiledPatterns.find(patternName); 78 if (it == compiledPatterns.end()) { 79 auto patternOp = patterns.lookup<pdl::PatternOp>(patternName); 80 if (!patternOp) 81 return failure(); 82 83 OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc()); 84 patternOp->moveBefore(pdlModuleOp->getBody(), 85 pdlModuleOp->getBody()->end()); 86 PDLPatternModule patternModule(std::move(pdlModuleOp)); 87 88 // Merge in the hooks owned by the dialect. Make a copy as they may be 89 // also used by the following operations. 90 auto *dialect = 91 root->getContext()->getLoadedDialect<transform::TransformDialect>(); 92 for (const auto &pair : dialect->getPDLConstraintHooks()) 93 patternModule.registerConstraintFunction(pair.first(), pair.second); 94 95 // Register a noop rewriter because PDL requires patterns to end with some 96 // rewrite call. 97 patternModule.registerRewriteFunction( 98 "transform.dialect", [](PatternRewriter &, Operation *) {}); 99 100 it = compiledPatterns 101 .try_emplace(patternOp.getName(), std::move(patternModule)) 102 .first; 103 } 104 105 PatternApplicator applicator(it->second); 106 TrivialPatternRewriter rewriter(root->getContext()); 107 applicator.applyDefaultCostModel(); 108 root->walk([&](Operation *op) { 109 if (succeeded(applicator.matchAndRewrite(op, rewriter))) 110 results.push_back(op); 111 }); 112 113 return success(); 114 } 115 } // namespace 116 117 //===----------------------------------------------------------------------===// 118 // PDLMatchOp 119 //===----------------------------------------------------------------------===// 120 121 LogicalResult transform::PDLMatchOp::apply(transform::TransformResults &results, 122 transform::TransformState &state) { 123 auto *extension = state.getExtension<PatternApplicatorExtension>(); 124 assert(extension && 125 "expected PatternApplicatorExtension to be attached by the parent op"); 126 SmallVector<Operation *> targets; 127 for (Operation *root : state.getPayloadOps(getRoot())) { 128 if (failed(extension->findAllMatches( 129 getPatternName().getLeafReference().getValue(), root, targets))) { 130 return emitOpError() << "could not find pattern '" << getPatternName() 131 << "'"; 132 } 133 } 134 results.set(getResult().cast<OpResult>(), targets); 135 return success(); 136 } 137 138 //===----------------------------------------------------------------------===// 139 // SequenceOp 140 //===----------------------------------------------------------------------===// 141 142 LogicalResult transform::SequenceOp::apply(transform::TransformResults &results, 143 transform::TransformState &state) { 144 // Map the entry block argument to the list of operations. 145 auto scope = state.make_region_scope(*getBodyBlock()->getParent()); 146 if (failed(mapBlockArguments(state))) 147 return failure(); 148 149 // Apply the sequenced ops one by one. 150 for (Operation &transform : getBodyBlock()->without_terminator()) 151 if (failed(state.applyTransform(cast<TransformOpInterface>(transform)))) 152 return failure(); 153 154 // Forward the operation mapping for values yielded from the sequence to the 155 // values produced by the sequence op. 156 for (const auto &pair : 157 llvm::zip(getBodyBlock()->getTerminator()->getOperands(), 158 getOperation()->getOpResults())) { 159 Value terminatorOperand = std::get<0>(pair); 160 OpResult result = std::get<1>(pair); 161 results.set(result, state.getPayloadOps(terminatorOperand)); 162 } 163 164 return success(); 165 } 166 167 /// Returns `true` if the given op operand may be consuming the handle value in 168 /// the Transform IR. That is, if it may have a Free effect on it. 169 static bool isValueUsePotentialConsumer(OpOperand &use) { 170 // Conservatively assume the effect being present in absence of the interface. 171 auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(use.getOwner()); 172 if (!memEffectInterface) 173 return true; 174 175 SmallVector<MemoryEffects::EffectInstance, 2> effects; 176 memEffectInterface.getEffectsOnValue(use.get(), effects); 177 return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { 178 return isa<transform::TransformMappingResource>(effect.getResource()) && 179 isa<MemoryEffects::Free>(effect.getEffect()); 180 }); 181 } 182 183 LogicalResult 184 checkDoubleConsume(Value value, 185 function_ref<InFlightDiagnostic()> reportError) { 186 OpOperand *potentialConsumer = nullptr; 187 for (OpOperand &use : value.getUses()) { 188 if (!isValueUsePotentialConsumer(use)) 189 continue; 190 191 if (!potentialConsumer) { 192 potentialConsumer = &use; 193 continue; 194 } 195 196 InFlightDiagnostic diag = reportError() 197 << " has more than one potential consumer"; 198 diag.attachNote(potentialConsumer->getOwner()->getLoc()) 199 << "used here as operand #" << potentialConsumer->getOperandNumber(); 200 diag.attachNote(use.getOwner()->getLoc()) 201 << "used here as operand #" << use.getOperandNumber(); 202 return diag; 203 } 204 205 return success(); 206 } 207 208 LogicalResult transform::SequenceOp::verify() { 209 // Check if the block argument has more than one consuming use. 210 for (BlockArgument argument : getBodyBlock()->getArguments()) { 211 auto report = [&]() { 212 return (emitOpError() << "block argument #" << argument.getArgNumber()); 213 }; 214 if (failed(checkDoubleConsume(argument, report))) 215 return failure(); 216 } 217 218 // Check properties of the nested operations they cannot check themselves. 219 for (Operation &child : *getBodyBlock()) { 220 if (!isa<TransformOpInterface>(child) && 221 &child != &getBodyBlock()->back()) { 222 InFlightDiagnostic diag = 223 emitOpError() 224 << "expected children ops to implement TransformOpInterface"; 225 diag.attachNote(child.getLoc()) << "op without interface"; 226 return diag; 227 } 228 229 for (OpResult result : child.getResults()) { 230 auto report = [&]() { 231 return (child.emitError() << "result #" << result.getResultNumber()); 232 }; 233 if (failed(checkDoubleConsume(result, report))) 234 return failure(); 235 } 236 } 237 238 if (getBodyBlock()->getTerminator()->getOperandTypes() != 239 getOperation()->getResultTypes()) { 240 InFlightDiagnostic diag = emitOpError() 241 << "expects the types of the terminator operands " 242 "to match the types of the result"; 243 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator"; 244 return diag; 245 } 246 return success(); 247 } 248 249 void transform::SequenceOp::getEffects( 250 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 251 auto *mappingResource = TransformMappingResource::get(); 252 effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource); 253 254 for (Value result : getResults()) { 255 effects.emplace_back(MemoryEffects::Allocate::get(), result, 256 mappingResource); 257 effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource); 258 } 259 260 if (!getRoot()) { 261 for (Operation &op : *getBodyBlock()) { 262 auto iface = dyn_cast<MemoryEffectOpInterface>(&op); 263 if (!iface) { 264 // TODO: fill all possible effects; or require ops to actually implement 265 // the memory effect interface always 266 assert(false); 267 } 268 269 SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects; 270 iface.getEffects(effects); 271 } 272 return; 273 } 274 275 // Carry over all effects on the argument of the entry block as those on the 276 // operand, this is the same value just remapped. 277 for (Operation &op : *getBodyBlock()) { 278 auto iface = dyn_cast<MemoryEffectOpInterface>(&op); 279 if (!iface) { 280 // TODO: fill all possible effects; or require ops to actually implement 281 // the memory effect interface always 282 assert(false); 283 } 284 285 SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects; 286 iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects); 287 for (const auto &effect : nestedEffects) 288 effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource()); 289 } 290 } 291 292 //===----------------------------------------------------------------------===// 293 // WithPDLPatternsOp 294 //===----------------------------------------------------------------------===// 295 296 LogicalResult 297 transform::WithPDLPatternsOp::apply(transform::TransformResults &results, 298 transform::TransformState &state) { 299 OwningOpRef<ModuleOp> pdlModuleOp = 300 ModuleOp::create(getOperation()->getLoc()); 301 TransformOpInterface transformOp = nullptr; 302 for (Operation &nested : getBody().front()) { 303 if (!isa<pdl::PatternOp>(nested)) { 304 transformOp = cast<TransformOpInterface>(nested); 305 break; 306 } 307 } 308 309 state.addExtension<PatternApplicatorExtension>(getOperation()); 310 auto guard = llvm::make_scope_exit( 311 [&]() { state.removeExtension<PatternApplicatorExtension>(); }); 312 313 auto scope = state.make_region_scope(getBody()); 314 if (failed(mapBlockArguments(state))) 315 return failure(); 316 return state.applyTransform(transformOp); 317 } 318 319 LogicalResult transform::WithPDLPatternsOp::verify() { 320 Block *body = getBodyBlock(); 321 Operation *topLevelOp = nullptr; 322 for (Operation &op : body->getOperations()) { 323 if (isa<pdl::PatternOp>(op)) 324 continue; 325 326 if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) { 327 if (topLevelOp) { 328 InFlightDiagnostic diag = 329 emitOpError() << "expects only one non-pattern op in its body"; 330 diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op"; 331 diag.attachNote(op.getLoc()) << "second non-pattern op"; 332 return diag; 333 } 334 topLevelOp = &op; 335 continue; 336 } 337 338 InFlightDiagnostic diag = 339 emitOpError() 340 << "expects only pattern and top-level transform ops in its body"; 341 diag.attachNote(op.getLoc()) << "offending op"; 342 return diag; 343 } 344 345 if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) { 346 InFlightDiagnostic diag = emitOpError() << "cannot be nested"; 347 diag.attachNote(parent.getLoc()) << "parent operation"; 348 return diag; 349 } 350 351 return success(); 352 } 353