1 //===- TransformDialect.cpp - Transform dialect operations ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformOps.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
12 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
13 #include "mlir/IR/OpImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Interfaces/ControlFlowInterfaces.h"
16 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
17 #include "mlir/Rewrite/PatternApplicator.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "transform-dialect"
22 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
23 
24 using namespace mlir;
25 
26 #define GET_OP_CLASSES
27 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
28 
29 //===----------------------------------------------------------------------===//
30 // PatternApplicatorExtension
31 //===----------------------------------------------------------------------===//
32 
33 namespace {
34 /// A simple pattern rewriter that can be constructed from a context. This is
35 /// necessary to apply patterns to a specific op locally.
36 class TrivialPatternRewriter : public PatternRewriter {
37 public:
38   explicit TrivialPatternRewriter(MLIRContext *context)
39       : PatternRewriter(context) {}
40 };
41 
42 /// A TransformState extension that keeps track of compiled PDL pattern sets.
43 /// This is intended to be used along the WithPDLPatterns op. The extension
44 /// can be constructed given an operation that has a SymbolTable trait and
45 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one
46 /// by one when requested; this behavior is subject to change.
47 class PatternApplicatorExtension : public transform::TransformState::Extension {
48 public:
49   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
50 
51   /// Creates the extension for patterns contained in `patternContainer`.
52   explicit PatternApplicatorExtension(transform::TransformState &state,
53                                       Operation *patternContainer)
54       : Extension(state), patterns(patternContainer) {}
55 
56   /// Appends to `results` the operations contained in `root` that matched the
57   /// PDL pattern with the given name. Note that `root` may or may not be the
58   /// operation that contains PDL patterns. Reports an error if the pattern
59   /// cannot be found. Note that when no operations are matched, this still
60   /// succeeds as long as the pattern exists.
61   LogicalResult findAllMatches(StringRef patternName, Operation *root,
62                                SmallVectorImpl<Operation *> &results);
63 
64 private:
65   /// Map from the pattern name to a singleton set of rewrite patterns that only
66   /// contains the pattern with this name. Populated when the pattern is first
67   /// requested.
68   // TODO: reconsider the efficiency of this storage when more usage data is
69   // available. Storing individual patterns in a set and triggering compilation
70   // for each of them has overhead. So does compiling a large set of patterns
71   // only to apply a handlful of them.
72   llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
73 
74   /// A symbol table operation containing the relevant PDL patterns.
75   SymbolTable patterns;
76 };
77 
78 LogicalResult PatternApplicatorExtension::findAllMatches(
79     StringRef patternName, Operation *root,
80     SmallVectorImpl<Operation *> &results) {
81   auto it = compiledPatterns.find(patternName);
82   if (it == compiledPatterns.end()) {
83     auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
84     if (!patternOp)
85       return failure();
86 
87     OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
88     patternOp->moveBefore(pdlModuleOp->getBody(),
89                           pdlModuleOp->getBody()->end());
90     PDLPatternModule patternModule(std::move(pdlModuleOp));
91 
92     // Merge in the hooks owned by the dialect. Make a copy as they may be
93     // also used by the following operations.
94     auto *dialect =
95         root->getContext()->getLoadedDialect<transform::TransformDialect>();
96     for (const auto &pair : dialect->getPDLConstraintHooks())
97       patternModule.registerConstraintFunction(pair.first(), pair.second);
98 
99     // Register a noop rewriter because PDL requires patterns to end with some
100     // rewrite call.
101     patternModule.registerRewriteFunction(
102         "transform.dialect", [](PatternRewriter &, Operation *) {});
103 
104     it = compiledPatterns
105              .try_emplace(patternOp.getName(), std::move(patternModule))
106              .first;
107   }
108 
109   PatternApplicator applicator(it->second);
110   TrivialPatternRewriter rewriter(root->getContext());
111   applicator.applyDefaultCostModel();
112   root->walk([&](Operation *op) {
113     if (succeeded(applicator.matchAndRewrite(op, rewriter)))
114       results.push_back(op);
115   });
116 
117   return success();
118 }
119 } // namespace
120 
121 //===----------------------------------------------------------------------===//
122 // AlternativesOp
123 //===----------------------------------------------------------------------===//
124 
125 OperandRange
126 transform::AlternativesOp::getSuccessorEntryOperands(Optional<unsigned> index) {
127   if (index.hasValue() && getOperation()->getNumOperands() == 1)
128     return getOperation()->getOperands();
129   return OperandRange(getOperation()->operand_end(),
130                       getOperation()->operand_end());
131 }
132 
133 void transform::AlternativesOp::getSuccessorRegions(
134     Optional<unsigned> index, ArrayRef<Attribute> operands,
135     SmallVectorImpl<RegionSuccessor> &regions) {
136   for (Region &alternative :
137        llvm::drop_begin(getAlternatives(), index.hasValue() ? *index + 1 : 0)) {
138     regions.emplace_back(&alternative, !getOperands().empty()
139                                            ? alternative.getArguments()
140                                            : Block::BlockArgListType());
141   }
142   if (index.hasValue())
143     regions.emplace_back(getOperation()->getResults());
144 }
145 
146 void transform::AlternativesOp::getRegionInvocationBounds(
147     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
148   (void)operands;
149   // The region corresponding to the first alternative is always executed, the
150   // remaining may or may not be executed.
151   bounds.reserve(getNumRegions());
152   bounds.emplace_back(1, 1);
153   bounds.resize(getNumRegions(), InvocationBounds(0, 1));
154 }
155 
156 static void forwardTerminatorOperands(Block *block,
157                                       transform::TransformState &state,
158                                       transform::TransformResults &results) {
159   for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(),
160                                     block->getParentOp()->getOpResults())) {
161     Value terminatorOperand = std::get<0>(pair);
162     OpResult result = std::get<1>(pair);
163     results.set(result, state.getPayloadOps(terminatorOperand));
164   }
165 }
166 
167 DiagnosedSilencableFailure
168 transform::AlternativesOp::apply(transform::TransformResults &results,
169                                  transform::TransformState &state) {
170   SmallVector<Operation *> originals;
171   if (Value scopeHandle = getScope())
172     llvm::append_range(originals, state.getPayloadOps(scopeHandle));
173   else
174     originals.push_back(state.getTopLevel());
175 
176   for (Operation *original : originals) {
177     if (original->isAncestor(getOperation())) {
178       InFlightDiagnostic diag =
179           emitError() << "scope must not contain the transforms being applied";
180       diag.attachNote(original->getLoc()) << "scope";
181       return DiagnosedSilencableFailure::definiteFailure();
182     }
183   }
184 
185   for (Region &reg : getAlternatives()) {
186     // Clone the scope operations and make the transforms in this alternative
187     // region apply to them by virtue of mapping the block argument (the only
188     // visible handle) to the cloned scope operations. This effectively prevents
189     // the transformation from accessing any IR outside the scope.
190     auto scope = state.make_region_scope(reg);
191     auto clones = llvm::to_vector(
192         llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
193     if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
194       return DiagnosedSilencableFailure::definiteFailure();
195     auto deleteClones = llvm::make_scope_exit([&] {
196       for (Operation *clone : clones)
197         clone->erase();
198     });
199 
200     bool failed = false;
201     for (Operation &transform : reg.front().without_terminator()) {
202       DiagnosedSilencableFailure result =
203           state.applyTransform(cast<TransformOpInterface>(transform));
204       if (result.isSilencableFailure()) {
205         LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
206                           << "\n");
207         failed = true;
208         break;
209       }
210 
211       if (::mlir::failed(result.silence()))
212         return DiagnosedSilencableFailure::definiteFailure();
213     }
214 
215     // If all operations in the given alternative succeeded, no need to consider
216     // the rest. Replace the original scoping operation with the clone on which
217     // the transformations were performed.
218     if (!failed) {
219       // We will be using the clones, so cancel their scheduled deletion.
220       deleteClones.release();
221       IRRewriter rewriter(getContext());
222       for (const auto &kvp : llvm::zip(originals, clones)) {
223         Operation *original = std::get<0>(kvp);
224         Operation *clone = std::get<1>(kvp);
225         original->getBlock()->getOperations().insert(original->getIterator(),
226                                                      clone);
227         rewriter.replaceOp(original, clone->getResults());
228       }
229       forwardTerminatorOperands(&reg.front(), state, results);
230       return DiagnosedSilencableFailure::success();
231     }
232   }
233   return emitSilencableError() << "all alternatives failed";
234 }
235 
236 LogicalResult transform::AlternativesOp::verify() {
237   for (Region &alternative : getAlternatives()) {
238     Block &block = alternative.front();
239     if (block.getNumArguments() != 1 ||
240         !block.getArgument(0).getType().isa<pdl::OperationType>()) {
241       return emitOpError()
242              << "expects region blocks to have one operand of type "
243              << pdl::OperationType::get(getContext());
244     }
245 
246     Operation *terminator = block.getTerminator();
247     if (terminator->getOperands().getTypes() != getResults().getTypes()) {
248       InFlightDiagnostic diag = emitOpError()
249                                 << "expects terminator operands to have the "
250                                    "same type as results of the operation";
251       diag.attachNote(terminator->getLoc()) << "terminator";
252       return diag;
253     }
254   }
255 
256   return success();
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // GetClosestIsolatedParentOp
261 //===----------------------------------------------------------------------===//
262 
263 DiagnosedSilencableFailure transform::GetClosestIsolatedParentOp::apply(
264     transform::TransformResults &results, transform::TransformState &state) {
265   SetVector<Operation *> parents;
266   for (Operation *target : state.getPayloadOps(getTarget())) {
267     Operation *parent =
268         target->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
269     if (!parent) {
270       DiagnosedSilencableFailure diag =
271           emitSilencableError()
272           << "could not find an isolated-from-above parent op";
273       diag.attachNote(target->getLoc()) << "target op";
274       return diag;
275     }
276     parents.insert(parent);
277   }
278   results.set(getResult().cast<OpResult>(), parents.getArrayRef());
279   return DiagnosedSilencableFailure::success();
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // PDLMatchOp
284 //===----------------------------------------------------------------------===//
285 
286 DiagnosedSilencableFailure
287 transform::PDLMatchOp::apply(transform::TransformResults &results,
288                              transform::TransformState &state) {
289   auto *extension = state.getExtension<PatternApplicatorExtension>();
290   assert(extension &&
291          "expected PatternApplicatorExtension to be attached by the parent op");
292   SmallVector<Operation *> targets;
293   for (Operation *root : state.getPayloadOps(getRoot())) {
294     if (failed(extension->findAllMatches(
295             getPatternName().getLeafReference().getValue(), root, targets))) {
296       emitOpError() << "could not find pattern '" << getPatternName() << "'";
297       return DiagnosedSilencableFailure::definiteFailure();
298     }
299   }
300   results.set(getResult().cast<OpResult>(), targets);
301   return DiagnosedSilencableFailure::success();
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // SequenceOp
306 //===----------------------------------------------------------------------===//
307 
308 DiagnosedSilencableFailure
309 transform::SequenceOp::apply(transform::TransformResults &results,
310                              transform::TransformState &state) {
311   // Map the entry block argument to the list of operations.
312   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
313   if (failed(mapBlockArguments(state)))
314     return DiagnosedSilencableFailure::definiteFailure();
315 
316   // Apply the sequenced ops one by one.
317   for (Operation &transform : getBodyBlock()->without_terminator()) {
318     DiagnosedSilencableFailure result =
319         state.applyTransform(cast<TransformOpInterface>(transform));
320     if (!result.succeeded())
321       return result;
322   }
323 
324   // Forward the operation mapping for values yielded from the sequence to the
325   // values produced by the sequence op.
326   forwardTerminatorOperands(getBodyBlock(), state, results);
327   return DiagnosedSilencableFailure::success();
328 }
329 
330 /// Returns `true` if the given op operand may be consuming the handle value in
331 /// the Transform IR. That is, if it may have a Free effect on it.
332 static bool isValueUsePotentialConsumer(OpOperand &use) {
333   // Conservatively assume the effect being present in absence of the interface.
334   auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(use.getOwner());
335   if (!memEffectInterface)
336     return true;
337 
338   SmallVector<MemoryEffects::EffectInstance, 2> effects;
339   memEffectInterface.getEffectsOnValue(use.get(), effects);
340   return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
341     return isa<transform::TransformMappingResource>(effect.getResource()) &&
342            isa<MemoryEffects::Free>(effect.getEffect());
343   });
344 }
345 
346 LogicalResult
347 checkDoubleConsume(Value value,
348                    function_ref<InFlightDiagnostic()> reportError) {
349   OpOperand *potentialConsumer = nullptr;
350   for (OpOperand &use : value.getUses()) {
351     if (!isValueUsePotentialConsumer(use))
352       continue;
353 
354     if (!potentialConsumer) {
355       potentialConsumer = &use;
356       continue;
357     }
358 
359     InFlightDiagnostic diag = reportError()
360                               << " has more than one potential consumer";
361     diag.attachNote(potentialConsumer->getOwner()->getLoc())
362         << "used here as operand #" << potentialConsumer->getOperandNumber();
363     diag.attachNote(use.getOwner()->getLoc())
364         << "used here as operand #" << use.getOperandNumber();
365     return diag;
366   }
367 
368   return success();
369 }
370 
371 LogicalResult transform::SequenceOp::verify() {
372   // Check if the block argument has more than one consuming use.
373   for (BlockArgument argument : getBodyBlock()->getArguments()) {
374     auto report = [&]() {
375       return (emitOpError() << "block argument #" << argument.getArgNumber());
376     };
377     if (failed(checkDoubleConsume(argument, report)))
378       return failure();
379   }
380 
381   // Check properties of the nested operations they cannot check themselves.
382   for (Operation &child : *getBodyBlock()) {
383     if (!isa<TransformOpInterface>(child) &&
384         &child != &getBodyBlock()->back()) {
385       InFlightDiagnostic diag =
386           emitOpError()
387           << "expected children ops to implement TransformOpInterface";
388       diag.attachNote(child.getLoc()) << "op without interface";
389       return diag;
390     }
391 
392     for (OpResult result : child.getResults()) {
393       auto report = [&]() {
394         return (child.emitError() << "result #" << result.getResultNumber());
395       };
396       if (failed(checkDoubleConsume(result, report)))
397         return failure();
398     }
399   }
400 
401   if (getBodyBlock()->getTerminator()->getOperandTypes() !=
402       getOperation()->getResultTypes()) {
403     InFlightDiagnostic diag = emitOpError()
404                               << "expects the types of the terminator operands "
405                                  "to match the types of the result";
406     diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
407     return diag;
408   }
409   return success();
410 }
411 
412 void transform::SequenceOp::getEffects(
413     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
414   auto *mappingResource = TransformMappingResource::get();
415   effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource);
416 
417   for (Value result : getResults()) {
418     effects.emplace_back(MemoryEffects::Allocate::get(), result,
419                          mappingResource);
420     effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource);
421   }
422 
423   if (!getRoot()) {
424     for (Operation &op : *getBodyBlock()) {
425       auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
426       if (!iface) {
427         // TODO: fill all possible effects; or require ops to actually implement
428         // the memory effect interface always
429         assert(false);
430       }
431 
432       SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
433       iface.getEffects(effects);
434     }
435     return;
436   }
437 
438   // Carry over all effects on the argument of the entry block as those on the
439   // operand, this is the same value just remapped.
440   for (Operation &op : *getBodyBlock()) {
441     auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
442     if (!iface) {
443       // TODO: fill all possible effects; or require ops to actually implement
444       // the memory effect interface always
445       assert(false);
446     }
447 
448     SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
449     iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects);
450     for (const auto &effect : nestedEffects)
451       effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource());
452   }
453 }
454 
455 OperandRange
456 transform::SequenceOp::getSuccessorEntryOperands(Optional<unsigned> index) {
457   assert(index && *index == 0 && "unexpected region index");
458   if (getOperation()->getNumOperands() == 1)
459     return getOperation()->getOperands();
460   return OperandRange(getOperation()->operand_end(),
461                       getOperation()->operand_end());
462 }
463 
464 void transform::SequenceOp::getSuccessorRegions(
465     Optional<unsigned> index, ArrayRef<Attribute> operands,
466     SmallVectorImpl<RegionSuccessor> &regions) {
467   if (!index.hasValue()) {
468     Region *bodyRegion = &getBody();
469     regions.emplace_back(bodyRegion, !operands.empty()
470                                          ? bodyRegion->getArguments()
471                                          : Block::BlockArgListType());
472     return;
473   }
474 
475   assert(*index == 0 && "unexpected region index");
476   regions.emplace_back(getOperation()->getResults());
477 }
478 
479 void transform::SequenceOp::getRegionInvocationBounds(
480     ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
481   (void)operands;
482   bounds.emplace_back(1, 1);
483 }
484 
485 //===----------------------------------------------------------------------===//
486 // WithPDLPatternsOp
487 //===----------------------------------------------------------------------===//
488 
489 DiagnosedSilencableFailure
490 transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
491                                     transform::TransformState &state) {
492   OwningOpRef<ModuleOp> pdlModuleOp =
493       ModuleOp::create(getOperation()->getLoc());
494   TransformOpInterface transformOp = nullptr;
495   for (Operation &nested : getBody().front()) {
496     if (!isa<pdl::PatternOp>(nested)) {
497       transformOp = cast<TransformOpInterface>(nested);
498       break;
499     }
500   }
501 
502   state.addExtension<PatternApplicatorExtension>(getOperation());
503   auto guard = llvm::make_scope_exit(
504       [&]() { state.removeExtension<PatternApplicatorExtension>(); });
505 
506   auto scope = state.make_region_scope(getBody());
507   if (failed(mapBlockArguments(state)))
508     return DiagnosedSilencableFailure::definiteFailure();
509   return state.applyTransform(transformOp);
510 }
511 
512 LogicalResult transform::WithPDLPatternsOp::verify() {
513   Block *body = getBodyBlock();
514   Operation *topLevelOp = nullptr;
515   for (Operation &op : body->getOperations()) {
516     if (isa<pdl::PatternOp>(op))
517       continue;
518 
519     if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
520       if (topLevelOp) {
521         InFlightDiagnostic diag =
522             emitOpError() << "expects only one non-pattern op in its body";
523         diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
524         diag.attachNote(op.getLoc()) << "second non-pattern op";
525         return diag;
526       }
527       topLevelOp = &op;
528       continue;
529     }
530 
531     InFlightDiagnostic diag =
532         emitOpError()
533         << "expects only pattern and top-level transform ops in its body";
534     diag.attachNote(op.getLoc()) << "offending op";
535     return diag;
536   }
537 
538   if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
539     InFlightDiagnostic diag = emitOpError() << "cannot be nested";
540     diag.attachNote(parent.getLoc()) << "parent operation";
541     return diag;
542   }
543 
544   return success();
545 }
546