1 //===- TransformDialect.cpp - Transform dialect operations ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformOps.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/OpImplementation.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Interfaces/ControlFlowInterfaces.h"
17 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
18 #include "mlir/Rewrite/PatternApplicator.h"
19 #include "llvm/ADT/ScopeExit.h"
20 
21 using namespace mlir;
22 
23 #define GET_OP_CLASSES
24 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
25 
26 //===----------------------------------------------------------------------===//
27 // PatternApplicatorExtension
28 //===----------------------------------------------------------------------===//
29 
30 namespace {
31 /// A simple pattern rewriter that can be constructed from a context. This is
32 /// necessary to apply patterns to a specific op locally.
33 class TrivialPatternRewriter : public PatternRewriter {
34 public:
35   explicit TrivialPatternRewriter(MLIRContext *context)
36       : PatternRewriter(context) {}
37 };
38 
39 /// A TransformState extension that keeps track of compiled PDL pattern sets.
40 /// This is intended to be used along the WithPDLPatterns op. The extension
41 /// can be constructed given an operation that has a SymbolTable trait and
42 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one
43 /// by one when requested; this behavior is subject to change.
44 class PatternApplicatorExtension : public transform::TransformState::Extension {
45 public:
46   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
47 
48   /// Creates the extension for patterns contained in `patternContainer`.
49   explicit PatternApplicatorExtension(transform::TransformState &state,
50                                       Operation *patternContainer)
51       : Extension(state), patterns(patternContainer) {}
52 
53   /// Appends to `results` the operations contained in `root` that matched the
54   /// PDL pattern with the given name. Note that `root` may or may not be the
55   /// operation that contains PDL patterns. Reports an error if the pattern
56   /// cannot be found. Note that when no operations are matched, this still
57   /// succeeds as long as the pattern exists.
58   LogicalResult findAllMatches(StringRef patternName, Operation *root,
59                                SmallVectorImpl<Operation *> &results);
60 
61 private:
62   /// Map from the pattern name to a singleton set of rewrite patterns that only
63   /// contains the pattern with this name. Populated when the pattern is first
64   /// requested.
65   // TODO: reconsider the efficiency of this storage when more usage data is
66   // available. Storing individual patterns in a set and triggering compilation
67   // for each of them has overhead. So does compiling a large set of patterns
68   // only to apply a handlful of them.
69   llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
70 
71   /// A symbol table operation containing the relevant PDL patterns.
72   SymbolTable patterns;
73 };
74 
75 LogicalResult PatternApplicatorExtension::findAllMatches(
76     StringRef patternName, Operation *root,
77     SmallVectorImpl<Operation *> &results) {
78   auto it = compiledPatterns.find(patternName);
79   if (it == compiledPatterns.end()) {
80     auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
81     if (!patternOp)
82       return failure();
83 
84     OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
85     patternOp->moveBefore(pdlModuleOp->getBody(),
86                           pdlModuleOp->getBody()->end());
87     PDLPatternModule patternModule(std::move(pdlModuleOp));
88 
89     // Merge in the hooks owned by the dialect. Make a copy as they may be
90     // also used by the following operations.
91     auto *dialect =
92         root->getContext()->getLoadedDialect<transform::TransformDialect>();
93     for (const auto &pair : dialect->getPDLConstraintHooks())
94       patternModule.registerConstraintFunction(pair.first(), pair.second);
95 
96     // Register a noop rewriter because PDL requires patterns to end with some
97     // rewrite call.
98     patternModule.registerRewriteFunction(
99         "transform.dialect", [](PatternRewriter &, Operation *) {});
100 
101     it = compiledPatterns
102              .try_emplace(patternOp.getName(), std::move(patternModule))
103              .first;
104   }
105 
106   PatternApplicator applicator(it->second);
107   TrivialPatternRewriter rewriter(root->getContext());
108   applicator.applyDefaultCostModel();
109   root->walk([&](Operation *op) {
110     if (succeeded(applicator.matchAndRewrite(op, rewriter)))
111       results.push_back(op);
112   });
113 
114   return success();
115 }
116 } // namespace
117 
118 //===----------------------------------------------------------------------===//
119 // GetClosestIsolatedParentOp
120 //===----------------------------------------------------------------------===//
121 
122 LogicalResult transform::GetClosestIsolatedParentOp::apply(
123     transform::TransformResults &results, transform::TransformState &state) {
124   SetVector<Operation *> parents;
125   for (Operation *target : state.getPayloadOps(getTarget())) {
126     Operation *parent =
127         target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
128     if (!parent) {
129       InFlightDiagnostic diag =
130           emitError() << "could not find an isolated-from-above parent op";
131       diag.attachNote(target->getLoc()) << "target op";
132       return diag;
133     }
134     parents.insert(parent);
135   }
136   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
137   return success();
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // PDLMatchOp
142 //===----------------------------------------------------------------------===//
143 
144 LogicalResult transform::PDLMatchOp::apply(transform::TransformResults &results,
145                                            transform::TransformState &state) {
146   auto *extension = state.getExtension<PatternApplicatorExtension>();
147   assert(extension &&
148          "expected PatternApplicatorExtension to be attached by the parent op");
149   SmallVector<Operation *> targets;
150   for (Operation *root : state.getPayloadOps(getRoot())) {
151     if (failed(extension->findAllMatches(
152             getPatternName().getLeafReference().getValue(), root, targets))) {
153       return emitOpError() << "could not find pattern '" << getPatternName()
154                            << "'";
155     }
156   }
157   results.set(getResult().cast<OpResult>(), targets);
158   return success();
159 }
160 
161 //===----------------------------------------------------------------------===//
162 // SequenceOp
163 //===----------------------------------------------------------------------===//
164 
165 LogicalResult transform::SequenceOp::apply(transform::TransformResults &results,
166                                            transform::TransformState &state) {
167   // Map the entry block argument to the list of operations.
168   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
169   if (failed(mapBlockArguments(state)))
170     return failure();
171 
172   // Apply the sequenced ops one by one.
173   for (Operation &transform : getBodyBlock()->without_terminator())
174     if (failed(state.applyTransform(cast<TransformOpInterface>(transform))))
175       return failure();
176 
177   // Forward the operation mapping for values yielded from the sequence to the
178   // values produced by the sequence op.
179   for (const auto &pair :
180        llvm::zip(getBodyBlock()->getTerminator()->getOperands(),
181                  getOperation()->getOpResults())) {
182     Value terminatorOperand = std::get<0>(pair);
183     OpResult result = std::get<1>(pair);
184     results.set(result, state.getPayloadOps(terminatorOperand));
185   }
186 
187   return success();
188 }
189 
190 /// Returns `true` if the given op operand may be consuming the handle value in
191 /// the Transform IR. That is, if it may have a Free effect on it.
192 static bool isValueUsePotentialConsumer(OpOperand &use) {
193   // Conservatively assume the effect being present in absence of the interface.
194   auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(use.getOwner());
195   if (!memEffectInterface)
196     return true;
197 
198   SmallVector<MemoryEffects::EffectInstance, 2> effects;
199   memEffectInterface.getEffectsOnValue(use.get(), effects);
200   return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
201     return isa<transform::TransformMappingResource>(effect.getResource()) &&
202            isa<MemoryEffects::Free>(effect.getEffect());
203   });
204 }
205 
206 LogicalResult
207 checkDoubleConsume(Value value,
208                    function_ref<InFlightDiagnostic()> reportError) {
209   OpOperand *potentialConsumer = nullptr;
210   for (OpOperand &use : value.getUses()) {
211     if (!isValueUsePotentialConsumer(use))
212       continue;
213 
214     if (!potentialConsumer) {
215       potentialConsumer = &use;
216       continue;
217     }
218 
219     InFlightDiagnostic diag = reportError()
220                               << " has more than one potential consumer";
221     diag.attachNote(potentialConsumer->getOwner()->getLoc())
222         << "used here as operand #" << potentialConsumer->getOperandNumber();
223     diag.attachNote(use.getOwner()->getLoc())
224         << "used here as operand #" << use.getOperandNumber();
225     return diag;
226   }
227 
228   return success();
229 }
230 
231 LogicalResult transform::SequenceOp::verify() {
232   // Check if the block argument has more than one consuming use.
233   for (BlockArgument argument : getBodyBlock()->getArguments()) {
234     auto report = [&]() {
235       return (emitOpError() << "block argument #" << argument.getArgNumber());
236     };
237     if (failed(checkDoubleConsume(argument, report)))
238       return failure();
239   }
240 
241   // Check properties of the nested operations they cannot check themselves.
242   for (Operation &child : *getBodyBlock()) {
243     if (!isa<TransformOpInterface>(child) &&
244         &child != &getBodyBlock()->back()) {
245       InFlightDiagnostic diag =
246           emitOpError()
247           << "expected children ops to implement TransformOpInterface";
248       diag.attachNote(child.getLoc()) << "op without interface";
249       return diag;
250     }
251 
252     for (OpResult result : child.getResults()) {
253       auto report = [&]() {
254         return (child.emitError() << "result #" << result.getResultNumber());
255       };
256       if (failed(checkDoubleConsume(result, report)))
257         return failure();
258     }
259   }
260 
261   if (getBodyBlock()->getTerminator()->getOperandTypes() !=
262       getOperation()->getResultTypes()) {
263     InFlightDiagnostic diag = emitOpError()
264                               << "expects the types of the terminator operands "
265                                  "to match the types of the result";
266     diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
267     return diag;
268   }
269   return success();
270 }
271 
272 void transform::SequenceOp::getEffects(
273     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
274   auto *mappingResource = TransformMappingResource::get();
275   effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource);
276 
277   for (Value result : getResults()) {
278     effects.emplace_back(MemoryEffects::Allocate::get(), result,
279                          mappingResource);
280     effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource);
281   }
282 
283   if (!getRoot()) {
284     for (Operation &op : *getBodyBlock()) {
285       auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
286       if (!iface) {
287         // TODO: fill all possible effects; or require ops to actually implement
288         // the memory effect interface always
289         assert(false);
290       }
291 
292       SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
293       iface.getEffects(effects);
294     }
295     return;
296   }
297 
298   // Carry over all effects on the argument of the entry block as those on the
299   // operand, this is the same value just remapped.
300   for (Operation &op : *getBodyBlock()) {
301     auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
302     if (!iface) {
303       // TODO: fill all possible effects; or require ops to actually implement
304       // the memory effect interface always
305       assert(false);
306     }
307 
308     SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
309     iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects);
310     for (const auto &effect : nestedEffects)
311       effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource());
312   }
313 }
314 
315 OperandRange
316 transform::SequenceOp::getSuccessorEntryOperands(Optional<unsigned> index) {
317   assert(index && *index == 0 && "unexpected region index");
318   if (getOperation()->getNumOperands() == 1)
319     return getOperation()->getOperands();
320   return OperandRange(getOperation()->operand_end(),
321                       getOperation()->operand_end());
322 }
323 
324 void transform::SequenceOp::getSuccessorRegions(
325     Optional<unsigned> index, ArrayRef<Attribute> operands,
326     SmallVectorImpl<RegionSuccessor> &regions) {
327   if (!index.hasValue()) {
328     Region *bodyRegion = &getBody();
329     regions.emplace_back(bodyRegion, !operands.empty()
330                                          ? bodyRegion->getArguments()
331                                          : Block::BlockArgListType());
332     return;
333   }
334 
335   assert(*index == 0 && "unexpected region index");
336   regions.emplace_back(getOperation()->getResults());
337 }
338 
339 void transform::SequenceOp::getRegionInvocationBounds(
340     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
341   (void)operands;
342   bounds.emplace_back(1, 1);
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // WithPDLPatternsOp
347 //===----------------------------------------------------------------------===//
348 
349 LogicalResult
350 transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
351                                     transform::TransformState &state) {
352   OwningOpRef<ModuleOp> pdlModuleOp =
353       ModuleOp::create(getOperation()->getLoc());
354   TransformOpInterface transformOp = nullptr;
355   for (Operation &nested : getBody().front()) {
356     if (!isa<pdl::PatternOp>(nested)) {
357       transformOp = cast<TransformOpInterface>(nested);
358       break;
359     }
360   }
361 
362   state.addExtension<PatternApplicatorExtension>(getOperation());
363   auto guard = llvm::make_scope_exit(
364       [&]() { state.removeExtension<PatternApplicatorExtension>(); });
365 
366   auto scope = state.make_region_scope(getBody());
367   if (failed(mapBlockArguments(state)))
368     return failure();
369   return state.applyTransform(transformOp);
370 }
371 
372 LogicalResult transform::WithPDLPatternsOp::verify() {
373   Block *body = getBodyBlock();
374   Operation *topLevelOp = nullptr;
375   for (Operation &op : body->getOperations()) {
376     if (isa<pdl::PatternOp>(op))
377       continue;
378 
379     if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
380       if (topLevelOp) {
381         InFlightDiagnostic diag =
382             emitOpError() << "expects only one non-pattern op in its body";
383         diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
384         diag.attachNote(op.getLoc()) << "second non-pattern op";
385         return diag;
386       }
387       topLevelOp = &op;
388       continue;
389     }
390 
391     InFlightDiagnostic diag =
392         emitOpError()
393         << "expects only pattern and top-level transform ops in its body";
394     diag.attachNote(op.getLoc()) << "offending op";
395     return diag;
396   }
397 
398   if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
399     InFlightDiagnostic diag = emitOpError() << "cannot be nested";
400     diag.attachNote(parent.getLoc()) << "parent operation";
401     return diag;
402   }
403 
404   return success();
405 }
406