1 //===- TransformDialect.cpp - Transform dialect operations ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformOps.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/OpImplementation.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Interfaces/ControlFlowInterfaces.h"
17 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
18 #include "mlir/Rewrite/PatternApplicator.h"
19 #include "llvm/ADT/ScopeExit.h"
20 
21 using namespace mlir;
22 
23 #define GET_OP_CLASSES
24 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
25 
26 //===----------------------------------------------------------------------===//
27 // PatternApplicatorExtension
28 //===----------------------------------------------------------------------===//
29 
30 namespace {
31 /// A simple pattern rewriter that can be constructed from a context. This is
32 /// necessary to apply patterns to a specific op locally.
33 class TrivialPatternRewriter : public PatternRewriter {
34 public:
35   explicit TrivialPatternRewriter(MLIRContext *context)
36       : PatternRewriter(context) {}
37 };
38 
39 /// A TransformState extension that keeps track of compiled PDL pattern sets.
40 /// This is intended to be used along the WithPDLPatterns op. The extension
41 /// can be constructed given an operation that has a SymbolTable trait and
42 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one
43 /// by one when requested; this behavior is subject to change.
44 class PatternApplicatorExtension : public transform::TransformState::Extension {
45 public:
46   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
47 
48   /// Creates the extension for patterns contained in `patternContainer`.
49   explicit PatternApplicatorExtension(transform::TransformState &state,
50                                       Operation *patternContainer)
51       : Extension(state), patterns(patternContainer) {}
52 
53   /// Appends to `results` the operations contained in `root` that matched the
54   /// PDL pattern with the given name. Note that `root` may or may not be the
55   /// operation that contains PDL patterns. Reports an error if the pattern
56   /// cannot be found. Note that when no operations are matched, this still
57   /// succeeds as long as the pattern exists.
58   LogicalResult findAllMatches(StringRef patternName, Operation *root,
59                                SmallVectorImpl<Operation *> &results);
60 
61 private:
62   /// Map from the pattern name to a singleton set of rewrite patterns that only
63   /// contains the pattern with this name. Populated when the pattern is first
64   /// requested.
65   // TODO: reconsider the efficiency of this storage when more usage data is
66   // available. Storing individual patterns in a set and triggering compilation
67   // for each of them has overhead. So does compiling a large set of patterns
68   // only to apply a handlful of them.
69   llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
70 
71   /// A symbol table operation containing the relevant PDL patterns.
72   SymbolTable patterns;
73 };
74 
75 LogicalResult PatternApplicatorExtension::findAllMatches(
76     StringRef patternName, Operation *root,
77     SmallVectorImpl<Operation *> &results) {
78   auto it = compiledPatterns.find(patternName);
79   if (it == compiledPatterns.end()) {
80     auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
81     if (!patternOp)
82       return failure();
83 
84     OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
85     patternOp->moveBefore(pdlModuleOp->getBody(),
86                           pdlModuleOp->getBody()->end());
87     PDLPatternModule patternModule(std::move(pdlModuleOp));
88 
89     // Merge in the hooks owned by the dialect. Make a copy as they may be
90     // also used by the following operations.
91     auto *dialect =
92         root->getContext()->getLoadedDialect<transform::TransformDialect>();
93     for (const auto &pair : dialect->getPDLConstraintHooks())
94       patternModule.registerConstraintFunction(pair.first(), pair.second);
95 
96     // Register a noop rewriter because PDL requires patterns to end with some
97     // rewrite call.
98     patternModule.registerRewriteFunction(
99         "transform.dialect", [](PatternRewriter &, Operation *) {});
100 
101     it = compiledPatterns
102              .try_emplace(patternOp.getName(), std::move(patternModule))
103              .first;
104   }
105 
106   PatternApplicator applicator(it->second);
107   TrivialPatternRewriter rewriter(root->getContext());
108   applicator.applyDefaultCostModel();
109   root->walk([&](Operation *op) {
110     if (succeeded(applicator.matchAndRewrite(op, rewriter)))
111       results.push_back(op);
112   });
113 
114   return success();
115 }
116 } // namespace
117 
118 //===----------------------------------------------------------------------===//
119 // GetClosestIsolatedParentOp
120 //===----------------------------------------------------------------------===//
121 
122 LogicalResult transform::GetClosestIsolatedParentOp::apply(
123     transform::TransformResults &results, transform::TransformState &state) {
124   SetVector<Operation *> parents;
125   for (Operation *target : state.getPayloadOps(getTarget())) {
126     Operation *parent =
127         target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
128     if (!parent) {
129       InFlightDiagnostic diag =
130           emitError() << "could not find an isolated-from-above parent op";
131       diag.attachNote(target->getLoc()) << "target op";
132       return diag;
133     }
134     parents.insert(parent);
135   }
136   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
137   return success();
138 }
139 
140 void transform::GetClosestIsolatedParentOp::getEffects(
141     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
142   effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
143                        TransformMappingResource::get());
144   effects.emplace_back(MemoryEffects::Allocate::get(), getParent(),
145                        TransformMappingResource::get());
146   effects.emplace_back(MemoryEffects::Write::get(), getParent(),
147                        TransformMappingResource::get());
148   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
149 }
150 
151 //===----------------------------------------------------------------------===//
152 // PDLMatchOp
153 //===----------------------------------------------------------------------===//
154 
155 LogicalResult transform::PDLMatchOp::apply(transform::TransformResults &results,
156                                            transform::TransformState &state) {
157   auto *extension = state.getExtension<PatternApplicatorExtension>();
158   assert(extension &&
159          "expected PatternApplicatorExtension to be attached by the parent op");
160   SmallVector<Operation *> targets;
161   for (Operation *root : state.getPayloadOps(getRoot())) {
162     if (failed(extension->findAllMatches(
163             getPatternName().getLeafReference().getValue(), root, targets))) {
164       return emitOpError() << "could not find pattern '" << getPatternName()
165                            << "'";
166     }
167   }
168   results.set(getResult().cast<OpResult>(), targets);
169   return success();
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // SequenceOp
174 //===----------------------------------------------------------------------===//
175 
176 LogicalResult transform::SequenceOp::apply(transform::TransformResults &results,
177                                            transform::TransformState &state) {
178   // Map the entry block argument to the list of operations.
179   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
180   if (failed(mapBlockArguments(state)))
181     return failure();
182 
183   // Apply the sequenced ops one by one.
184   for (Operation &transform : getBodyBlock()->without_terminator())
185     if (failed(state.applyTransform(cast<TransformOpInterface>(transform))))
186       return failure();
187 
188   // Forward the operation mapping for values yielded from the sequence to the
189   // values produced by the sequence op.
190   for (const auto &pair :
191        llvm::zip(getBodyBlock()->getTerminator()->getOperands(),
192                  getOperation()->getOpResults())) {
193     Value terminatorOperand = std::get<0>(pair);
194     OpResult result = std::get<1>(pair);
195     results.set(result, state.getPayloadOps(terminatorOperand));
196   }
197 
198   return success();
199 }
200 
201 /// Returns `true` if the given op operand may be consuming the handle value in
202 /// the Transform IR. That is, if it may have a Free effect on it.
203 static bool isValueUsePotentialConsumer(OpOperand &use) {
204   // Conservatively assume the effect being present in absence of the interface.
205   auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(use.getOwner());
206   if (!memEffectInterface)
207     return true;
208 
209   SmallVector<MemoryEffects::EffectInstance, 2> effects;
210   memEffectInterface.getEffectsOnValue(use.get(), effects);
211   return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
212     return isa<transform::TransformMappingResource>(effect.getResource()) &&
213            isa<MemoryEffects::Free>(effect.getEffect());
214   });
215 }
216 
217 LogicalResult
218 checkDoubleConsume(Value value,
219                    function_ref<InFlightDiagnostic()> reportError) {
220   OpOperand *potentialConsumer = nullptr;
221   for (OpOperand &use : value.getUses()) {
222     if (!isValueUsePotentialConsumer(use))
223       continue;
224 
225     if (!potentialConsumer) {
226       potentialConsumer = &use;
227       continue;
228     }
229 
230     InFlightDiagnostic diag = reportError()
231                               << " has more than one potential consumer";
232     diag.attachNote(potentialConsumer->getOwner()->getLoc())
233         << "used here as operand #" << potentialConsumer->getOperandNumber();
234     diag.attachNote(use.getOwner()->getLoc())
235         << "used here as operand #" << use.getOperandNumber();
236     return diag;
237   }
238 
239   return success();
240 }
241 
242 LogicalResult transform::SequenceOp::verify() {
243   // Check if the block argument has more than one consuming use.
244   for (BlockArgument argument : getBodyBlock()->getArguments()) {
245     auto report = [&]() {
246       return (emitOpError() << "block argument #" << argument.getArgNumber());
247     };
248     if (failed(checkDoubleConsume(argument, report)))
249       return failure();
250   }
251 
252   // Check properties of the nested operations they cannot check themselves.
253   for (Operation &child : *getBodyBlock()) {
254     if (!isa<TransformOpInterface>(child) &&
255         &child != &getBodyBlock()->back()) {
256       InFlightDiagnostic diag =
257           emitOpError()
258           << "expected children ops to implement TransformOpInterface";
259       diag.attachNote(child.getLoc()) << "op without interface";
260       return diag;
261     }
262 
263     for (OpResult result : child.getResults()) {
264       auto report = [&]() {
265         return (child.emitError() << "result #" << result.getResultNumber());
266       };
267       if (failed(checkDoubleConsume(result, report)))
268         return failure();
269     }
270   }
271 
272   if (getBodyBlock()->getTerminator()->getOperandTypes() !=
273       getOperation()->getResultTypes()) {
274     InFlightDiagnostic diag = emitOpError()
275                               << "expects the types of the terminator operands "
276                                  "to match the types of the result";
277     diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
278     return diag;
279   }
280   return success();
281 }
282 
283 void transform::SequenceOp::getEffects(
284     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
285   auto *mappingResource = TransformMappingResource::get();
286   effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource);
287 
288   for (Value result : getResults()) {
289     effects.emplace_back(MemoryEffects::Allocate::get(), result,
290                          mappingResource);
291     effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource);
292   }
293 
294   if (!getRoot()) {
295     for (Operation &op : *getBodyBlock()) {
296       auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
297       if (!iface) {
298         // TODO: fill all possible effects; or require ops to actually implement
299         // the memory effect interface always
300         assert(false);
301       }
302 
303       SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
304       iface.getEffects(effects);
305     }
306     return;
307   }
308 
309   // Carry over all effects on the argument of the entry block as those on the
310   // operand, this is the same value just remapped.
311   for (Operation &op : *getBodyBlock()) {
312     auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
313     if (!iface) {
314       // TODO: fill all possible effects; or require ops to actually implement
315       // the memory effect interface always
316       assert(false);
317     }
318 
319     SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
320     iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects);
321     for (const auto &effect : nestedEffects)
322       effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource());
323   }
324 }
325 
326 OperandRange transform::SequenceOp::getSuccessorEntryOperands(unsigned index) {
327   assert(index == 0 && "unexpected region index");
328   if (getOperation()->getNumOperands() == 1)
329     return getOperation()->getOperands();
330   return OperandRange(getOperation()->operand_end(),
331                       getOperation()->operand_end());
332 }
333 
334 void transform::SequenceOp::getSuccessorRegions(
335     Optional<unsigned> index, ArrayRef<Attribute> operands,
336     SmallVectorImpl<RegionSuccessor> &regions) {
337   if (!index.hasValue()) {
338     Region *bodyRegion = &getBody();
339     regions.emplace_back(bodyRegion, !operands.empty()
340                                          ? bodyRegion->getArguments()
341                                          : Block::BlockArgListType());
342     return;
343   }
344 
345   assert(*index == 0 && "unexpected region index");
346   regions.emplace_back(getOperation()->getResults());
347 }
348 
349 void transform::SequenceOp::getRegionInvocationBounds(
350     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
351   (void)operands;
352   bounds.emplace_back(1, 1);
353 }
354 
355 //===----------------------------------------------------------------------===//
356 // WithPDLPatternsOp
357 //===----------------------------------------------------------------------===//
358 
359 LogicalResult
360 transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
361                                     transform::TransformState &state) {
362   OwningOpRef<ModuleOp> pdlModuleOp =
363       ModuleOp::create(getOperation()->getLoc());
364   TransformOpInterface transformOp = nullptr;
365   for (Operation &nested : getBody().front()) {
366     if (!isa<pdl::PatternOp>(nested)) {
367       transformOp = cast<TransformOpInterface>(nested);
368       break;
369     }
370   }
371 
372   state.addExtension<PatternApplicatorExtension>(getOperation());
373   auto guard = llvm::make_scope_exit(
374       [&]() { state.removeExtension<PatternApplicatorExtension>(); });
375 
376   auto scope = state.make_region_scope(getBody());
377   if (failed(mapBlockArguments(state)))
378     return failure();
379   return state.applyTransform(transformOp);
380 }
381 
382 LogicalResult transform::WithPDLPatternsOp::verify() {
383   Block *body = getBodyBlock();
384   Operation *topLevelOp = nullptr;
385   for (Operation &op : body->getOperations()) {
386     if (isa<pdl::PatternOp>(op))
387       continue;
388 
389     if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
390       if (topLevelOp) {
391         InFlightDiagnostic diag =
392             emitOpError() << "expects only one non-pattern op in its body";
393         diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
394         diag.attachNote(op.getLoc()) << "second non-pattern op";
395         return diag;
396       }
397       topLevelOp = &op;
398       continue;
399     }
400 
401     InFlightDiagnostic diag =
402         emitOpError()
403         << "expects only pattern and top-level transform ops in its body";
404     diag.attachNote(op.getLoc()) << "offending op";
405     return diag;
406   }
407 
408   if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
409     InFlightDiagnostic diag = emitOpError() << "cannot be nested";
410     diag.attachNote(parent.getLoc()) << "parent operation";
411     return diag;
412   }
413 
414   return success();
415 }
416