1 //===- TransformInterfaces.h - Transform Dialect Interfaces -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H
10 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H
11 
12 #include "mlir/IR/OpDefinition.h"
13 
14 #include "mlir/Interfaces/SideEffectInterfaces.h"
15 #include "llvm/ADT/ScopeExit.h"
16 
17 namespace mlir {
18 
19 /// The result of a transform IR operation application. This can have one of the
20 /// three states:
21 ///   - success;
22 ///   - silenceable (recoverable) failure with yet-unreported diagnostic;
23 ///   - definite failure.
24 /// Silenceable failure is intended to communicate information about
25 /// transformations that did not apply but in a way that supports recovery,
26 /// for example, they did not modify the payload IR or modified it in some
27 /// predictable way. They are associated with a Diagnostic that provides more
28 /// details on the failure. Silenceable failure can be discarded, turning the
29 /// result into success, or "reported", emitting the diagnostic and turning the
30 /// result into definite failure.
31 /// Transform IR operations containing other operations are allowed to do either
32 /// with the results of the nested transformations, but must propagate definite
33 /// failures as their diagnostics have been already reported to the user.
34 class LLVM_NODISCARD DiagnosedSilenceableFailure {
35 public:
DiagnosedSilenceableFailure(LogicalResult result)36   explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {}
37   DiagnosedSilenceableFailure(const DiagnosedSilenceableFailure &) = delete;
38   DiagnosedSilenceableFailure &
39   operator=(const DiagnosedSilenceableFailure &) = delete;
40   DiagnosedSilenceableFailure(DiagnosedSilenceableFailure &&) = default;
41   DiagnosedSilenceableFailure &
42   operator=(DiagnosedSilenceableFailure &&) = default;
43 
44   /// Constructs a DiagnosedSilenceableFailure in the success state.
success()45   static DiagnosedSilenceableFailure success() {
46     return DiagnosedSilenceableFailure(::mlir::success());
47   }
48 
49   /// Constructs a DiagnosedSilenceableFailure in the failure state. Typically,
50   /// a diagnostic has been emitted before this.
definiteFailure()51   static DiagnosedSilenceableFailure definiteFailure() {
52     return DiagnosedSilenceableFailure(::mlir::failure());
53   }
54 
55   /// Constructs a DiagnosedSilenceableFailure in the silenceable failure state,
56   /// ready to emit the given diagnostic. This is considered a failure
57   /// regardless of the diagnostic severity.
silenceableFailure(Diagnostic && diag)58   static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag) {
59     return DiagnosedSilenceableFailure(std::forward<Diagnostic>(diag));
60   }
61   static DiagnosedSilenceableFailure
silenceableFailure(SmallVector<Diagnostic> && diag)62   silenceableFailure(SmallVector<Diagnostic> &&diag) {
63     return DiagnosedSilenceableFailure(
64         std::forward<SmallVector<Diagnostic>>(diag));
65   }
66 
67   /// Converts all kinds of failure into a LogicalResult failure, emitting the
68   /// diagnostic if necessary. Must not be called more than once.
checkAndReport()69   LogicalResult checkAndReport() {
70 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
71     assert(!reported && "attempting to report a diagnostic more than once");
72     reported = true;
73 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
74     if (!diagnostics.empty()) {
75       for (auto &&diagnostic : diagnostics) {
76         diagnostic.getLocation().getContext()->getDiagEngine().emit(
77             std::move(diagnostic));
78       }
79       diagnostics.clear();
80       result = ::mlir::failure();
81     }
82     return result;
83   }
84 
85   /// Returns `true` if this is a success.
succeeded()86   bool succeeded() const {
87     return ::mlir::succeeded(result) && diagnostics.empty();
88   }
89 
90   /// Returns `true` if this is a definite failure.
isDefiniteFailure()91   bool isDefiniteFailure() const {
92     return ::mlir::failed(result) && diagnostics.empty();
93   }
94 
95   /// Returns `true` if this is a silenceable failure.
isSilenceableFailure()96   bool isSilenceableFailure() const { return !diagnostics.empty(); }
97 
98   /// Returns the diagnostic message without emitting it. Expects this object
99   /// to be a silenceable failure.
getMessage()100   std::string getMessage() const {
101     std::string res;
102     for (auto &diagnostic : diagnostics) {
103       res.append(diagnostic.str());
104       res.append("\n");
105     }
106     return res;
107   }
108 
109   /// Returns a string representation of the failure mode (for error reporting).
getStatusString()110   std::string getStatusString() const {
111     if (succeeded())
112       return "success";
113     if (isSilenceableFailure())
114       return "silenceable failure";
115     return "definite failure";
116   }
117 
118   /// Converts silenceable failure into LogicalResult success without reporting
119   /// the diagnostic, preserves the other states.
silence()120   LogicalResult silence() {
121     if (!diagnostics.empty()) {
122       diagnostics.clear();
123       result = ::mlir::success();
124     }
125     return result;
126   }
127 
128   /// Take the diagnostic and silence.
takeDiagnostics()129   SmallVector<Diagnostic> &&takeDiagnostics() {
130     assert(!diagnostics.empty() && "expected a diagnostic to be present");
131     auto guard = llvm::make_scope_exit([&]() { diagnostics.clear(); });
132     return std::move(diagnostics);
133   }
134 
135   /// Streams the given values into the last diagnotic.
136   /// Expects this object to be a silenceable failure.
137   template <typename T>
138   DiagnosedSilenceableFailure &operator<<(T &&value) & {
139     assert(isSilenceableFailure() &&
140            "can only append output in silenceable failure state");
141     diagnostics.back() << std::forward<T>(value);
142     return *this;
143   }
144   template <typename T>
145   DiagnosedSilenceableFailure &&operator<<(T &&value) && {
146     return std::move(this->operator<<(std::forward<T>(value)));
147   }
148 
149   /// Attaches a note to the last diagnostic.
150   /// Expects this object to be a silenceable failure.
151   Diagnostic &attachNote(Optional<Location> loc = llvm::None) {
152     assert(isSilenceableFailure() &&
153            "can only attach notes to silenceable failures");
154     return diagnostics.back().attachNote(loc);
155   }
156 
157 private:
DiagnosedSilenceableFailure(Diagnostic && diagnostic)158   explicit DiagnosedSilenceableFailure(Diagnostic &&diagnostic)
159       : diagnostics(), result(failure()) {
160     diagnostics.emplace_back(std::move(diagnostic));
161   }
DiagnosedSilenceableFailure(SmallVector<Diagnostic> && diagnostics)162   explicit DiagnosedSilenceableFailure(SmallVector<Diagnostic> &&diagnostics)
163       : diagnostics(std::move(diagnostics)), result(failure()) {}
164 
165   /// The diagnostics associated with this object. If non-empty, the object is
166   /// considered to be in the silenceable failure state regardless of the
167   /// `result` field.
168   SmallVector<Diagnostic, 1> diagnostics;
169 
170   /// The "definite" logical state, either success or failure.
171   /// Ignored if the diagnostics message is present.
172   LogicalResult result;
173 
174 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
175   /// Whether the associated diagnostics have been reported.
176   /// Diagnostics reporting consumes the diagnostics, so we need a mechanism to
177   /// differentiate reported diagnostics from a state where it was never
178   /// created.
179   bool reported = false;
180 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
181 };
182 
183 namespace transform {
184 
185 class TransformOpInterface;
186 
187 /// Options controlling the application of transform operations by the
188 /// TransformState.
189 class TransformOptions {
190 public:
TransformOptions()191   TransformOptions() {}
192 
193   /// Requests computationally expensive checks of the transform and payload IR
194   /// well-formedness to be performed before each transformation. In particular,
195   /// these ensure that the handles still point to valid operations when used.
196   TransformOptions &enableExpensiveChecks(bool enable = true) {
197     expensiveChecksEnabled = enable;
198     return *this;
199   }
200 
201   /// Returns true if the expensive checks are requested.
getExpensiveChecksEnabled()202   bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }
203 
204 private:
205   bool expensiveChecksEnabled = true;
206 };
207 
208 /// The state maintained across applications of various ops implementing the
209 /// TransformOpInterface. The operations implementing this interface and the
210 /// surrounding structure are referred to as transform IR. The operations to
211 /// which transformations apply are referred to as payload IR. The state thus
212 /// contains the mapping between values defined in the transform IR ops and
213 /// payload IR ops. It assumes that each value in the transform IR can be used
214 /// at most once (since transformations are likely to change the payload IR ops
215 /// the value corresponds to). Checks that transform IR values correspond to
216 /// disjoint sets of payload IR ops throughout the transformation.
217 ///
218 /// A reference to this class is passed as an argument to "apply" methods of the
219 /// transform op interface. Thus the "apply" method can call
220 /// `state.getPayloadOps( getSomeOperand() )` to obtain the list of operations
221 /// associated with its operand and subject to transformation. The method is
222 /// expected to populate the `TransformResults` class instance in order to
223 /// update the mapping. The `applyTransform` method takes care of propagating
224 /// the state of `TransformResults` into the instance of this class.
225 ///
226 /// When applying transform IR operations with regions, the client is expected
227 /// to create a RegionScope RAII object to create a new "stack frame" for
228 /// values defined inside the region. The mappings from and to these values will
229 /// be automatically dropped when the object goes out of scope, typically at the
230 /// end of the "apply" function of the parent operation. If a region contains
231 /// blocks with arguments, the client can map those arguments to payload IR ops
232 /// using "mapBlockArguments".
233 class TransformState {
234   /// Mapping between a Value in the transform IR and the corresponding set of
235   /// operations in the payload IR.
236   using TransformOpMapping = DenseMap<Value, SmallVector<Operation *>>;
237 
238   /// Mapping between a payload IR operation and the transform IR value it is
239   /// currently associated with.
240   using TransformOpReverseMapping = DenseMap<Operation *, Value>;
241 
242   /// Bidirectional mappings between transform IR values and payload IR
243   /// operations.
244   struct Mappings {
245     TransformOpMapping direct;
246     TransformOpReverseMapping reverse;
247   };
248 
249 public:
250   /// Creates a state for transform ops living in the given region. The parent
251   /// operation of the region. The second argument points to the root operation
252   /// in the payload IR beind transformed, which may or may not contain the
253   /// region with transform ops. Additional options can be provided through the
254   /// trailing configuration object.
255   TransformState(Region &region, Operation *root,
256                  const TransformOptions &options = TransformOptions());
257 
258   /// Returns the op at which the transformation state is rooted. This is
259   /// typically helpful for transformations that apply globally.
260   Operation *getTopLevel() const;
261 
262   /// Returns the list of ops that the given transform IR value corresponds to.
263   /// This is helpful for transformations that apply to a particular handle.
264   ArrayRef<Operation *> getPayloadOps(Value value) const;
265 
266   /// Returns the Transform IR handle for the given Payload IR op if it exists
267   /// in the state, null otherwise.
268   Value getHandleForPayloadOp(Operation *op) const;
269 
270   /// Applies the transformation specified by the given transform op and updates
271   /// the state accordingly.
272   DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform);
273 
274   /// Records the mapping between a block argument in the transform IR and a
275   /// list of operations in the payload IR. The arguments must be defined in
276   /// blocks of the currently processed transform IR region, typically after a
277   /// region scope is defined.
mapBlockArguments(BlockArgument argument,ArrayRef<Operation * > operations)278   LogicalResult mapBlockArguments(BlockArgument argument,
279                                   ArrayRef<Operation *> operations) {
280 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
281     assert(argument.getParentRegion() == regionStack.back() &&
282            "mapping block arguments from a region other than the active one");
283 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
284     return setPayloadOps(argument, operations);
285   }
286 
287   // Forward declarations to support limited visibility.
288   class RegionScope;
289 
290   /// Creates a new region scope for the given region. The region is expected to
291   /// be nested in the currently processed region.
292   // Implementation note: this method is inline but implemented outside of the
293   // class body to comply with visibility and full-declaration requirements.
294   inline RegionScope make_region_scope(Region &region);
295 
296   /// A RAII object maintaining a "stack frame" for a transform IR region. When
297   /// applying a transform IR operation that contains a region, the caller is
298   /// expected to create a RegionScope before applying the ops contained in the
299   /// region. This ensures that the mappings between values defined in the
300   /// transform IR region and payload IR operations are cleared when the region
301   /// processing ends; such values cannot be accessed outside the region.
302   class RegionScope {
303   public:
304     /// Forgets the mapping from or to values defined in the associated
305     /// transform IR region.
~RegionScope()306     ~RegionScope() {
307       state.mappings.erase(region);
308 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
309       state.regionStack.pop_back();
310 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
311     }
312 
313   private:
314     /// Creates a new scope for mappings between values defined in the given
315     /// transform IR region and payload IR operations.
RegionScope(TransformState & state,Region & region)316     RegionScope(TransformState &state, Region &region)
317         : state(state), region(&region) {
318       auto res = state.mappings.try_emplace(this->region);
319       assert(res.second && "the region scope is already present");
320       (void)res;
321 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
322       assert(state.regionStack.back()->isProperAncestor(&region) &&
323              "scope started at a non-nested region");
324       state.regionStack.push_back(&region);
325 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
326     }
327 
328     /// Back-reference to the transform state.
329     TransformState &state;
330 
331     /// The region this scope is associated with.
332     Region *region;
333 
334     friend RegionScope TransformState::make_region_scope(Region &);
335   };
336   friend class RegionScope;
337 
338   /// Base class for TransformState extensions that allow TransformState to
339   /// contain user-specified information in the state object. Clients are
340   /// expected to derive this class, add the desired fields, and make the
341   /// derived class compatible with the MLIR TypeID mechanism:
342   ///
343   /// ```mlir
344   /// class MyExtension final : public TransformState::Extension {
345   /// public:
346   ///   MyExtension(TranfsormState &state, int myData)
347   ///     : Extension(state) {...}
348   /// private:
349   ///   int mySupplementaryData;
350   /// };
351   /// ```
352   ///
353   /// Instances of this and derived classes are not expected to be created by
354   /// the user, instead they are directly constructed within a TransformState. A
355   /// TransformState can only contain one extension with the given TypeID.
356   /// Extensions can be obtained from a TransformState instance, and can be
357   /// removed when they are no longer required.
358   ///
359   /// ```mlir
360   /// transformState.addExtension<MyExtension>(/*myData=*/42);
361   /// MyExtension *ext = transformState.getExtension<MyExtension>();
362   /// ext->doSomething();
363   /// ```
364   class Extension {
365     // Allow TransformState to allocate Extensions.
366     friend class TransformState;
367 
368   public:
369     /// Base virtual destructor.
370     // Out-of-line definition ensures symbols are emitted in a single object
371     // file.
372     virtual ~Extension();
373 
374   protected:
375     /// Constructs an extension of the given TransformState object.
Extension(TransformState & state)376     Extension(TransformState &state) : state(state) {}
377 
378     /// Provides read-only access to the parent TransformState object.
getTransformState()379     const TransformState &getTransformState() const { return state; }
380 
381     /// Replaces the given payload op with another op. If the replacement op is
382     /// null, removes the association of the payload op with its handle.
383     LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
384 
385   private:
386     /// Back-reference to the state that is being extended.
387     TransformState &state;
388   };
389 
390   /// Adds a new Extension of the type specified as template parameter,
391   /// constructing it with the arguments provided. The extension is owned by the
392   /// TransformState. It is expected that the state does not already have an
393   /// extension of the same type. Extension constructors are expected to take
394   /// a reference to TransformState as first argument, automatically supplied
395   /// by this call.
396   template <typename Ty, typename... Args>
addExtension(Args &&...args)397   Ty &addExtension(Args &&...args) {
398     static_assert(
399         std::is_base_of<Extension, Ty>::value,
400         "only an class derived from TransformState::Extension is allowed here");
401     auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...);
402     auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
403     assert(result.second && "extension already added");
404     return *static_cast<Ty *>(result.first->second.get());
405   }
406 
407   /// Returns the extension of the specified type.
408   template <typename Ty>
getExtension()409   Ty *getExtension() {
410     static_assert(
411         std::is_base_of<Extension, Ty>::value,
412         "only an class derived from TransformState::Extension is allowed here");
413     auto iter = extensions.find(TypeID::get<Ty>());
414     if (iter == extensions.end())
415       return nullptr;
416     return static_cast<Ty *>(iter->second.get());
417   }
418 
419   /// Removes the extension of the specified type.
420   template <typename Ty>
removeExtension()421   void removeExtension() {
422     static_assert(
423         std::is_base_of<Extension, Ty>::value,
424         "only an class derived from TransformState::Extension is allowed here");
425     extensions.erase(TypeID::get<Ty>());
426   }
427 
428 private:
429   /// Identifier for storing top-level value in the `operations` mapping.
430   static constexpr Value kTopLevelValue = Value();
431 
432   /// Returns the mappings frame for the reigon in which the value is defined.
getMapping(Value value)433   const Mappings &getMapping(Value value) const {
434     return const_cast<TransformState *>(this)->getMapping(value);
435   }
getMapping(Value value)436   Mappings &getMapping(Value value) {
437     auto it = mappings.find(value.getParentRegion());
438     assert(it != mappings.end() &&
439            "trying to find a mapping for a value from an unmapped region");
440     return it->second;
441   }
442 
443   /// Returns the mappings frame for the region in which the operation resides.
getMapping(Operation * operation)444   const Mappings &getMapping(Operation *operation) const {
445     return const_cast<TransformState *>(this)->getMapping(operation);
446   }
getMapping(Operation * operation)447   Mappings &getMapping(Operation *operation) {
448     auto it = mappings.find(operation->getParentRegion());
449     assert(it != mappings.end() &&
450            "trying to find a mapping for an operation from an unmapped region");
451     return it->second;
452   }
453 
454   /// Sets the payload IR ops associated with the given transform IR value.
455   /// Fails if this would result in multiple transform IR values with uses
456   /// corresponding to the same payload IR ops. For example, a hypothetical
457   /// "find function by name" transform op would (indirectly) call this
458   /// function for its result. Having two such calls in a row with for different
459   /// values, e.g. coming from different ops:
460   ///
461   ///   %0 = transform.find_func_by_name { name = "myfunc" }
462   ///   %1 = transform.find_func_by_name { name = "myfunc" }
463   ///
464   /// would lead to both values pointing to the same operation. The second call
465   /// to setPayloadOps will fail, unless the association with the %0 value is
466   /// removed first by calling update/removePayloadOps.
467   LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
468 
469   /// Forgets the payload IR ops associated with the given transform IR value.
470   void removePayloadOps(Value value);
471 
472   /// Updates the payload IR ops associated with the given transform IR value.
473   /// The callback function is called once per associated operation and is
474   /// expected to return the modified operation or nullptr. In the latter case,
475   /// the corresponding operation is no longer associated with the transform IR
476   /// value. May fail if the operation produced by the update callback is
477   /// already associated with a different Transform IR handle value.
478   LogicalResult
479   updatePayloadOps(Value value,
480                    function_ref<Operation *(Operation *)> callback);
481 
482   /// Attempts to record the mapping between the given Payload IR operation and
483   /// the given Transform IR handle. Fails and reports an error if the operation
484   /// is already tracked by another handle.
485   static LogicalResult tryEmplaceReverseMapping(Mappings &map, Operation *op,
486                                                 Value handle);
487 
488   /// If the operand is a handle consumed by the operation, i.e. has the "free"
489   /// memory effect associated with it, identifies other handles that are
490   /// pointing to payload IR operations nested in the operations pointed to by
491   /// the consumed handle. Marks all such handles as invalidated so trigger
492   /// errors if they are used.
493   void recordHandleInvalidation(OpOperand &handle);
494 
495   /// Checks that the operation does not use invalidated handles as operands.
496   /// Reports errors and returns failure if it does. Otherwise, invalidates the
497   /// handles consumed by the operation as well as any handles pointing to
498   /// payload IR operations nested in the operations associated with the
499   /// consumed handles.
500   LogicalResult
501   checkAndRecordHandleInvalidation(TransformOpInterface transform);
502 
503   /// The mappings between transform IR values and payload IR ops, aggregated by
504   /// the region in which the transform IR values are defined.
505   llvm::SmallDenseMap<Region *, Mappings> mappings;
506 
507   /// Extensions attached to the TransformState, identified by the TypeID of
508   /// their type. Only one extension of any given type is allowed.
509   DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
510 
511   /// The top-level operation that contains all payload IR, typically a module.
512   Operation *topLevel;
513 
514   /// Additional options controlling the transformation state behavior.
515   TransformOptions options;
516 
517   /// The mapping from invalidated handles to the error-reporting functions that
518   /// describe when the handles were invalidated. Calling such a function emits
519   /// a user-visible diagnostic.
520   DenseMap<Value, std::function<void()>> invalidatedHandles;
521 
522 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
523   /// A stack of nested regions that are being processed in the transform IR.
524   /// Each region must be an ancestor of the following regions in this list.
525   /// These are also the keys for "mappings".
526   SmallVector<Region *> regionStack;
527 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
528 };
529 
530 /// Local mapping between values defined by a specific op implementing the
531 /// TransformOpInterface and the payload IR ops they correspond to.
532 class TransformResults {
533   friend class TransformState;
534 
535 public:
536   /// Indicates that the result of the transform IR op at the given position
537   /// corresponds to the given list of payload IR ops. Each result must be set
538   /// by the transformation exactly once.
539   void set(OpResult value, ArrayRef<Operation *> ops);
540 
541 private:
542   /// Creates an instance of TransformResults that expects mappings for
543   /// `numSegments` values.
544   explicit TransformResults(unsigned numSegments);
545 
546   /// Gets the list of operations associated with the result identified by its
547   /// number in the list of operation results.
548   ArrayRef<Operation *> get(unsigned resultNumber) const;
549 
550   /// Storage for pointers to payload IR ops that are associated with results of
551   /// a transform IR op. `segments` contains as many entries as the transform IR
552   /// op has results. Each entry is a reference to a contiguous segment in
553   /// the `operations` list that contains the pointers to operations. This
554   /// allows for operations to be stored contiguously without nested vectors and
555   /// for different segments to be set in any order.
556   SmallVector<ArrayRef<Operation *>, 2> segments;
557   SmallVector<Operation *> operations;
558 };
559 
make_region_scope(Region & region)560 TransformState::RegionScope TransformState::make_region_scope(Region &region) {
561   return RegionScope(*this, region);
562 }
563 
564 namespace detail {
565 /// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait
566 /// to either the list of operations associated with its operand or the root of
567 /// the payload IR, depending on what is available in the context.
568 LogicalResult
569 mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
570                                              Operation *op, Region &region);
571 
572 /// Verification hook for PossibleTopLevelTransformOpTrait.
573 LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
574 } // namespace detail
575 
576 /// This trait is supposed to be attached to Transform dialect operations that
577 /// can be standalone top-level transforms. Such operations typically contain
578 /// other Transform dialect operations that can be executed following some
579 /// control flow logic specific to the current operation. The operations with
580 /// this trait are expected to have at least one single-block region with one
581 /// argument of PDL Operation type. The operations are also expected to be valid
582 /// without operands, in which case they are considered top-level, and with one
583 /// or more arguments, in which case they are considered nested. Top-level
584 /// operations have the block argument of the entry block in the Transform IR
585 /// correspond to the root operation of Payload IR. Nested operations have the
586 /// block argument of the entry block in the Transform IR correspond to a list
587 /// of Payload IR operations mapped to the first operand of the Transform IR
588 /// operation. The operation must implement TransformOpInterface.
589 template <typename OpTy>
590 class PossibleTopLevelTransformOpTrait
591     : public OpTrait::TraitBase<OpTy, PossibleTopLevelTransformOpTrait> {
592 public:
593   /// Verifies that `op` satisfies the invariants of this trait. Not expected to
594   /// be called directly.
verifyTrait(Operation * op)595   static LogicalResult verifyTrait(Operation *op) {
596     return detail::verifyPossibleTopLevelTransformOpTrait(op);
597   }
598 
599   /// Returns the single block of the given region.
600   Block *getBodyBlock(unsigned region = 0) {
601     return &this->getOperation()->getRegion(region).front();
602   }
603 
604   /// Sets up the mapping between the entry block of the given region of this op
605   /// and the relevant list of Payload IR operations in the given state. The
606   /// state is expected to be already scoped at the region of this operation.
607   /// Returns failure if the mapping failed, e.g., the value is already mapped.
mapBlockArguments(TransformState & state,Region & region)608   LogicalResult mapBlockArguments(TransformState &state, Region &region) {
609     assert(region.getParentOp() == this->getOperation() &&
610            "op comes from the wrong region");
611     return detail::mapPossibleTopLevelTransformOpBlockArguments(
612         state, this->getOperation(), region);
613   }
mapBlockArguments(TransformState & state)614   LogicalResult mapBlockArguments(TransformState &state) {
615     assert(
616         this->getOperation()->getNumRegions() == 1 &&
617         "must indicate the region to map if the operation has more than one");
618     return mapBlockArguments(state, this->getOperation()->getRegion(0));
619   }
620 };
621 
622 /// Trait implementing the TransformOpInterface for operations applying a
623 /// transformation to a single operation handle and producing zero, one or
624 /// multiple operation handles.
625 /// The op must implement a method with the following signature:
626 ///   - DiagnosedSilenceableFailure applyToOne(OpTy,
627 ///       SmallVector<Operation*> &results, state)
628 /// to perform a transformation that is applied in turn to all payload IR
629 /// operations that correspond to the handle of the transform IR operation.
630 /// In `applyToOne`, OpTy is either Operation* or a concrete payload IR Op class
631 /// that the transformation is applied to (and NOT the class of the transform IR
632 /// op).
633 /// The `applyToOne` method takes an empty `results` vector that it fills with
634 /// zero, one or multiple operations depending on the number of resultd expected
635 /// by the transform op.
636 /// The number of results must match the number of results of the transform op.
637 /// `applyToOne` is allowed to fill the `results` with all null elements to
638 /// signify that the transformation did not apply to the payload IR operations.
639 /// Such null elements are filtered out from results before return.
640 ///
641 /// The transform op having this trait is expected to have a single operand.
642 template <typename OpTy>
643 class TransformEachOpTrait
644     : public OpTrait::TraitBase<OpTy, TransformEachOpTrait> {
645 public:
646   /// Calls `applyToOne` for every payload operation associated with the operand
647   /// of this transform IR op, the following case disjunction happens:
648   ///   1. If not target payload ops are associated to the operand then fill the
649   ///      results vector with the expected number of null elements and return
650   ///      success. This is the corner case handling that allows propagating
651   ///      the "no-op" case gracefully to improve usability.
652   ///   2. If any `applyToOne` returns definiteFailure, the transformation is
653   ///      immediately considered definitely failed and we return.
654   ///   3. All applications of `applyToOne` are checked to return a number of
655   ///      results expected by the transform IR op. If not, this is a definite
656   ///      failure and we return early.
657   ///   4. If `applyToOne` produces ops, associate them with the result of this
658   ///      transform op.
659   ///   5. If any `applyToOne` return silenceableFailure, the transformation is
660   ///      considered silenceable.
661   ///   6. Otherwise the transformation is considered successful.
662   DiagnosedSilenceableFailure apply(TransformResults &transformResults,
663                                     TransformState &state);
664 
665   /// Checks that the op matches the expectations of this trait.
666   static LogicalResult verifyTrait(Operation *op);
667 };
668 
669 /// Side effect resource corresponding to the mapping between Transform IR
670 /// values and Payload IR operations. An Allocate effect from this resource
671 /// means creating a new mapping entry, it is always accompanied by a Write
672 /// effet. A Read effect from this resource means accessing the mapping. A Free
673 /// effect on this resource indicates the removal of the mapping entry,
674 /// typically after a transformation that modifies the Payload IR operations
675 /// associated with one of the Transform IR operation's operands. It is always
676 /// accompanied by a Read effect. Read-after-Free and double-Free are not
677 /// allowed (they would be problematic with "regular" memory effects too) as
678 /// they indicate an attempt to access Payload IR operations that have been
679 /// modified, potentially erased, by the previous tranfsormations.
680 // TODO: consider custom effects if these are not enabling generic passes such
681 // as CSE/DCE to work.
682 struct TransformMappingResource
683     : public SideEffects::Resource::Base<TransformMappingResource> {
getNameTransformMappingResource684   StringRef getName() override { return "transform.mapping"; }
685 };
686 
687 /// Side effect resource corresponding to the Payload IR itself. Only Read and
688 /// Write effects are expected on this resource, with Write always accompanied
689 /// by a Read (short of fully replacing the top-level Payload IR operation, one
690 /// cannot modify the Payload IR without reading it first). This is intended
691 /// to disallow reordering of Transform IR operations that mutate the Payload IR
692 /// while still allowing the reordering of those that only access it.
693 struct PayloadIRResource
694     : public SideEffects::Resource::Base<PayloadIRResource> {
getNamePayloadIRResource695   StringRef getName() override { return "transform.payload_ir"; }
696 };
697 
698 /// Populates `effects` with the memory effects indicating the operation on the
699 /// given handle value:
700 ///   - consumes = Read + Free,
701 ///   - produces = Allocate + Write,
702 ///   - onlyReads = Read.
703 void consumesHandle(ValueRange handles,
704                     SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
705 void producesHandle(ValueRange handles,
706                     SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
707 void onlyReadsHandle(ValueRange handles,
708                      SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
709 
710 /// Checks whether the transform op consumes the given handle.
711 bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
712 
713 /// Populates `effects` with the memory effects indicating the access to payload
714 /// IR resource.
715 void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
716 void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
717 
718 /// Trait implementing the MemoryEffectOpInterface for operations that "consume"
719 /// their operands and produce new results.
720 template <typename OpTy>
721 class FunctionalStyleTransformOpTrait
722     : public OpTrait::TraitBase<OpTy, FunctionalStyleTransformOpTrait> {
723 public:
724   /// This op "consumes" the operands by reading and freeing then, "produces"
725   /// the results by allocating and writing it and reads/writes the payload IR
726   /// in the process.
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)727   void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
728     consumesHandle(this->getOperation()->getOperands(), effects);
729     producesHandle(this->getOperation()->getResults(), effects);
730     modifiesPayload(effects);
731   }
732 
733   /// Checks that the op matches the expectations of this trait.
verifyTrait(Operation * op)734   static LogicalResult verifyTrait(Operation *op) {
735     if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
736       op->emitError()
737           << "FunctionalStyleTransformOpTrait should only be attached to ops "
738              "that implement MemoryEffectOpInterface";
739     }
740     return success();
741   }
742 };
743 
744 /// Trait implementing the MemoryEffectOpInterface for single-operand
745 /// single-result operations that use their operand without consuming and
746 /// without modifying the Payload IR to produce a new handle.
747 template <typename OpTy>
748 class NavigationTransformOpTrait
749     : public OpTrait::TraitBase<OpTy, NavigationTransformOpTrait> {
750 public:
751   /// This op produces handles to the Payload IR without consuming the original
752   /// handles and without modifying the IR itself.
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)753   void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
754     onlyReadsHandle(this->getOperation()->getOperands(), effects);
755     producesHandle(this->getOperation()->getResults(), effects);
756     onlyReadsPayload(effects);
757   }
758 
759   /// Checks that the op matches the expectation of this trait.
verifyTrait(Operation * op)760   static LogicalResult verifyTrait(Operation *op) {
761     static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
762                   "expected single-operand op");
763     static_assert(OpTy::template hasTrait<OpTrait::OneResult>(),
764                   "expected single-result op");
765     if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
766       op->emitError() << "NavigationTransformOpTrait should only be attached "
767                          "to ops that implement MemoryEffectOpInterface";
768     }
769     return success();
770   }
771 };
772 
773 } // namespace transform
774 } // namespace mlir
775 
776 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h.inc"
777 
778 namespace mlir {
779 namespace transform {
780 namespace detail {
781 /// Applies a one-to-one or a one-to-many transform to each of the given
782 /// targets. Puts the results of transforms, if any, in `results` in the same
783 /// order. Fails if any of the application fails. Individual transforms must be
784 /// callable with the following signature:
785 ///   - DiagnosedSilenceableFailure(OpTy,
786 ///       SmallVector<Operation*> &results, state)
787 /// where OpTy is either
788 ///   - Operation *, in which case the transform is always applied;
789 ///   - a concrete Op class, in which case a check is performed whether
790 ///   `targets` contains operations of the same class and a silenceable failure
791 ///   is reported if it does not.
792 template <typename FnTy>
applyTransformToEach(Location loc,int expectedNumResults,ArrayRef<Operation * > targets,SmallVectorImpl<SmallVector<Operation * >> & results,FnTy transform)793 DiagnosedSilenceableFailure applyTransformToEach(
794     Location loc, int expectedNumResults, ArrayRef<Operation *> targets,
795     SmallVectorImpl<SmallVector<Operation *>> &results, FnTy transform) {
796   SmallVector<Diagnostic> silenceableStack;
797   using OpTy = typename llvm::function_traits<FnTy>::template arg_t<0>;
798   static_assert(std::is_convertible<OpTy, Operation *>::value,
799                 "expected transform function to take an operation");
800   for (Operation *target : targets) {
801     // Emplace back a placeholder for the returned new ops.
802     // This is filled with `expectedNumResults` if the op fails to apply.
803     results.push_back(SmallVector<Operation *>());
804 
805     auto specificOp = dyn_cast<OpTy>(target);
806     if (!specificOp) {
807       Diagnostic diag(loc, DiagnosticSeverity::Error);
808       diag << "transform applied to the wrong op kind";
809       diag.attachNote(target->getLoc()) << "when applied to this op";
810       // Producing `expectedNumResults` nullptr is a silenceableFailure mode.
811       // TODO: encode this implicit `expectedNumResults` nullptr ==
812       // silenceableFailure with a proper trait.
813       results.back().assign(expectedNumResults, nullptr);
814       silenceableStack.push_back(std::move(diag));
815       continue;
816     }
817 
818     DiagnosedSilenceableFailure result = transform(specificOp, results.back());
819     if (result.isDefiniteFailure())
820       return result;
821     if (result.isSilenceableFailure())
822       for (auto &&diag : result.takeDiagnostics())
823         silenceableStack.push_back(std::move(diag));
824   }
825   if (!silenceableStack.empty()) {
826     return DiagnosedSilenceableFailure::silenceableFailure(
827         std::move(silenceableStack));
828   }
829   return DiagnosedSilenceableFailure::success();
830 }
831 
832 /// Helper function: transpose MxN into NxM; assumes that the input is a valid.
833 static inline SmallVector<SmallVector<Operation *, 1>>
transposeResults(const SmallVector<SmallVector<Operation * >,1> & m)834 transposeResults(const SmallVector<SmallVector<Operation *>, 1> &m) {
835   SmallVector<SmallVector<Operation *, 1>> res;
836   if (m.empty())
837     return res;
838   int64_t rows = m.size(), cols = m[0].size();
839   for (int64_t j = 0; j < cols; ++j)
840     res.push_back(SmallVector<Operation *, 1>(rows, nullptr));
841   for (int64_t i = 0; i < rows; ++i) {
842     assert(static_cast<int64_t>(m[i].size()) == cols);
843     for (int64_t j = 0; j < cols; ++j) {
844       res[j][i] = m[i][j];
845     }
846   }
847   return res;
848 }
849 } // namespace detail
850 } // namespace transform
851 } // namespace mlir
852 
853 template <typename OpTy>
854 mlir::DiagnosedSilenceableFailure
apply(TransformResults & transformResults,TransformState & state)855 mlir::transform::TransformEachOpTrait<OpTy>::apply(
856     TransformResults &transformResults, TransformState &state) {
857   using TransformOpType = typename llvm::function_traits<
858       decltype(&OpTy::applyToOne)>::template arg_t<0>;
859   ArrayRef<Operation *> targets =
860       state.getPayloadOps(this->getOperation()->getOperand(0));
861 
862   // Step 1. Handle the corner case where no target is specified.
863   // This is typically the case when the matcher fails to apply and we need to
864   // propagate gracefully.
865   // In this case, we fill all results with an empty vector.
866   if (targets.empty()) {
867     SmallVector<Operation *> empty;
868     for (auto r : this->getOperation()->getResults())
869       transformResults.set(r.template cast<OpResult>(), empty);
870     return DiagnosedSilenceableFailure::success();
871   }
872 
873   // Step 2. Call applyToOne on each target and record newly produced ops in its
874   // corresponding results entry.
875   int expectedNumResults = this->getOperation()->getNumResults();
876   SmallVector<SmallVector<Operation *>, 1> results;
877   DiagnosedSilenceableFailure result = detail::applyTransformToEach(
878       this->getOperation()->getLoc(), expectedNumResults, targets, results,
879       [&](TransformOpType specificOp, SmallVector<Operation *> &partialResult) {
880         auto res = static_cast<OpTy *>(this)->applyToOne(specificOp,
881                                                          partialResult, state);
882         if (res.isDefiniteFailure())
883           return res;
884 
885         // TODO: encode this implicit must always produce `expectedNumResults`
886         // and nullptr is fine with a proper trait.
887         if (static_cast<int>(partialResult.size()) != expectedNumResults) {
888           auto loc = this->getOperation()->getLoc();
889           auto diag = mlir::emitError(loc, "applications of ")
890                       << OpTy::getOperationName() << " expected to produce "
891                       << expectedNumResults << " results (actually produced "
892                       << partialResult.size() << ").";
893           diag.attachNote(loc)
894               << "If you need variadic results, consider a generic `apply` "
895               << "instead of the specialized `applyToOne`.";
896           diag.attachNote(loc)
897               << "Producing " << expectedNumResults << " null results is "
898               << "allowed if the use case warrants it.";
899           diag.attachNote(specificOp->getLoc()) << "when applied to this op";
900           return DiagnosedSilenceableFailure::definiteFailure();
901         }
902         // Check that all is null or none is null
903         // TODO: relax this behavior and encode with a proper trait.
904         if (llvm::any_of(partialResult, [](Operation *op) { return op; }) &&
905             llvm::any_of(partialResult, [](Operation *op) { return !op; })) {
906           auto loc = this->getOperation()->getLoc();
907           auto diag = mlir::emitError(loc, "unexpected application of ")
908                       << OpTy::getOperationName()
909                       << " produces both null and non null results.";
910           diag.attachNote(specificOp->getLoc()) << "when applied to this op";
911           return DiagnosedSilenceableFailure::definiteFailure();
912         }
913         return res;
914       });
915 
916   // Step 3. Propagate the definite failure if any and bail out.
917   if (result.isDefiniteFailure())
918     return result;
919 
920   // Step 4. If there are no results, return early.
921   if (OpTy::template hasTrait<OpTrait::ZeroResults>())
922     return result;
923 
924   // Step 5. Perform transposition of M applications producing N results each
925   // into N results for each of the M applications.
926   SmallVector<SmallVector<Operation *, 1>> transposedResults =
927       detail::transposeResults(results);
928 
929   // Step 6. Single result applies to M ops produces one single M-result.
930   if (OpTy::template hasTrait<OpTrait::OneResult>()) {
931     assert(transposedResults.size() == 1 && "Expected single result");
932     transformResults.set(
933         this->getOperation()->getResult(0).template cast<OpResult>(),
934         transposedResults[0]);
935     // ApplyToOne may have returned silenceableFailure, propagate it.
936     return result;
937   }
938 
939   // Step 7. Filter out empty results and set the transformResults.
940   for (const auto &it :
941        llvm::zip(this->getOperation()->getResults(), transposedResults)) {
942     SmallVector<Operation *, 1> filtered;
943     llvm::copy_if(std::get<1>(it), std::back_inserter(filtered),
944                   [](Operation *op) { return op; });
945     transformResults.set(std::get<0>(it).template cast<OpResult>(), filtered);
946   }
947 
948   // Step 8. ApplyToOne may have returned silenceableFailure, propagate it.
949   return result;
950 }
951 
952 template <typename OpTy>
953 mlir::LogicalResult
verifyTrait(Operation * op)954 mlir::transform::TransformEachOpTrait<OpTy>::verifyTrait(Operation *op) {
955   static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
956                 "expected single-operand op");
957   if (!op->getName().getInterface<TransformOpInterface>()) {
958     return op->emitError() << "TransformEachOpTrait should only be attached to "
959                               "ops that implement TransformOpInterface";
960   }
961 
962   return success();
963 }
964 
965 #endif // DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H
966