1 //===- TransformDialect.cpp - Transform dialect operations ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformOps.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
13 #include "mlir/IR/OpImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Interfaces/ControlFlowInterfaces.h"
16 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
17 #include "mlir/Rewrite/PatternApplicator.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "transform-dialect"
22 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
23 
24 using namespace mlir;
25 
26 #define GET_OP_CLASSES
27 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
28 
29 //===----------------------------------------------------------------------===//
30 // PatternApplicatorExtension
31 //===----------------------------------------------------------------------===//
32 
33 namespace {
34 /// A simple pattern rewriter that can be constructed from a context. This is
35 /// necessary to apply patterns to a specific op locally.
36 class TrivialPatternRewriter : public PatternRewriter {
37 public:
38   explicit TrivialPatternRewriter(MLIRContext *context)
39       : PatternRewriter(context) {}
40 };
41 
42 /// A TransformState extension that keeps track of compiled PDL pattern sets.
43 /// This is intended to be used along the WithPDLPatterns op. The extension
44 /// can be constructed given an operation that has a SymbolTable trait and
45 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one
46 /// by one when requested; this behavior is subject to change.
47 class PatternApplicatorExtension : public transform::TransformState::Extension {
48 public:
49   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
50 
51   /// Creates the extension for patterns contained in `patternContainer`.
52   explicit PatternApplicatorExtension(transform::TransformState &state,
53                                       Operation *patternContainer)
54       : Extension(state), patterns(patternContainer) {}
55 
56   /// Appends to `results` the operations contained in `root` that matched the
57   /// PDL pattern with the given name. Note that `root` may or may not be the
58   /// operation that contains PDL patterns. Reports an error if the pattern
59   /// cannot be found. Note that when no operations are matched, this still
60   /// succeeds as long as the pattern exists.
61   LogicalResult findAllMatches(StringRef patternName, Operation *root,
62                                SmallVectorImpl<Operation *> &results);
63 
64 private:
65   /// Map from the pattern name to a singleton set of rewrite patterns that only
66   /// contains the pattern with this name. Populated when the pattern is first
67   /// requested.
68   // TODO: reconsider the efficiency of this storage when more usage data is
69   // available. Storing individual patterns in a set and triggering compilation
70   // for each of them has overhead. So does compiling a large set of patterns
71   // only to apply a handlful of them.
72   llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
73 
74   /// A symbol table operation containing the relevant PDL patterns.
75   SymbolTable patterns;
76 };
77 
78 LogicalResult PatternApplicatorExtension::findAllMatches(
79     StringRef patternName, Operation *root,
80     SmallVectorImpl<Operation *> &results) {
81   auto it = compiledPatterns.find(patternName);
82   if (it == compiledPatterns.end()) {
83     auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
84     if (!patternOp)
85       return failure();
86 
87     OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
88     patternOp->moveBefore(pdlModuleOp->getBody(),
89                           pdlModuleOp->getBody()->end());
90     PDLPatternModule patternModule(std::move(pdlModuleOp));
91 
92     // Merge in the hooks owned by the dialect. Make a copy as they may be
93     // also used by the following operations.
94     auto *dialect =
95         root->getContext()->getLoadedDialect<transform::TransformDialect>();
96     for (const auto &pair : dialect->getPDLConstraintHooks())
97       patternModule.registerConstraintFunction(pair.first(), pair.second);
98 
99     // Register a noop rewriter because PDL requires patterns to end with some
100     // rewrite call.
101     patternModule.registerRewriteFunction(
102         "transform.dialect", [](PatternRewriter &, Operation *) {});
103 
104     it = compiledPatterns
105              .try_emplace(patternOp.getName(), std::move(patternModule))
106              .first;
107   }
108 
109   PatternApplicator applicator(it->second);
110   TrivialPatternRewriter rewriter(root->getContext());
111   applicator.applyDefaultCostModel();
112   root->walk([&](Operation *op) {
113     if (succeeded(applicator.matchAndRewrite(op, rewriter)))
114       results.push_back(op);
115   });
116 
117   return success();
118 }
119 } // namespace
120 
121 //===----------------------------------------------------------------------===//
122 // AlternativesOp
123 //===----------------------------------------------------------------------===//
124 
125 OperandRange
126 transform::AlternativesOp::getSuccessorEntryOperands(Optional<unsigned> index) {
127   if (index && getOperation()->getNumOperands() == 1)
128     return getOperation()->getOperands();
129   return OperandRange(getOperation()->operand_end(),
130                       getOperation()->operand_end());
131 }
132 
133 void transform::AlternativesOp::getSuccessorRegions(
134     Optional<unsigned> index, ArrayRef<Attribute> operands,
135     SmallVectorImpl<RegionSuccessor> &regions) {
136   for (Region &alternative :
137        llvm::drop_begin(getAlternatives(), index.hasValue() ? *index + 1 : 0)) {
138     regions.emplace_back(&alternative, !getOperands().empty()
139                                            ? alternative.getArguments()
140                                            : Block::BlockArgListType());
141   }
142   if (index.hasValue())
143     regions.emplace_back(getOperation()->getResults());
144 }
145 
146 void transform::AlternativesOp::getRegionInvocationBounds(
147     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
148   (void)operands;
149   // The region corresponding to the first alternative is always executed, the
150   // remaining may or may not be executed.
151   bounds.reserve(getNumRegions());
152   bounds.emplace_back(1, 1);
153   bounds.resize(getNumRegions(), InvocationBounds(0, 1));
154 }
155 
156 static void forwardTerminatorOperands(Block *block,
157                                       transform::TransformState &state,
158                                       transform::TransformResults &results) {
159   for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(),
160                                     block->getParentOp()->getOpResults())) {
161     Value terminatorOperand = std::get<0>(pair);
162     OpResult result = std::get<1>(pair);
163     results.set(result, state.getPayloadOps(terminatorOperand));
164   }
165 }
166 
167 DiagnosedSilenceableFailure
168 transform::AlternativesOp::apply(transform::TransformResults &results,
169                                  transform::TransformState &state) {
170   SmallVector<Operation *> originals;
171   if (Value scopeHandle = getScope())
172     llvm::append_range(originals, state.getPayloadOps(scopeHandle));
173   else
174     originals.push_back(state.getTopLevel());
175 
176   for (Operation *original : originals) {
177     if (original->isAncestor(getOperation())) {
178       InFlightDiagnostic diag =
179           emitError() << "scope must not contain the transforms being applied";
180       diag.attachNote(original->getLoc()) << "scope";
181       return DiagnosedSilenceableFailure::definiteFailure();
182     }
183     if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
184       InFlightDiagnostic diag =
185           emitError()
186           << "only isolated-from-above ops can be alternative scopes";
187       diag.attachNote(original->getLoc()) << "scope";
188       return DiagnosedSilenceableFailure(std::move(diag));
189     }
190   }
191 
192   for (Region &reg : getAlternatives()) {
193     // Clone the scope operations and make the transforms in this alternative
194     // region apply to them by virtue of mapping the block argument (the only
195     // visible handle) to the cloned scope operations. This effectively prevents
196     // the transformation from accessing any IR outside the scope.
197     auto scope = state.make_region_scope(reg);
198     auto clones = llvm::to_vector(
199         llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
200     auto deleteClones = llvm::make_scope_exit([&] {
201       for (Operation *clone : clones)
202         clone->erase();
203     });
204     if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
205       return DiagnosedSilenceableFailure::definiteFailure();
206 
207     bool failed = false;
208     for (Operation &transform : reg.front().without_terminator()) {
209       DiagnosedSilenceableFailure result =
210           state.applyTransform(cast<TransformOpInterface>(transform));
211       if (result.isSilenceableFailure()) {
212         LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
213                           << "\n");
214         failed = true;
215         break;
216       }
217 
218       if (::mlir::failed(result.silence()))
219         return DiagnosedSilenceableFailure::definiteFailure();
220     }
221 
222     // If all operations in the given alternative succeeded, no need to consider
223     // the rest. Replace the original scoping operation with the clone on which
224     // the transformations were performed.
225     if (!failed) {
226       // We will be using the clones, so cancel their scheduled deletion.
227       deleteClones.release();
228       IRRewriter rewriter(getContext());
229       for (const auto &kvp : llvm::zip(originals, clones)) {
230         Operation *original = std::get<0>(kvp);
231         Operation *clone = std::get<1>(kvp);
232         original->getBlock()->getOperations().insert(original->getIterator(),
233                                                      clone);
234         rewriter.replaceOp(original, clone->getResults());
235       }
236       forwardTerminatorOperands(&reg.front(), state, results);
237       return DiagnosedSilenceableFailure::success();
238     }
239   }
240   return emitSilenceableError() << "all alternatives failed";
241 }
242 
243 LogicalResult transform::AlternativesOp::verify() {
244   for (Region &alternative : getAlternatives()) {
245     Block &block = alternative.front();
246     if (block.getNumArguments() != 1 ||
247         !block.getArgument(0).getType().isa<pdl::OperationType>()) {
248       return emitOpError()
249              << "expects region blocks to have one operand of type "
250              << pdl::OperationType::get(getContext());
251     }
252 
253     Operation *terminator = block.getTerminator();
254     if (terminator->getOperands().getTypes() != getResults().getTypes()) {
255       InFlightDiagnostic diag = emitOpError()
256                                 << "expects terminator operands to have the "
257                                    "same type as results of the operation";
258       diag.attachNote(terminator->getLoc()) << "terminator";
259       return diag;
260     }
261   }
262 
263   return success();
264 }
265 
266 //===----------------------------------------------------------------------===//
267 // GetClosestIsolatedParentOp
268 //===----------------------------------------------------------------------===//
269 
270 DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply(
271     transform::TransformResults &results, transform::TransformState &state) {
272   SetVector<Operation *> parents;
273   for (Operation *target : state.getPayloadOps(getTarget())) {
274     Operation *parent =
275         target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
276     if (!parent) {
277       DiagnosedSilenceableFailure diag =
278           emitSilenceableError()
279           << "could not find an isolated-from-above parent op";
280       diag.attachNote(target->getLoc()) << "target op";
281       return diag;
282     }
283     parents.insert(parent);
284   }
285   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
286   return DiagnosedSilenceableFailure::success();
287 }
288 
289 //===----------------------------------------------------------------------===//
290 // MergeHandlesOp
291 //===----------------------------------------------------------------------===//
292 
293 DiagnosedSilenceableFailure
294 transform::MergeHandlesOp::apply(transform::TransformResults &results,
295                                  transform::TransformState &state) {
296   SmallVector<Operation *> operations;
297   for (Value operand : getHandles())
298     llvm::append_range(operations, state.getPayloadOps(operand));
299   if (!getDeduplicate()) {
300     results.set(getResult().cast<OpResult>(), operations);
301     return DiagnosedSilenceableFailure::success();
302   }
303 
304   SetVector<Operation *> uniqued(operations.begin(), operations.end());
305   results.set(getResult().cast<OpResult>(), uniqued.getArrayRef());
306   return DiagnosedSilenceableFailure::success();
307 }
308 
309 void transform::MergeHandlesOp::getEffects(
310     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
311   for (Value operand : getHandles()) {
312     effects.emplace_back(MemoryEffects::Read::get(), operand,
313                          transform::TransformMappingResource::get());
314     effects.emplace_back(MemoryEffects::Free::get(), operand,
315                          transform::TransformMappingResource::get());
316   }
317   effects.emplace_back(MemoryEffects::Allocate::get(), getResult(),
318                        transform::TransformMappingResource::get());
319   effects.emplace_back(MemoryEffects::Write::get(), getResult(),
320                        transform::TransformMappingResource::get());
321 
322   // There are no effects on the Payload IR as this is only a handle
323   // manipulation.
324 }
325 
326 OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) {
327   if (getDeduplicate() || getHandles().size() != 1)
328     return {};
329 
330   // If deduplication is not required and there is only one operand, it can be
331   // used directly instead of merging.
332   return getHandles().front();
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // PDLMatchOp
337 //===----------------------------------------------------------------------===//
338 
339 DiagnosedSilenceableFailure
340 transform::PDLMatchOp::apply(transform::TransformResults &results,
341                              transform::TransformState &state) {
342   auto *extension = state.getExtension<PatternApplicatorExtension>();
343   assert(extension &&
344          "expected PatternApplicatorExtension to be attached by the parent op");
345   SmallVector<Operation *> targets;
346   for (Operation *root : state.getPayloadOps(getRoot())) {
347     if (failed(extension->findAllMatches(
348             getPatternName().getLeafReference().getValue(), root, targets))) {
349       emitOpError() << "could not find pattern '" << getPatternName() << "'";
350       return DiagnosedSilenceableFailure::definiteFailure();
351     }
352   }
353   results.set(getResult().cast<OpResult>(), targets);
354   return DiagnosedSilenceableFailure::success();
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // SequenceOp
359 //===----------------------------------------------------------------------===//
360 
361 DiagnosedSilenceableFailure
362 transform::SequenceOp::apply(transform::TransformResults &results,
363                              transform::TransformState &state) {
364   // Map the entry block argument to the list of operations.
365   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
366   if (failed(mapBlockArguments(state)))
367     return DiagnosedSilenceableFailure::definiteFailure();
368 
369   // Apply the sequenced ops one by one.
370   for (Operation &transform : getBodyBlock()->without_terminator()) {
371     DiagnosedSilenceableFailure result =
372         state.applyTransform(cast<TransformOpInterface>(transform));
373     if (!result.succeeded())
374       return result;
375   }
376 
377   // Forward the operation mapping for values yielded from the sequence to the
378   // values produced by the sequence op.
379   forwardTerminatorOperands(getBodyBlock(), state, results);
380   return DiagnosedSilenceableFailure::success();
381 }
382 
383 /// Returns `true` if the given op operand may be consuming the handle value in
384 /// the Transform IR. That is, if it may have a Free effect on it.
385 static bool isValueUsePotentialConsumer(OpOperand &use) {
386   // Conservatively assume the effect being present in absence of the interface.
387   auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(use.getOwner());
388   if (!memEffectInterface)
389     return true;
390 
391   SmallVector<MemoryEffects::EffectInstance, 2> effects;
392   memEffectInterface.getEffectsOnValue(use.get(), effects);
393   return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
394     return isa<transform::TransformMappingResource>(effect.getResource()) &&
395            isa<MemoryEffects::Free>(effect.getEffect());
396   });
397 }
398 
399 LogicalResult
400 checkDoubleConsume(Value value,
401                    function_ref<InFlightDiagnostic()> reportError) {
402   OpOperand *potentialConsumer = nullptr;
403   for (OpOperand &use : value.getUses()) {
404     if (!isValueUsePotentialConsumer(use))
405       continue;
406 
407     if (!potentialConsumer) {
408       potentialConsumer = &use;
409       continue;
410     }
411 
412     InFlightDiagnostic diag = reportError()
413                               << " has more than one potential consumer";
414     diag.attachNote(potentialConsumer->getOwner()->getLoc())
415         << "used here as operand #" << potentialConsumer->getOperandNumber();
416     diag.attachNote(use.getOwner()->getLoc())
417         << "used here as operand #" << use.getOperandNumber();
418     return diag;
419   }
420 
421   return success();
422 }
423 
424 LogicalResult transform::SequenceOp::verify() {
425   // Check if the block argument has more than one consuming use.
426   for (BlockArgument argument : getBodyBlock()->getArguments()) {
427     auto report = [&]() {
428       return (emitOpError() << "block argument #" << argument.getArgNumber());
429     };
430     if (failed(checkDoubleConsume(argument, report)))
431       return failure();
432   }
433 
434   // Check properties of the nested operations they cannot check themselves.
435   for (Operation &child : *getBodyBlock()) {
436     if (!isa<TransformOpInterface>(child) &&
437         &child != &getBodyBlock()->back()) {
438       InFlightDiagnostic diag =
439           emitOpError()
440           << "expected children ops to implement TransformOpInterface";
441       diag.attachNote(child.getLoc()) << "op without interface";
442       return diag;
443     }
444 
445     for (OpResult result : child.getResults()) {
446       auto report = [&]() {
447         return (child.emitError() << "result #" << result.getResultNumber());
448       };
449       if (failed(checkDoubleConsume(result, report)))
450         return failure();
451     }
452   }
453 
454   if (getBodyBlock()->getTerminator()->getOperandTypes() !=
455       getOperation()->getResultTypes()) {
456     InFlightDiagnostic diag = emitOpError()
457                               << "expects the types of the terminator operands "
458                                  "to match the types of the result";
459     diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
460     return diag;
461   }
462   return success();
463 }
464 
465 void transform::SequenceOp::getEffects(
466     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
467   auto *mappingResource = TransformMappingResource::get();
468   effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource);
469 
470   for (Value result : getResults()) {
471     effects.emplace_back(MemoryEffects::Allocate::get(), result,
472                          mappingResource);
473     effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource);
474   }
475 
476   if (!getRoot()) {
477     for (Operation &op : *getBodyBlock()) {
478       auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
479       if (!iface) {
480         // TODO: fill all possible effects; or require ops to actually implement
481         // the memory effect interface always
482         assert(false);
483       }
484 
485       SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
486       iface.getEffects(effects);
487     }
488     return;
489   }
490 
491   // Carry over all effects on the argument of the entry block as those on the
492   // operand, this is the same value just remapped.
493   for (Operation &op : *getBodyBlock()) {
494     auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
495     if (!iface) {
496       // TODO: fill all possible effects; or require ops to actually implement
497       // the memory effect interface always
498       assert(false);
499     }
500 
501     SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
502     iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects);
503     for (const auto &effect : nestedEffects)
504       effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource());
505   }
506 }
507 
508 OperandRange
509 transform::SequenceOp::getSuccessorEntryOperands(Optional<unsigned> index) {
510   assert(index && *index == 0 && "unexpected region index");
511   if (getOperation()->getNumOperands() == 1)
512     return getOperation()->getOperands();
513   return OperandRange(getOperation()->operand_end(),
514                       getOperation()->operand_end());
515 }
516 
517 void transform::SequenceOp::getSuccessorRegions(
518     Optional<unsigned> index, ArrayRef<Attribute> operands,
519     SmallVectorImpl<RegionSuccessor> &regions) {
520   if (!index) {
521     Region *bodyRegion = &getBody();
522     regions.emplace_back(bodyRegion, !operands.empty()
523                                          ? bodyRegion->getArguments()
524                                          : Block::BlockArgListType());
525     return;
526   }
527 
528   assert(*index == 0 && "unexpected region index");
529   regions.emplace_back(getOperation()->getResults());
530 }
531 
532 void transform::SequenceOp::getRegionInvocationBounds(
533     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
534   (void)operands;
535   bounds.emplace_back(1, 1);
536 }
537 
538 //===----------------------------------------------------------------------===//
539 // WithPDLPatternsOp
540 //===----------------------------------------------------------------------===//
541 
542 DiagnosedSilenceableFailure
543 transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
544                                     transform::TransformState &state) {
545   OwningOpRef<ModuleOp> pdlModuleOp =
546       ModuleOp::create(getOperation()->getLoc());
547   TransformOpInterface transformOp = nullptr;
548   for (Operation &nested : getBody().front()) {
549     if (!isa<pdl::PatternOp>(nested)) {
550       transformOp = cast<TransformOpInterface>(nested);
551       break;
552     }
553   }
554 
555   state.addExtension<PatternApplicatorExtension>(getOperation());
556   auto guard = llvm::make_scope_exit(
557       [&]() { state.removeExtension<PatternApplicatorExtension>(); });
558 
559   auto scope = state.make_region_scope(getBody());
560   if (failed(mapBlockArguments(state)))
561     return DiagnosedSilenceableFailure::definiteFailure();
562   return state.applyTransform(transformOp);
563 }
564 
565 LogicalResult transform::WithPDLPatternsOp::verify() {
566   Block *body = getBodyBlock();
567   Operation *topLevelOp = nullptr;
568   for (Operation &op : body->getOperations()) {
569     if (isa<pdl::PatternOp>(op))
570       continue;
571 
572     if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
573       if (topLevelOp) {
574         InFlightDiagnostic diag =
575             emitOpError() << "expects only one non-pattern op in its body";
576         diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
577         diag.attachNote(op.getLoc()) << "second non-pattern op";
578         return diag;
579       }
580       topLevelOp = &op;
581       continue;
582     }
583 
584     InFlightDiagnostic diag =
585         emitOpError()
586         << "expects only pattern and top-level transform ops in its body";
587     diag.attachNote(op.getLoc()) << "offending op";
588     return diag;
589   }
590 
591   if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
592     InFlightDiagnostic diag = emitOpError() << "cannot be nested";
593     diag.attachNote(parent.getLoc()) << "parent operation";
594     return diag;
595   }
596 
597   return success();
598 }
599