1 //===- TransformDialect.cpp - Transform dialect operations ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformOps.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/OpImplementation.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Interfaces/ControlFlowInterfaces.h"
17 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
18 #include "mlir/Rewrite/PatternApplicator.h"
19 #include "llvm/ADT/ScopeExit.h"
20 
21 using namespace mlir;
22 
23 #define GET_OP_CLASSES
24 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
25 
26 //===----------------------------------------------------------------------===//
27 // PatternApplicatorExtension
28 //===----------------------------------------------------------------------===//
29 
30 namespace {
31 /// A simple pattern rewriter that can be constructed from a context. This is
32 /// necessary to apply patterns to a specific op locally.
33 class TrivialPatternRewriter : public PatternRewriter {
34 public:
35   explicit TrivialPatternRewriter(MLIRContext *context)
36       : PatternRewriter(context) {}
37 };
38 
39 /// A TransformState extension that keeps track of compiled PDL pattern sets.
40 /// This is intended to be used along the WithPDLPatterns op. The extension
41 /// can be constructed given an operation that has a SymbolTable trait and
42 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one
43 /// by one when requested; this behavior is subject to change.
44 class PatternApplicatorExtension : public transform::TransformState::Extension {
45 public:
46   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
47 
48   /// Creates the extension for patterns contained in `patternContainer`.
49   explicit PatternApplicatorExtension(transform::TransformState &state,
50                                       Operation *patternContainer)
51       : Extension(state), patterns(patternContainer) {}
52 
53   /// Appends to `results` the operations contained in `root` that matched the
54   /// PDL pattern with the given name. Note that `root` may or may not be the
55   /// operation that contains PDL patterns. Reports an error if the pattern
56   /// cannot be found. Note that when no operations are matched, this still
57   /// succeeds as long as the pattern exists.
58   LogicalResult findAllMatches(StringRef patternName, Operation *root,
59                                SmallVectorImpl<Operation *> &results);
60 
61 private:
62   /// Map from the pattern name to a singleton set of rewrite patterns that only
63   /// contains the pattern with this name. Populated when the pattern is first
64   /// requested.
65   // TODO: reconsider the efficiency of this storage when more usage data is
66   // available. Storing individual patterns in a set and triggering compilation
67   // for each of them has overhead. So does compiling a large set of patterns
68   // only to apply a handlful of them.
69   llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
70 
71   /// A symbol table operation containing the relevant PDL patterns.
72   SymbolTable patterns;
73 };
74 
75 LogicalResult PatternApplicatorExtension::findAllMatches(
76     StringRef patternName, Operation *root,
77     SmallVectorImpl<Operation *> &results) {
78   auto it = compiledPatterns.find(patternName);
79   if (it == compiledPatterns.end()) {
80     auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
81     if (!patternOp)
82       return failure();
83 
84     OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
85     patternOp->moveBefore(pdlModuleOp->getBody(),
86                           pdlModuleOp->getBody()->end());
87     PDLPatternModule patternModule(std::move(pdlModuleOp));
88 
89     // Merge in the hooks owned by the dialect. Make a copy as they may be
90     // also used by the following operations.
91     auto *dialect =
92         root->getContext()->getLoadedDialect<transform::TransformDialect>();
93     for (const auto &pair : dialect->getPDLConstraintHooks())
94       patternModule.registerConstraintFunction(pair.first(), pair.second);
95 
96     // Register a noop rewriter because PDL requires patterns to end with some
97     // rewrite call.
98     patternModule.registerRewriteFunction(
99         "transform.dialect", [](PatternRewriter &, Operation *) {});
100 
101     it = compiledPatterns
102              .try_emplace(patternOp.getName(), std::move(patternModule))
103              .first;
104   }
105 
106   PatternApplicator applicator(it->second);
107   TrivialPatternRewriter rewriter(root->getContext());
108   applicator.applyDefaultCostModel();
109   root->walk([&](Operation *op) {
110     if (succeeded(applicator.matchAndRewrite(op, rewriter)))
111       results.push_back(op);
112   });
113 
114   return success();
115 }
116 } // namespace
117 
118 //===----------------------------------------------------------------------===//
119 // GetClosestIsolatedParentOp
120 //===----------------------------------------------------------------------===//
121 
122 LogicalResult transform::GetClosestIsolatedParentOp::apply(
123     transform::TransformResults &results, transform::TransformState &state) {
124   SetVector<Operation *> parents;
125   for (Operation *target : state.getPayloadOps(getTarget())) {
126     Operation *parent =
127         target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
128     if (!parent) {
129       InFlightDiagnostic diag =
130           emitError() << "could not find an isolated-from-above parent op";
131       diag.attachNote(target->getLoc()) << "target op";
132       return diag;
133     }
134     parents.insert(parent);
135   }
136   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
137   return success();
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // PDLMatchOp
142 //===----------------------------------------------------------------------===//
143 
144 LogicalResult transform::PDLMatchOp::apply(transform::TransformResults &results,
145                                            transform::TransformState &state) {
146   auto *extension = state.getExtension<PatternApplicatorExtension>();
147   assert(extension &&
148          "expected PatternApplicatorExtension to be attached by the parent op");
149   SmallVector<Operation *> targets;
150   for (Operation *root : state.getPayloadOps(getRoot())) {
151     if (failed(extension->findAllMatches(
152             getPatternName().getLeafReference().getValue(), root, targets))) {
153       return emitOpError() << "could not find pattern '" << getPatternName()
154                            << "'";
155     }
156   }
157   results.set(getResult().cast<OpResult>(), targets);
158   return success();
159 }
160 
161 //===----------------------------------------------------------------------===//
162 // SequenceOp
163 //===----------------------------------------------------------------------===//
164 
165 LogicalResult transform::SequenceOp::apply(transform::TransformResults &results,
166                                            transform::TransformState &state) {
167   // Map the entry block argument to the list of operations.
168   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
169   if (failed(mapBlockArguments(state)))
170     return failure();
171 
172   // Apply the sequenced ops one by one.
173   for (Operation &transform : getBodyBlock()->without_terminator())
174     if (failed(state.applyTransform(cast<TransformOpInterface>(transform))))
175       return failure();
176 
177   // Forward the operation mapping for values yielded from the sequence to the
178   // values produced by the sequence op.
179   for (const auto &pair :
180        llvm::zip(getBodyBlock()->getTerminator()->getOperands(),
181                  getOperation()->getOpResults())) {
182     Value terminatorOperand = std::get<0>(pair);
183     OpResult result = std::get<1>(pair);
184     results.set(result, state.getPayloadOps(terminatorOperand));
185   }
186 
187   return success();
188 }
189 
190 /// Returns `true` if the given op operand may be consuming the handle value in
191 /// the Transform IR. That is, if it may have a Free effect on it.
192 static bool isValueUsePotentialConsumer(OpOperand &use) {
193   // Conservatively assume the effect being present in absence of the interface.
194   auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(use.getOwner());
195   if (!memEffectInterface)
196     return true;
197 
198   SmallVector<MemoryEffects::EffectInstance, 2> effects;
199   memEffectInterface.getEffectsOnValue(use.get(), effects);
200   return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
201     return isa<transform::TransformMappingResource>(effect.getResource()) &&
202            isa<MemoryEffects::Free>(effect.getEffect());
203   });
204 }
205 
206 LogicalResult
207 checkDoubleConsume(Value value,
208                    function_ref<InFlightDiagnostic()> reportError) {
209   OpOperand *potentialConsumer = nullptr;
210   for (OpOperand &use : value.getUses()) {
211     if (!isValueUsePotentialConsumer(use))
212       continue;
213 
214     if (!potentialConsumer) {
215       potentialConsumer = &use;
216       continue;
217     }
218 
219     InFlightDiagnostic diag = reportError()
220                               << " has more than one potential consumer";
221     diag.attachNote(potentialConsumer->getOwner()->getLoc())
222         << "used here as operand #" << potentialConsumer->getOperandNumber();
223     diag.attachNote(use.getOwner()->getLoc())
224         << "used here as operand #" << use.getOperandNumber();
225     return diag;
226   }
227 
228   return success();
229 }
230 
231 LogicalResult transform::SequenceOp::verify() {
232   // Check if the block argument has more than one consuming use.
233   for (BlockArgument argument : getBodyBlock()->getArguments()) {
234     auto report = [&]() {
235       return (emitOpError() << "block argument #" << argument.getArgNumber());
236     };
237     if (failed(checkDoubleConsume(argument, report)))
238       return failure();
239   }
240 
241   // Check properties of the nested operations they cannot check themselves.
242   for (Operation &child : *getBodyBlock()) {
243     if (!isa<TransformOpInterface>(child) &&
244         &child != &getBodyBlock()->back()) {
245       InFlightDiagnostic diag =
246           emitOpError()
247           << "expected children ops to implement TransformOpInterface";
248       diag.attachNote(child.getLoc()) << "op without interface";
249       return diag;
250     }
251 
252     for (OpResult result : child.getResults()) {
253       auto report = [&]() {
254         return (child.emitError() << "result #" << result.getResultNumber());
255       };
256       if (failed(checkDoubleConsume(result, report)))
257         return failure();
258     }
259   }
260 
261   if (getBodyBlock()->getTerminator()->getOperandTypes() !=
262       getOperation()->getResultTypes()) {
263     InFlightDiagnostic diag = emitOpError()
264                               << "expects the types of the terminator operands "
265                                  "to match the types of the result";
266     diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
267     return diag;
268   }
269   return success();
270 }
271 
272 void transform::SequenceOp::getEffects(
273     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
274   auto *mappingResource = TransformMappingResource::get();
275   effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource);
276 
277   for (Value result : getResults()) {
278     effects.emplace_back(MemoryEffects::Allocate::get(), result,
279                          mappingResource);
280     effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource);
281   }
282 
283   if (!getRoot()) {
284     for (Operation &op : *getBodyBlock()) {
285       auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
286       if (!iface) {
287         // TODO: fill all possible effects; or require ops to actually implement
288         // the memory effect interface always
289         assert(false);
290       }
291 
292       SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
293       iface.getEffects(effects);
294     }
295     return;
296   }
297 
298   // Carry over all effects on the argument of the entry block as those on the
299   // operand, this is the same value just remapped.
300   for (Operation &op : *getBodyBlock()) {
301     auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
302     if (!iface) {
303       // TODO: fill all possible effects; or require ops to actually implement
304       // the memory effect interface always
305       assert(false);
306     }
307 
308     SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
309     iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects);
310     for (const auto &effect : nestedEffects)
311       effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource());
312   }
313 }
314 
315 OperandRange transform::SequenceOp::getSuccessorEntryOperands(unsigned index) {
316   assert(index == 0 && "unexpected region index");
317   if (getOperation()->getNumOperands() == 1)
318     return getOperation()->getOperands();
319   return OperandRange(getOperation()->operand_end(),
320                       getOperation()->operand_end());
321 }
322 
323 void transform::SequenceOp::getSuccessorRegions(
324     Optional<unsigned> index, ArrayRef<Attribute> operands,
325     SmallVectorImpl<RegionSuccessor> &regions) {
326   if (!index.hasValue()) {
327     Region *bodyRegion = &getBody();
328     regions.emplace_back(bodyRegion, !operands.empty()
329                                          ? bodyRegion->getArguments()
330                                          : Block::BlockArgListType());
331     return;
332   }
333 
334   assert(*index == 0 && "unexpected region index");
335   regions.emplace_back(getOperation()->getResults());
336 }
337 
338 void transform::SequenceOp::getRegionInvocationBounds(
339     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
340   (void)operands;
341   bounds.emplace_back(1, 1);
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // WithPDLPatternsOp
346 //===----------------------------------------------------------------------===//
347 
348 LogicalResult
349 transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
350                                     transform::TransformState &state) {
351   OwningOpRef<ModuleOp> pdlModuleOp =
352       ModuleOp::create(getOperation()->getLoc());
353   TransformOpInterface transformOp = nullptr;
354   for (Operation &nested : getBody().front()) {
355     if (!isa<pdl::PatternOp>(nested)) {
356       transformOp = cast<TransformOpInterface>(nested);
357       break;
358     }
359   }
360 
361   state.addExtension<PatternApplicatorExtension>(getOperation());
362   auto guard = llvm::make_scope_exit(
363       [&]() { state.removeExtension<PatternApplicatorExtension>(); });
364 
365   auto scope = state.make_region_scope(getBody());
366   if (failed(mapBlockArguments(state)))
367     return failure();
368   return state.applyTransform(transformOp);
369 }
370 
371 LogicalResult transform::WithPDLPatternsOp::verify() {
372   Block *body = getBodyBlock();
373   Operation *topLevelOp = nullptr;
374   for (Operation &op : body->getOperations()) {
375     if (isa<pdl::PatternOp>(op))
376       continue;
377 
378     if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
379       if (topLevelOp) {
380         InFlightDiagnostic diag =
381             emitOpError() << "expects only one non-pattern op in its body";
382         diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
383         diag.attachNote(op.getLoc()) << "second non-pattern op";
384         return diag;
385       }
386       topLevelOp = &op;
387       continue;
388     }
389 
390     InFlightDiagnostic diag =
391         emitOpError()
392         << "expects only pattern and top-level transform ops in its body";
393     diag.attachNote(op.getLoc()) << "offending op";
394     return diag;
395   }
396 
397   if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
398     InFlightDiagnostic diag = emitOpError() << "cannot be nested";
399     diag.attachNote(parent.getLoc()) << "parent operation";
400     return diag;
401   }
402 
403   return success();
404 }
405