1 //===- TransformDialect.cpp - Transform dialect operations ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformOps.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
13 #include "mlir/IR/OpImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Interfaces/ControlFlowInterfaces.h"
16 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
17 #include "mlir/Rewrite/PatternApplicator.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "transform-dialect"
22 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
23 
24 using namespace mlir;
25 
26 #define GET_OP_CLASSES
27 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
28 
29 //===----------------------------------------------------------------------===//
30 // PatternApplicatorExtension
31 //===----------------------------------------------------------------------===//
32 
33 namespace {
34 /// A simple pattern rewriter that can be constructed from a context. This is
35 /// necessary to apply patterns to a specific op locally.
36 class TrivialPatternRewriter : public PatternRewriter {
37 public:
38   explicit TrivialPatternRewriter(MLIRContext *context)
39       : PatternRewriter(context) {}
40 };
41 
42 /// A TransformState extension that keeps track of compiled PDL pattern sets.
43 /// This is intended to be used along the WithPDLPatterns op. The extension
44 /// can be constructed given an operation that has a SymbolTable trait and
45 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one
46 /// by one when requested; this behavior is subject to change.
47 class PatternApplicatorExtension : public transform::TransformState::Extension {
48 public:
49   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
50 
51   /// Creates the extension for patterns contained in `patternContainer`.
52   explicit PatternApplicatorExtension(transform::TransformState &state,
53                                       Operation *patternContainer)
54       : Extension(state), patterns(patternContainer) {}
55 
56   /// Appends to `results` the operations contained in `root` that matched the
57   /// PDL pattern with the given name. Note that `root` may or may not be the
58   /// operation that contains PDL patterns. Reports an error if the pattern
59   /// cannot be found. Note that when no operations are matched, this still
60   /// succeeds as long as the pattern exists.
61   LogicalResult findAllMatches(StringRef patternName, Operation *root,
62                                SmallVectorImpl<Operation *> &results);
63 
64 private:
65   /// Map from the pattern name to a singleton set of rewrite patterns that only
66   /// contains the pattern with this name. Populated when the pattern is first
67   /// requested.
68   // TODO: reconsider the efficiency of this storage when more usage data is
69   // available. Storing individual patterns in a set and triggering compilation
70   // for each of them has overhead. So does compiling a large set of patterns
71   // only to apply a handlful of them.
72   llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
73 
74   /// A symbol table operation containing the relevant PDL patterns.
75   SymbolTable patterns;
76 };
77 
78 LogicalResult PatternApplicatorExtension::findAllMatches(
79     StringRef patternName, Operation *root,
80     SmallVectorImpl<Operation *> &results) {
81   auto it = compiledPatterns.find(patternName);
82   if (it == compiledPatterns.end()) {
83     auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
84     if (!patternOp)
85       return failure();
86 
87     OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
88     patternOp->moveBefore(pdlModuleOp->getBody(),
89                           pdlModuleOp->getBody()->end());
90     PDLPatternModule patternModule(std::move(pdlModuleOp));
91 
92     // Merge in the hooks owned by the dialect. Make a copy as they may be
93     // also used by the following operations.
94     auto *dialect =
95         root->getContext()->getLoadedDialect<transform::TransformDialect>();
96     for (const auto &pair : dialect->getPDLConstraintHooks())
97       patternModule.registerConstraintFunction(pair.first(), pair.second);
98 
99     // Register a noop rewriter because PDL requires patterns to end with some
100     // rewrite call.
101     patternModule.registerRewriteFunction(
102         "transform.dialect", [](PatternRewriter &, Operation *) {});
103 
104     it = compiledPatterns
105              .try_emplace(patternOp.getName(), std::move(patternModule))
106              .first;
107   }
108 
109   PatternApplicator applicator(it->second);
110   TrivialPatternRewriter rewriter(root->getContext());
111   applicator.applyDefaultCostModel();
112   root->walk([&](Operation *op) {
113     if (succeeded(applicator.matchAndRewrite(op, rewriter)))
114       results.push_back(op);
115   });
116 
117   return success();
118 }
119 } // namespace
120 
121 //===----------------------------------------------------------------------===//
122 // AlternativesOp
123 //===----------------------------------------------------------------------===//
124 
125 OperandRange
126 transform::AlternativesOp::getSuccessorEntryOperands(Optional<unsigned> index) {
127   if (index && getOperation()->getNumOperands() == 1)
128     return getOperation()->getOperands();
129   return OperandRange(getOperation()->operand_end(),
130                       getOperation()->operand_end());
131 }
132 
133 void transform::AlternativesOp::getSuccessorRegions(
134     Optional<unsigned> index, ArrayRef<Attribute> operands,
135     SmallVectorImpl<RegionSuccessor> &regions) {
136   for (Region &alternative :
137        llvm::drop_begin(getAlternatives(), index ? *index + 1 : 0)) {
138     regions.emplace_back(&alternative, !getOperands().empty()
139                                            ? alternative.getArguments()
140                                            : Block::BlockArgListType());
141   }
142   if (index)
143     regions.emplace_back(getOperation()->getResults());
144 }
145 
146 void transform::AlternativesOp::getRegionInvocationBounds(
147     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
148   (void)operands;
149   // The region corresponding to the first alternative is always executed, the
150   // remaining may or may not be executed.
151   bounds.reserve(getNumRegions());
152   bounds.emplace_back(1, 1);
153   bounds.resize(getNumRegions(), InvocationBounds(0, 1));
154 }
155 
156 static void forwardTerminatorOperands(Block *block,
157                                       transform::TransformState &state,
158                                       transform::TransformResults &results) {
159   for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(),
160                                     block->getParentOp()->getOpResults())) {
161     Value terminatorOperand = std::get<0>(pair);
162     OpResult result = std::get<1>(pair);
163     results.set(result, state.getPayloadOps(terminatorOperand));
164   }
165 }
166 
167 DiagnosedSilenceableFailure
168 transform::AlternativesOp::apply(transform::TransformResults &results,
169                                  transform::TransformState &state) {
170   SmallVector<Operation *> originals;
171   if (Value scopeHandle = getScope())
172     llvm::append_range(originals, state.getPayloadOps(scopeHandle));
173   else
174     originals.push_back(state.getTopLevel());
175 
176   for (Operation *original : originals) {
177     if (original->isAncestor(getOperation())) {
178       InFlightDiagnostic diag =
179           emitError() << "scope must not contain the transforms being applied";
180       diag.attachNote(original->getLoc()) << "scope";
181       return DiagnosedSilenceableFailure::definiteFailure();
182     }
183     if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
184       InFlightDiagnostic diag =
185           emitError()
186           << "only isolated-from-above ops can be alternative scopes";
187       diag.attachNote(original->getLoc()) << "scope";
188       return DiagnosedSilenceableFailure(std::move(diag));
189     }
190   }
191 
192   for (Region &reg : getAlternatives()) {
193     // Clone the scope operations and make the transforms in this alternative
194     // region apply to them by virtue of mapping the block argument (the only
195     // visible handle) to the cloned scope operations. This effectively prevents
196     // the transformation from accessing any IR outside the scope.
197     auto scope = state.make_region_scope(reg);
198     auto clones = llvm::to_vector(
199         llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
200     auto deleteClones = llvm::make_scope_exit([&] {
201       for (Operation *clone : clones)
202         clone->erase();
203     });
204     if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
205       return DiagnosedSilenceableFailure::definiteFailure();
206 
207     bool failed = false;
208     for (Operation &transform : reg.front().without_terminator()) {
209       DiagnosedSilenceableFailure result =
210           state.applyTransform(cast<TransformOpInterface>(transform));
211       if (result.isSilenceableFailure()) {
212         LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
213                           << "\n");
214         failed = true;
215         break;
216       }
217 
218       if (::mlir::failed(result.silence()))
219         return DiagnosedSilenceableFailure::definiteFailure();
220     }
221 
222     // If all operations in the given alternative succeeded, no need to consider
223     // the rest. Replace the original scoping operation with the clone on which
224     // the transformations were performed.
225     if (!failed) {
226       // We will be using the clones, so cancel their scheduled deletion.
227       deleteClones.release();
228       IRRewriter rewriter(getContext());
229       for (const auto &kvp : llvm::zip(originals, clones)) {
230         Operation *original = std::get<0>(kvp);
231         Operation *clone = std::get<1>(kvp);
232         original->getBlock()->getOperations().insert(original->getIterator(),
233                                                      clone);
234         rewriter.replaceOp(original, clone->getResults());
235       }
236       forwardTerminatorOperands(&reg.front(), state, results);
237       return DiagnosedSilenceableFailure::success();
238     }
239   }
240   return emitSilenceableError() << "all alternatives failed";
241 }
242 
243 LogicalResult transform::AlternativesOp::verify() {
244   for (Region &alternative : getAlternatives()) {
245     Block &block = alternative.front();
246     if (block.getNumArguments() != 1 ||
247         !block.getArgument(0).getType().isa<pdl::OperationType>()) {
248       return emitOpError()
249              << "expects region blocks to have one operand of type "
250              << pdl::OperationType::get(getContext());
251     }
252 
253     Operation *terminator = block.getTerminator();
254     if (terminator->getOperands().getTypes() != getResults().getTypes()) {
255       InFlightDiagnostic diag = emitOpError()
256                                 << "expects terminator operands to have the "
257                                    "same type as results of the operation";
258       diag.attachNote(terminator->getLoc()) << "terminator";
259       return diag;
260     }
261   }
262 
263   return success();
264 }
265 
266 //===----------------------------------------------------------------------===//
267 // GetClosestIsolatedParentOp
268 //===----------------------------------------------------------------------===//
269 
270 DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply(
271     transform::TransformResults &results, transform::TransformState &state) {
272   SetVector<Operation *> parents;
273   for (Operation *target : state.getPayloadOps(getTarget())) {
274     Operation *parent =
275         target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
276     if (!parent) {
277       DiagnosedSilenceableFailure diag =
278           emitSilenceableError()
279           << "could not find an isolated-from-above parent op";
280       diag.attachNote(target->getLoc()) << "target op";
281       return diag;
282     }
283     parents.insert(parent);
284   }
285   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
286   return DiagnosedSilenceableFailure::success();
287 }
288 
289 //===----------------------------------------------------------------------===//
290 // PDLMatchOp
291 //===----------------------------------------------------------------------===//
292 
293 DiagnosedSilenceableFailure
294 transform::PDLMatchOp::apply(transform::TransformResults &results,
295                              transform::TransformState &state) {
296   auto *extension = state.getExtension<PatternApplicatorExtension>();
297   assert(extension &&
298          "expected PatternApplicatorExtension to be attached by the parent op");
299   SmallVector<Operation *> targets;
300   for (Operation *root : state.getPayloadOps(getRoot())) {
301     if (failed(extension->findAllMatches(
302             getPatternName().getLeafReference().getValue(), root, targets))) {
303       emitOpError() << "could not find pattern '" << getPatternName() << "'";
304       return DiagnosedSilenceableFailure::definiteFailure();
305     }
306   }
307   results.set(getResult().cast<OpResult>(), targets);
308   return DiagnosedSilenceableFailure::success();
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // SequenceOp
313 //===----------------------------------------------------------------------===//
314 
315 DiagnosedSilenceableFailure
316 transform::SequenceOp::apply(transform::TransformResults &results,
317                              transform::TransformState &state) {
318   // Map the entry block argument to the list of operations.
319   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
320   if (failed(mapBlockArguments(state)))
321     return DiagnosedSilenceableFailure::definiteFailure();
322 
323   // Apply the sequenced ops one by one.
324   for (Operation &transform : getBodyBlock()->without_terminator()) {
325     DiagnosedSilenceableFailure result =
326         state.applyTransform(cast<TransformOpInterface>(transform));
327     if (!result.succeeded())
328       return result;
329   }
330 
331   // Forward the operation mapping for values yielded from the sequence to the
332   // values produced by the sequence op.
333   forwardTerminatorOperands(getBodyBlock(), state, results);
334   return DiagnosedSilenceableFailure::success();
335 }
336 
337 /// Returns `true` if the given op operand may be consuming the handle value in
338 /// the Transform IR. That is, if it may have a Free effect on it.
339 static bool isValueUsePotentialConsumer(OpOperand &use) {
340   // Conservatively assume the effect being present in absence of the interface.
341   auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(use.getOwner());
342   if (!memEffectInterface)
343     return true;
344 
345   SmallVector<MemoryEffects::EffectInstance, 2> effects;
346   memEffectInterface.getEffectsOnValue(use.get(), effects);
347   return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
348     return isa<transform::TransformMappingResource>(effect.getResource()) &&
349            isa<MemoryEffects::Free>(effect.getEffect());
350   });
351 }
352 
353 LogicalResult
354 checkDoubleConsume(Value value,
355                    function_ref<InFlightDiagnostic()> reportError) {
356   OpOperand *potentialConsumer = nullptr;
357   for (OpOperand &use : value.getUses()) {
358     if (!isValueUsePotentialConsumer(use))
359       continue;
360 
361     if (!potentialConsumer) {
362       potentialConsumer = &use;
363       continue;
364     }
365 
366     InFlightDiagnostic diag = reportError()
367                               << " has more than one potential consumer";
368     diag.attachNote(potentialConsumer->getOwner()->getLoc())
369         << "used here as operand #" << potentialConsumer->getOperandNumber();
370     diag.attachNote(use.getOwner()->getLoc())
371         << "used here as operand #" << use.getOperandNumber();
372     return diag;
373   }
374 
375   return success();
376 }
377 
378 LogicalResult transform::SequenceOp::verify() {
379   // Check if the block argument has more than one consuming use.
380   for (BlockArgument argument : getBodyBlock()->getArguments()) {
381     auto report = [&]() {
382       return (emitOpError() << "block argument #" << argument.getArgNumber());
383     };
384     if (failed(checkDoubleConsume(argument, report)))
385       return failure();
386   }
387 
388   // Check properties of the nested operations they cannot check themselves.
389   for (Operation &child : *getBodyBlock()) {
390     if (!isa<TransformOpInterface>(child) &&
391         &child != &getBodyBlock()->back()) {
392       InFlightDiagnostic diag =
393           emitOpError()
394           << "expected children ops to implement TransformOpInterface";
395       diag.attachNote(child.getLoc()) << "op without interface";
396       return diag;
397     }
398 
399     for (OpResult result : child.getResults()) {
400       auto report = [&]() {
401         return (child.emitError() << "result #" << result.getResultNumber());
402       };
403       if (failed(checkDoubleConsume(result, report)))
404         return failure();
405     }
406   }
407 
408   if (getBodyBlock()->getTerminator()->getOperandTypes() !=
409       getOperation()->getResultTypes()) {
410     InFlightDiagnostic diag = emitOpError()
411                               << "expects the types of the terminator operands "
412                                  "to match the types of the result";
413     diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
414     return diag;
415   }
416   return success();
417 }
418 
419 void transform::SequenceOp::getEffects(
420     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
421   auto *mappingResource = TransformMappingResource::get();
422   effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource);
423 
424   for (Value result : getResults()) {
425     effects.emplace_back(MemoryEffects::Allocate::get(), result,
426                          mappingResource);
427     effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource);
428   }
429 
430   if (!getRoot()) {
431     for (Operation &op : *getBodyBlock()) {
432       auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
433       if (!iface) {
434         // TODO: fill all possible effects; or require ops to actually implement
435         // the memory effect interface always
436         assert(false);
437       }
438 
439       SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
440       iface.getEffects(effects);
441     }
442     return;
443   }
444 
445   // Carry over all effects on the argument of the entry block as those on the
446   // operand, this is the same value just remapped.
447   for (Operation &op : *getBodyBlock()) {
448     auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
449     if (!iface) {
450       // TODO: fill all possible effects; or require ops to actually implement
451       // the memory effect interface always
452       assert(false);
453     }
454 
455     SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
456     iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects);
457     for (const auto &effect : nestedEffects)
458       effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource());
459   }
460 }
461 
462 OperandRange
463 transform::SequenceOp::getSuccessorEntryOperands(Optional<unsigned> index) {
464   assert(index && *index == 0 && "unexpected region index");
465   if (getOperation()->getNumOperands() == 1)
466     return getOperation()->getOperands();
467   return OperandRange(getOperation()->operand_end(),
468                       getOperation()->operand_end());
469 }
470 
471 void transform::SequenceOp::getSuccessorRegions(
472     Optional<unsigned> index, ArrayRef<Attribute> operands,
473     SmallVectorImpl<RegionSuccessor> &regions) {
474   if (!index) {
475     Region *bodyRegion = &getBody();
476     regions.emplace_back(bodyRegion, !operands.empty()
477                                          ? bodyRegion->getArguments()
478                                          : Block::BlockArgListType());
479     return;
480   }
481 
482   assert(*index == 0 && "unexpected region index");
483   regions.emplace_back(getOperation()->getResults());
484 }
485 
486 void transform::SequenceOp::getRegionInvocationBounds(
487     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
488   (void)operands;
489   bounds.emplace_back(1, 1);
490 }
491 
492 //===----------------------------------------------------------------------===//
493 // WithPDLPatternsOp
494 //===----------------------------------------------------------------------===//
495 
496 DiagnosedSilenceableFailure
497 transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
498                                     transform::TransformState &state) {
499   OwningOpRef<ModuleOp> pdlModuleOp =
500       ModuleOp::create(getOperation()->getLoc());
501   TransformOpInterface transformOp = nullptr;
502   for (Operation &nested : getBody().front()) {
503     if (!isa<pdl::PatternOp>(nested)) {
504       transformOp = cast<TransformOpInterface>(nested);
505       break;
506     }
507   }
508 
509   state.addExtension<PatternApplicatorExtension>(getOperation());
510   auto guard = llvm::make_scope_exit(
511       [&]() { state.removeExtension<PatternApplicatorExtension>(); });
512 
513   auto scope = state.make_region_scope(getBody());
514   if (failed(mapBlockArguments(state)))
515     return DiagnosedSilenceableFailure::definiteFailure();
516   return state.applyTransform(transformOp);
517 }
518 
519 LogicalResult transform::WithPDLPatternsOp::verify() {
520   Block *body = getBodyBlock();
521   Operation *topLevelOp = nullptr;
522   for (Operation &op : body->getOperations()) {
523     if (isa<pdl::PatternOp>(op))
524       continue;
525 
526     if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
527       if (topLevelOp) {
528         InFlightDiagnostic diag =
529             emitOpError() << "expects only one non-pattern op in its body";
530         diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
531         diag.attachNote(op.getLoc()) << "second non-pattern op";
532         return diag;
533       }
534       topLevelOp = &op;
535       continue;
536     }
537 
538     InFlightDiagnostic diag =
539         emitOpError()
540         << "expects only pattern and top-level transform ops in its body";
541     diag.attachNote(op.getLoc()) << "offending op";
542     return diag;
543   }
544 
545   if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
546     InFlightDiagnostic diag = emitOpError() << "cannot be nested";
547     diag.attachNote(parent.getLoc()) << "parent operation";
548     return diag;
549   }
550 
551   return success();
552 }
553