1 //===- TransformDialect.cpp - Transform dialect operations ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformOps.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/OpImplementation.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
17 #include "mlir/Rewrite/PatternApplicator.h"
18 #include "llvm/ADT/ScopeExit.h"
19 
20 using namespace mlir;
21 
22 #define GET_OP_CLASSES
23 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
24 
25 //===----------------------------------------------------------------------===//
26 // PatternApplicatorExtension
27 //===----------------------------------------------------------------------===//
28 
29 namespace {
30 /// A simple pattern rewriter that can be constructed from a context. This is
31 /// necessary to apply patterns to a specific op locally.
32 class TrivialPatternRewriter : public PatternRewriter {
33 public:
34   explicit TrivialPatternRewriter(MLIRContext *context)
35       : PatternRewriter(context) {}
36 };
37 
38 /// A TransformState extension that keeps track of compiled PDL pattern sets.
39 /// This is intended to be used along the WithPDLPatterns op. The extension
40 /// can be constructed given an operation that has a SymbolTable trait and
41 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one
42 /// by one when requested; this behavior is subject to change.
43 class PatternApplicatorExtension : public transform::TransformState::Extension {
44 public:
45   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
46 
47   /// Creates the extension for patterns contained in `patternContainer`.
48   explicit PatternApplicatorExtension(transform::TransformState &state,
49                                       Operation *patternContainer)
50       : Extension(state), patterns(patternContainer) {}
51 
52   /// Appends to `results` the operations contained in `root` that matched the
53   /// PDL pattern with the given name. Note that `root` may or may not be the
54   /// operation that contains PDL patterns. Reports an error if the pattern
55   /// cannot be found. Note that when no operations are matched, this still
56   /// succeeds as long as the pattern exists.
57   LogicalResult findAllMatches(StringRef patternName, Operation *root,
58                                SmallVectorImpl<Operation *> &results);
59 
60 private:
61   /// Map from the pattern name to a singleton set of rewrite patterns that only
62   /// contains the pattern with this name. Populated when the pattern is first
63   /// requested.
64   // TODO: reconsider the efficiency of this storage when more usage data is
65   // available. Storing individual patterns in a set and triggering compilation
66   // for each of them has overhead. So does compiling a large set of patterns
67   // only to apply a handlful of them.
68   llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
69 
70   /// A symbol table operation containing the relevant PDL patterns.
71   SymbolTable patterns;
72 };
73 
74 LogicalResult PatternApplicatorExtension::findAllMatches(
75     StringRef patternName, Operation *root,
76     SmallVectorImpl<Operation *> &results) {
77   auto it = compiledPatterns.find(patternName);
78   if (it == compiledPatterns.end()) {
79     auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
80     if (!patternOp)
81       return failure();
82 
83     OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
84     patternOp->moveBefore(pdlModuleOp->getBody(),
85                           pdlModuleOp->getBody()->end());
86     PDLPatternModule patternModule(std::move(pdlModuleOp));
87 
88     // Merge in the hooks owned by the dialect. Make a copy as they may be
89     // also used by the following operations.
90     auto *dialect =
91         root->getContext()->getLoadedDialect<transform::TransformDialect>();
92     for (const auto &pair : dialect->getPDLConstraintHooks())
93       patternModule.registerConstraintFunction(pair.first(), pair.second);
94 
95     // Register a noop rewriter because PDL requires patterns to end with some
96     // rewrite call.
97     patternModule.registerRewriteFunction(
98         "transform.dialect", [](PatternRewriter &, Operation *) {});
99 
100     it = compiledPatterns
101              .try_emplace(patternOp.getName(), std::move(patternModule))
102              .first;
103   }
104 
105   PatternApplicator applicator(it->second);
106   TrivialPatternRewriter rewriter(root->getContext());
107   applicator.applyDefaultCostModel();
108   root->walk([&](Operation *op) {
109     if (succeeded(applicator.matchAndRewrite(op, rewriter)))
110       results.push_back(op);
111   });
112 
113   return success();
114 }
115 } // namespace
116 
117 //===----------------------------------------------------------------------===//
118 // PDLMatchOp
119 //===----------------------------------------------------------------------===//
120 
121 LogicalResult transform::PDLMatchOp::apply(transform::TransformResults &results,
122                                            transform::TransformState &state) {
123   auto *extension = state.getExtension<PatternApplicatorExtension>();
124   assert(extension &&
125          "expected PatternApplicatorExtension to be attached by the parent op");
126   SmallVector<Operation *> targets;
127   for (Operation *root : state.getPayloadOps(getRoot())) {
128     if (failed(extension->findAllMatches(
129             getPatternName().getLeafReference().getValue(), root, targets))) {
130       return emitOpError() << "could not find pattern '" << getPatternName()
131                            << "'";
132     }
133   }
134   results.set(getResult().cast<OpResult>(), targets);
135   return success();
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // SequenceOp
140 //===----------------------------------------------------------------------===//
141 
142 LogicalResult transform::SequenceOp::apply(transform::TransformResults &results,
143                                            transform::TransformState &state) {
144   // Map the entry block argument to the list of operations.
145   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
146   if (failed(mapBlockArguments(state)))
147     return failure();
148 
149   // Apply the sequenced ops one by one.
150   for (Operation &transform : getBodyBlock()->without_terminator())
151     if (failed(state.applyTransform(cast<TransformOpInterface>(transform))))
152       return failure();
153 
154   // Forward the operation mapping for values yielded from the sequence to the
155   // values produced by the sequence op.
156   for (const auto &pair :
157        llvm::zip(getBodyBlock()->getTerminator()->getOperands(),
158                  getOperation()->getOpResults())) {
159     Value terminatorOperand = std::get<0>(pair);
160     OpResult result = std::get<1>(pair);
161     results.set(result, state.getPayloadOps(terminatorOperand));
162   }
163 
164   return success();
165 }
166 
167 /// Returns `true` if the given op operand may be consuming the handle value in
168 /// the Transform IR. That is, if it may have a Free effect on it.
169 static bool isValueUsePotentialConsumer(OpOperand &use) {
170   // Conservatively assume the effect being present in absence of the interface.
171   auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(use.getOwner());
172   if (!memEffectInterface)
173     return true;
174 
175   SmallVector<MemoryEffects::EffectInstance, 2> effects;
176   memEffectInterface.getEffectsOnValue(use.get(), effects);
177   return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
178     return isa<transform::TransformMappingResource>(effect.getResource()) &&
179            isa<MemoryEffects::Free>(effect.getEffect());
180   });
181 }
182 
183 LogicalResult
184 checkDoubleConsume(Value value,
185                    function_ref<InFlightDiagnostic()> reportError) {
186   OpOperand *potentialConsumer = nullptr;
187   for (OpOperand &use : value.getUses()) {
188     if (!isValueUsePotentialConsumer(use))
189       continue;
190 
191     if (!potentialConsumer) {
192       potentialConsumer = &use;
193       continue;
194     }
195 
196     InFlightDiagnostic diag = reportError()
197                               << " has more than one potential consumer";
198     diag.attachNote(potentialConsumer->getOwner()->getLoc())
199         << "used here as operand #" << potentialConsumer->getOperandNumber();
200     diag.attachNote(use.getOwner()->getLoc())
201         << "used here as operand #" << use.getOperandNumber();
202     return diag;
203   }
204 
205   return success();
206 }
207 
208 LogicalResult transform::SequenceOp::verify() {
209   // Check if the block argument has more than one consuming use.
210   for (BlockArgument argument : getBodyBlock()->getArguments()) {
211     auto report = [&]() {
212       return (emitOpError() << "block argument #" << argument.getArgNumber());
213     };
214     if (failed(checkDoubleConsume(argument, report)))
215       return failure();
216   }
217 
218   // Check properties of the nested operations they cannot check themselves.
219   for (Operation &child : *getBodyBlock()) {
220     if (!isa<TransformOpInterface>(child) &&
221         &child != &getBodyBlock()->back()) {
222       InFlightDiagnostic diag =
223           emitOpError()
224           << "expected children ops to implement TransformOpInterface";
225       diag.attachNote(child.getLoc()) << "op without interface";
226       return diag;
227     }
228 
229     for (OpResult result : child.getResults()) {
230       auto report = [&]() {
231         return (child.emitError() << "result #" << result.getResultNumber());
232       };
233       if (failed(checkDoubleConsume(result, report)))
234         return failure();
235     }
236   }
237 
238   if (getBodyBlock()->getTerminator()->getOperandTypes() !=
239       getOperation()->getResultTypes()) {
240     InFlightDiagnostic diag = emitOpError()
241                               << "expects the types of the terminator operands "
242                                  "to match the types of the result";
243     diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
244     return diag;
245   }
246   return success();
247 }
248 
249 void transform::SequenceOp::getEffects(
250     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
251   auto *mappingResource = TransformMappingResource::get();
252   effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource);
253 
254   for (Value result : getResults()) {
255     effects.emplace_back(MemoryEffects::Allocate::get(), result,
256                          mappingResource);
257     effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource);
258   }
259 
260   if (!getRoot()) {
261     for (Operation &op : *getBodyBlock()) {
262       auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
263       if (!iface) {
264         // TODO: fill all possible effects; or require ops to actually implement
265         // the memory effect interface always
266         assert(false);
267       }
268 
269       SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
270       iface.getEffects(effects);
271     }
272     return;
273   }
274 
275   // Carry over all effects on the argument of the entry block as those on the
276   // operand, this is the same value just remapped.
277   for (Operation &op : *getBodyBlock()) {
278     auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
279     if (!iface) {
280       // TODO: fill all possible effects; or require ops to actually implement
281       // the memory effect interface always
282       assert(false);
283     }
284 
285     SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
286     iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects);
287     for (const auto &effect : nestedEffects)
288       effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource());
289   }
290 }
291 
292 //===----------------------------------------------------------------------===//
293 // WithPDLPatternsOp
294 //===----------------------------------------------------------------------===//
295 
296 LogicalResult
297 transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
298                                     transform::TransformState &state) {
299   OwningOpRef<ModuleOp> pdlModuleOp =
300       ModuleOp::create(getOperation()->getLoc());
301   TransformOpInterface transformOp = nullptr;
302   for (Operation &nested : getBody().front()) {
303     if (!isa<pdl::PatternOp>(nested)) {
304       transformOp = cast<TransformOpInterface>(nested);
305       break;
306     }
307   }
308 
309   state.addExtension<PatternApplicatorExtension>(getOperation());
310   auto guard = llvm::make_scope_exit(
311       [&]() { state.removeExtension<PatternApplicatorExtension>(); });
312 
313   auto scope = state.make_region_scope(getBody());
314   if (failed(mapBlockArguments(state)))
315     return failure();
316   return state.applyTransform(transformOp);
317 }
318 
319 LogicalResult transform::WithPDLPatternsOp::verify() {
320   Block *body = getBodyBlock();
321   Operation *topLevelOp = nullptr;
322   for (Operation &op : body->getOperations()) {
323     if (isa<pdl::PatternOp>(op))
324       continue;
325 
326     if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
327       if (topLevelOp) {
328         InFlightDiagnostic diag =
329             emitOpError() << "expects only one non-pattern op in its body";
330         diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
331         diag.attachNote(op.getLoc()) << "second non-pattern op";
332         return diag;
333       }
334       topLevelOp = &op;
335       continue;
336     }
337 
338     InFlightDiagnostic diag =
339         emitOpError()
340         << "expects only pattern and top-level transform ops in its body";
341     diag.attachNote(op.getLoc()) << "offending op";
342     return diag;
343   }
344 
345   if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
346     InFlightDiagnostic diag = emitOpError() << "cannot be nested";
347     diag.attachNote(parent.getLoc()) << "parent operation";
348     return diag;
349   }
350 
351   return success();
352 }
353