1 //===- TransformDialect.cpp - Transform dialect operations ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformOps.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/OpImplementation.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
17 #include "mlir/Rewrite/PatternApplicator.h"
18 #include "llvm/ADT/ScopeExit.h"
19 
20 using namespace mlir;
21 
22 #define GET_OP_CLASSES
23 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
24 
25 //===----------------------------------------------------------------------===//
26 // PatternApplicatorExtension
27 //===----------------------------------------------------------------------===//
28 
29 namespace {
30 /// A simple pattern rewriter that can be constructed from a context. This is
31 /// necessary to apply patterns to a specific op locally.
32 class TrivialPatternRewriter : public PatternRewriter {
33 public:
34   explicit TrivialPatternRewriter(MLIRContext *context)
35       : PatternRewriter(context) {}
36 };
37 
38 /// A TransformState extension that keeps track of compiled PDL pattern sets.
39 /// This is intended to be used along the WithPDLPatterns op. The extension
40 /// can be constructed given an operation that has a SymbolTable trait and
41 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one
42 /// by one when requested; this behavior is subject to change.
43 class PatternApplicatorExtension : public transform::TransformState::Extension {
44 public:
45   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
46 
47   /// Creates the extension for patterns contained in `patternContainer`.
48   explicit PatternApplicatorExtension(transform::TransformState &state,
49                                       Operation *patternContainer)
50       : Extension(state), patterns(patternContainer) {}
51 
52   /// Appends to `results` the operations contained in `root` that matched the
53   /// PDL pattern with the given name. Note that `root` may or may not be the
54   /// operation that contains PDL patterns. Reports an error if the pattern
55   /// cannot be found. Note that when no operations are matched, this still
56   /// succeeds as long as the pattern exists.
57   LogicalResult findAllMatches(StringRef patternName, Operation *root,
58                                SmallVectorImpl<Operation *> &results);
59 
60 private:
61   /// Map from the pattern name to a singleton set of rewrite patterns that only
62   /// contains the pattern with this name. Populated when the pattern is first
63   /// requested.
64   // TODO: reconsider the efficiency of this storage when more usage data is
65   // available. Storing individual patterns in a set and triggering compilation
66   // for each of them has overhead. So does compiling a large set of patterns
67   // only to apply a handlful of them.
68   llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
69 
70   /// A symbol table operation containing the relevant PDL patterns.
71   SymbolTable patterns;
72 };
73 
74 LogicalResult PatternApplicatorExtension::findAllMatches(
75     StringRef patternName, Operation *root,
76     SmallVectorImpl<Operation *> &results) {
77   auto it = compiledPatterns.find(patternName);
78   if (it == compiledPatterns.end()) {
79     auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
80     if (!patternOp)
81       return failure();
82 
83     OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
84     patternOp->moveBefore(pdlModuleOp->getBody(),
85                           pdlModuleOp->getBody()->end());
86     PDLPatternModule patternModule(std::move(pdlModuleOp));
87 
88     // Merge in the hooks owned by the dialect. Make a copy as they may be
89     // also used by the following operations.
90     auto *dialect =
91         root->getContext()->getLoadedDialect<transform::TransformDialect>();
92     for (const auto &pair : dialect->getPDLConstraintHooks())
93       patternModule.registerConstraintFunction(pair.first(), pair.second);
94 
95     // Register a noop rewriter because PDL requires patterns to end with some
96     // rewrite call.
97     patternModule.registerRewriteFunction(
98         "transform.dialect", [](PatternRewriter &, Operation *) {});
99 
100     it = compiledPatterns
101              .try_emplace(patternOp.getName(), std::move(patternModule))
102              .first;
103   }
104 
105   PatternApplicator applicator(it->second);
106   TrivialPatternRewriter rewriter(root->getContext());
107   applicator.applyDefaultCostModel();
108   root->walk([&](Operation *op) {
109     if (succeeded(applicator.matchAndRewrite(op, rewriter)))
110       results.push_back(op);
111   });
112 
113   return success();
114 }
115 } // namespace
116 
117 //===----------------------------------------------------------------------===//
118 // PDLMatchOp
119 //===----------------------------------------------------------------------===//
120 
121 LogicalResult transform::PDLMatchOp::apply(transform::TransformResults &results,
122                                            transform::TransformState &state) {
123   auto *extension = state.getExtension<PatternApplicatorExtension>();
124   assert(extension &&
125          "expected PatternApplicatorExtension to be attached by the parent op");
126   SmallVector<Operation *> targets;
127   for (Operation *root : state.getPayloadOps(getRoot())) {
128     if (failed(extension->findAllMatches(
129             getPatternName().getLeafReference().getValue(), root, targets))) {
130       return emitOpError() << "could not find pattern '" << getPatternName()
131                            << "'";
132     }
133   }
134   results.set(getResult().cast<OpResult>(), targets);
135   return success();
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // SequenceOp
140 //===----------------------------------------------------------------------===//
141 
142 LogicalResult transform::SequenceOp::apply(transform::TransformResults &results,
143                                            transform::TransformState &state) {
144   // Map the entry block argument to the list of operations.
145   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
146   if (failed(mapBlockArguments(state)))
147     return failure();
148 
149   // Apply the sequenced ops one by one.
150   for (Operation &transform : getBodyBlock()->without_terminator())
151     if (failed(state.applyTransform(cast<TransformOpInterface>(transform))))
152       return failure();
153 
154   // Forward the operation mapping for values yielded from the sequence to the
155   // values produced by the sequence op.
156   for (const auto &pair :
157        llvm::zip(getBodyBlock()->getTerminator()->getOperands(),
158                  getOperation()->getOpResults())) {
159     Value terminatorOperand = std::get<0>(pair);
160     OpResult result = std::get<1>(pair);
161     results.set(result, state.getPayloadOps(terminatorOperand));
162   }
163 
164   return success();
165 }
166 
167 LogicalResult transform::SequenceOp::verify() {
168   for (Operation &child : *getBodyBlock()) {
169     if (!isa<TransformOpInterface>(child) &&
170         &child != &getBodyBlock()->back()) {
171       InFlightDiagnostic diag =
172           emitOpError()
173           << "expected children ops to implement TransformOpInterface";
174       diag.attachNote(child.getLoc()) << "op without interface";
175       return diag;
176     }
177 
178     for (OpResult result : child.getResults()) {
179       if (llvm::hasNItemsOrLess(result.getUses(), 1))
180         continue;
181       InFlightDiagnostic diag = child.emitError()
182                                 << "result #" << result.getResultNumber()
183                                 << " has more than one use";
184       for (OpOperand &use : result.getUses()) {
185         diag.attachNote(use.getOwner()->getLoc())
186             << "used here as operand #" << use.getOperandNumber();
187       }
188       return diag;
189     }
190   }
191 
192   if (getBodyBlock()->getTerminator()->getOperandTypes() !=
193       getOperation()->getResultTypes()) {
194     InFlightDiagnostic diag = emitOpError()
195                               << "expects the types of the terminator operands "
196                                  "to match the types of the result";
197     diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
198     return diag;
199   }
200   return success();
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // WithPDLPatternsOp
205 //===----------------------------------------------------------------------===//
206 
207 LogicalResult
208 transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
209                                     transform::TransformState &state) {
210   OwningOpRef<ModuleOp> pdlModuleOp =
211       ModuleOp::create(getOperation()->getLoc());
212   TransformOpInterface transformOp = nullptr;
213   for (Operation &nested : getBody().front()) {
214     if (!isa<pdl::PatternOp>(nested)) {
215       transformOp = cast<TransformOpInterface>(nested);
216       break;
217     }
218   }
219 
220   state.addExtension<PatternApplicatorExtension>(getOperation());
221   auto guard = llvm::make_scope_exit(
222       [&]() { state.removeExtension<PatternApplicatorExtension>(); });
223 
224   auto scope = state.make_region_scope(getBody());
225   if (failed(mapBlockArguments(state)))
226     return failure();
227   return state.applyTransform(transformOp);
228 }
229 
230 LogicalResult transform::WithPDLPatternsOp::verify() {
231   Block *body = getBodyBlock();
232   Operation *topLevelOp = nullptr;
233   for (Operation &op : body->getOperations()) {
234     if (isa<pdl::PatternOp>(op))
235       continue;
236 
237     if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
238       if (topLevelOp) {
239         InFlightDiagnostic diag =
240             emitOpError() << "expects only one non-pattern op in its body";
241         diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
242         diag.attachNote(op.getLoc()) << "second non-pattern op";
243         return diag;
244       }
245       topLevelOp = &op;
246       continue;
247     }
248 
249     InFlightDiagnostic diag =
250         emitOpError()
251         << "expects only pattern and top-level transform ops in its body";
252     diag.attachNote(op.getLoc()) << "offending op";
253     return diag;
254   }
255 
256   if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
257     InFlightDiagnostic diag = emitOpError() << "cannot be nested";
258     diag.attachNote(parent.getLoc()) << "parent operation";
259     return diag;
260   }
261 
262   return success();
263 }
264