1 //===- TransformDialect.cpp - Transform dialect operations ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/IR/TransformOps.h" 10 #include "mlir/Dialect/PDL/IR/PDLOps.h" 11 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/OpImplementation.h" 15 #include "mlir/IR/PatternMatch.h" 16 #include "mlir/Rewrite/FrozenRewritePatternSet.h" 17 #include "mlir/Rewrite/PatternApplicator.h" 18 #include "llvm/ADT/ScopeExit.h" 19 20 using namespace mlir; 21 22 #define GET_OP_CLASSES 23 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" 24 25 //===----------------------------------------------------------------------===// 26 // PatternApplicatorExtension 27 //===----------------------------------------------------------------------===// 28 29 namespace { 30 /// A simple pattern rewriter that can be constructed from a context. This is 31 /// necessary to apply patterns to a specific op locally. 32 class TrivialPatternRewriter : public PatternRewriter { 33 public: 34 explicit TrivialPatternRewriter(MLIRContext *context) 35 : PatternRewriter(context) {} 36 }; 37 38 /// A TransformState extension that keeps track of compiled PDL pattern sets. 39 /// This is intended to be used along the WithPDLPatterns op. The extension 40 /// can be constructed given an operation that has a SymbolTable trait and 41 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one 42 /// by one when requested; this behavior is subject to change. 43 class PatternApplicatorExtension : public transform::TransformState::Extension { 44 public: 45 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension) 46 47 /// Creates the extension for patterns contained in `patternContainer`. 48 explicit PatternApplicatorExtension(transform::TransformState &state, 49 Operation *patternContainer) 50 : Extension(state), patterns(patternContainer) {} 51 52 /// Appends to `results` the operations contained in `root` that matched the 53 /// PDL pattern with the given name. Note that `root` may or may not be the 54 /// operation that contains PDL patterns. Reports an error if the pattern 55 /// cannot be found. Note that when no operations are matched, this still 56 /// succeeds as long as the pattern exists. 57 LogicalResult findAllMatches(StringRef patternName, Operation *root, 58 SmallVectorImpl<Operation *> &results); 59 60 private: 61 /// Map from the pattern name to a singleton set of rewrite patterns that only 62 /// contains the pattern with this name. Populated when the pattern is first 63 /// requested. 64 // TODO: reconsider the efficiency of this storage when more usage data is 65 // available. Storing individual patterns in a set and triggering compilation 66 // for each of them has overhead. So does compiling a large set of patterns 67 // only to apply a handlful of them. 68 llvm::StringMap<FrozenRewritePatternSet> compiledPatterns; 69 70 /// A symbol table operation containing the relevant PDL patterns. 71 SymbolTable patterns; 72 }; 73 74 LogicalResult PatternApplicatorExtension::findAllMatches( 75 StringRef patternName, Operation *root, 76 SmallVectorImpl<Operation *> &results) { 77 auto it = compiledPatterns.find(patternName); 78 if (it == compiledPatterns.end()) { 79 auto patternOp = patterns.lookup<pdl::PatternOp>(patternName); 80 if (!patternOp) 81 return failure(); 82 83 OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc()); 84 patternOp->moveBefore(pdlModuleOp->getBody(), 85 pdlModuleOp->getBody()->end()); 86 PDLPatternModule patternModule(std::move(pdlModuleOp)); 87 88 // Merge in the hooks owned by the dialect. Make a copy as they may be 89 // also used by the following operations. 90 auto *dialect = 91 root->getContext()->getLoadedDialect<transform::TransformDialect>(); 92 for (const auto &pair : dialect->getPDLConstraintHooks()) 93 patternModule.registerConstraintFunction(pair.first(), pair.second); 94 95 // Register a noop rewriter because PDL requires patterns to end with some 96 // rewrite call. 97 patternModule.registerRewriteFunction( 98 "transform.dialect", [](PatternRewriter &, Operation *) {}); 99 100 it = compiledPatterns 101 .try_emplace(patternOp.getName(), std::move(patternModule)) 102 .first; 103 } 104 105 PatternApplicator applicator(it->second); 106 TrivialPatternRewriter rewriter(root->getContext()); 107 applicator.applyDefaultCostModel(); 108 root->walk([&](Operation *op) { 109 if (succeeded(applicator.matchAndRewrite(op, rewriter))) 110 results.push_back(op); 111 }); 112 113 return success(); 114 } 115 } // namespace 116 117 //===----------------------------------------------------------------------===// 118 // PDLMatchOp 119 //===----------------------------------------------------------------------===// 120 121 LogicalResult transform::PDLMatchOp::apply(transform::TransformResults &results, 122 transform::TransformState &state) { 123 auto *extension = state.getExtension<PatternApplicatorExtension>(); 124 assert(extension && 125 "expected PatternApplicatorExtension to be attached by the parent op"); 126 SmallVector<Operation *> targets; 127 for (Operation *root : state.getPayloadOps(getRoot())) { 128 if (failed(extension->findAllMatches( 129 getPatternName().getLeafReference().getValue(), root, targets))) { 130 return emitOpError() << "could not find pattern '" << getPatternName() 131 << "'"; 132 } 133 } 134 results.set(getResult().cast<OpResult>(), targets); 135 return success(); 136 } 137 138 //===----------------------------------------------------------------------===// 139 // SequenceOp 140 //===----------------------------------------------------------------------===// 141 142 LogicalResult transform::SequenceOp::apply(transform::TransformResults &results, 143 transform::TransformState &state) { 144 // Map the entry block argument to the list of operations. 145 auto scope = state.make_region_scope(*getBodyBlock()->getParent()); 146 if (failed(mapBlockArguments(state))) 147 return failure(); 148 149 // Apply the sequenced ops one by one. 150 for (Operation &transform : getBodyBlock()->without_terminator()) 151 if (failed(state.applyTransform(cast<TransformOpInterface>(transform)))) 152 return failure(); 153 154 // Forward the operation mapping for values yielded from the sequence to the 155 // values produced by the sequence op. 156 for (const auto &pair : 157 llvm::zip(getBodyBlock()->getTerminator()->getOperands(), 158 getOperation()->getOpResults())) { 159 Value terminatorOperand = std::get<0>(pair); 160 OpResult result = std::get<1>(pair); 161 results.set(result, state.getPayloadOps(terminatorOperand)); 162 } 163 164 return success(); 165 } 166 167 LogicalResult transform::SequenceOp::verify() { 168 for (Operation &child : *getBodyBlock()) { 169 if (!isa<TransformOpInterface>(child) && 170 &child != &getBodyBlock()->back()) { 171 InFlightDiagnostic diag = 172 emitOpError() 173 << "expected children ops to implement TransformOpInterface"; 174 diag.attachNote(child.getLoc()) << "op without interface"; 175 return diag; 176 } 177 178 for (OpResult result : child.getResults()) { 179 if (llvm::hasNItemsOrLess(result.getUses(), 1)) 180 continue; 181 InFlightDiagnostic diag = child.emitError() 182 << "result #" << result.getResultNumber() 183 << " has more than one use"; 184 for (OpOperand &use : result.getUses()) { 185 diag.attachNote(use.getOwner()->getLoc()) 186 << "used here as operand #" << use.getOperandNumber(); 187 } 188 return diag; 189 } 190 } 191 192 if (getBodyBlock()->getTerminator()->getOperandTypes() != 193 getOperation()->getResultTypes()) { 194 InFlightDiagnostic diag = emitOpError() 195 << "expects the types of the terminator operands " 196 "to match the types of the result"; 197 diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator"; 198 return diag; 199 } 200 return success(); 201 } 202 203 //===----------------------------------------------------------------------===// 204 // WithPDLPatternsOp 205 //===----------------------------------------------------------------------===// 206 207 LogicalResult 208 transform::WithPDLPatternsOp::apply(transform::TransformResults &results, 209 transform::TransformState &state) { 210 OwningOpRef<ModuleOp> pdlModuleOp = 211 ModuleOp::create(getOperation()->getLoc()); 212 TransformOpInterface transformOp = nullptr; 213 for (Operation &nested : getBody().front()) { 214 if (!isa<pdl::PatternOp>(nested)) { 215 transformOp = cast<TransformOpInterface>(nested); 216 break; 217 } 218 } 219 220 state.addExtension<PatternApplicatorExtension>(getOperation()); 221 auto guard = llvm::make_scope_exit( 222 [&]() { state.removeExtension<PatternApplicatorExtension>(); }); 223 224 auto scope = state.make_region_scope(getBody()); 225 if (failed(mapBlockArguments(state))) 226 return failure(); 227 return state.applyTransform(transformOp); 228 } 229 230 LogicalResult transform::WithPDLPatternsOp::verify() { 231 Block *body = getBodyBlock(); 232 Operation *topLevelOp = nullptr; 233 for (Operation &op : body->getOperations()) { 234 if (isa<pdl::PatternOp>(op)) 235 continue; 236 237 if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) { 238 if (topLevelOp) { 239 InFlightDiagnostic diag = 240 emitOpError() << "expects only one non-pattern op in its body"; 241 diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op"; 242 diag.attachNote(op.getLoc()) << "second non-pattern op"; 243 return diag; 244 } 245 topLevelOp = &op; 246 continue; 247 } 248 249 InFlightDiagnostic diag = 250 emitOpError() 251 << "expects only pattern and top-level transform ops in its body"; 252 diag.attachNote(op.getLoc()) << "offending op"; 253 return diag; 254 } 255 256 if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) { 257 InFlightDiagnostic diag = emitOpError() << "cannot be nested"; 258 diag.attachNote(parent.getLoc()) << "parent operation"; 259 return diag; 260 } 261 262 return success(); 263 } 264