1 //===- TransformInterfaces.h - Transform Dialect Interfaces -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H
10 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H
11
12 #include "mlir/IR/OpDefinition.h"
13
14 #include "mlir/Interfaces/SideEffectInterfaces.h"
15 #include "llvm/ADT/ScopeExit.h"
16
17 namespace mlir {
18
19 /// The result of a transform IR operation application. This can have one of the
20 /// three states:
21 /// - success;
22 /// - silenceable (recoverable) failure with yet-unreported diagnostic;
23 /// - definite failure.
24 /// Silenceable failure is intended to communicate information about
25 /// transformations that did not apply but in a way that supports recovery,
26 /// for example, they did not modify the payload IR or modified it in some
27 /// predictable way. They are associated with a Diagnostic that provides more
28 /// details on the failure. Silenceable failure can be discarded, turning the
29 /// result into success, or "reported", emitting the diagnostic and turning the
30 /// result into definite failure.
31 /// Transform IR operations containing other operations are allowed to do either
32 /// with the results of the nested transformations, but must propagate definite
33 /// failures as their diagnostics have been already reported to the user.
34 class LLVM_NODISCARD DiagnosedSilenceableFailure {
35 public:
DiagnosedSilenceableFailure(LogicalResult result)36 explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {}
37 DiagnosedSilenceableFailure(const DiagnosedSilenceableFailure &) = delete;
38 DiagnosedSilenceableFailure &
39 operator=(const DiagnosedSilenceableFailure &) = delete;
40 DiagnosedSilenceableFailure(DiagnosedSilenceableFailure &&) = default;
41 DiagnosedSilenceableFailure &
42 operator=(DiagnosedSilenceableFailure &&) = default;
43
44 /// Constructs a DiagnosedSilenceableFailure in the success state.
success()45 static DiagnosedSilenceableFailure success() {
46 return DiagnosedSilenceableFailure(::mlir::success());
47 }
48
49 /// Constructs a DiagnosedSilenceableFailure in the failure state. Typically,
50 /// a diagnostic has been emitted before this.
definiteFailure()51 static DiagnosedSilenceableFailure definiteFailure() {
52 return DiagnosedSilenceableFailure(::mlir::failure());
53 }
54
55 /// Constructs a DiagnosedSilenceableFailure in the silenceable failure state,
56 /// ready to emit the given diagnostic. This is considered a failure
57 /// regardless of the diagnostic severity.
silenceableFailure(Diagnostic && diag)58 static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag) {
59 return DiagnosedSilenceableFailure(std::forward<Diagnostic>(diag));
60 }
61 static DiagnosedSilenceableFailure
silenceableFailure(SmallVector<Diagnostic> && diag)62 silenceableFailure(SmallVector<Diagnostic> &&diag) {
63 return DiagnosedSilenceableFailure(
64 std::forward<SmallVector<Diagnostic>>(diag));
65 }
66
67 /// Converts all kinds of failure into a LogicalResult failure, emitting the
68 /// diagnostic if necessary. Must not be called more than once.
checkAndReport()69 LogicalResult checkAndReport() {
70 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
71 assert(!reported && "attempting to report a diagnostic more than once");
72 reported = true;
73 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
74 if (!diagnostics.empty()) {
75 for (auto &&diagnostic : diagnostics) {
76 diagnostic.getLocation().getContext()->getDiagEngine().emit(
77 std::move(diagnostic));
78 }
79 diagnostics.clear();
80 result = ::mlir::failure();
81 }
82 return result;
83 }
84
85 /// Returns `true` if this is a success.
succeeded()86 bool succeeded() const {
87 return ::mlir::succeeded(result) && diagnostics.empty();
88 }
89
90 /// Returns `true` if this is a definite failure.
isDefiniteFailure()91 bool isDefiniteFailure() const {
92 return ::mlir::failed(result) && diagnostics.empty();
93 }
94
95 /// Returns `true` if this is a silenceable failure.
isSilenceableFailure()96 bool isSilenceableFailure() const { return !diagnostics.empty(); }
97
98 /// Returns the diagnostic message without emitting it. Expects this object
99 /// to be a silenceable failure.
getMessage()100 std::string getMessage() const {
101 std::string res;
102 for (auto &diagnostic : diagnostics) {
103 res.append(diagnostic.str());
104 res.append("\n");
105 }
106 return res;
107 }
108
109 /// Returns a string representation of the failure mode (for error reporting).
getStatusString()110 std::string getStatusString() const {
111 if (succeeded())
112 return "success";
113 if (isSilenceableFailure())
114 return "silenceable failure";
115 return "definite failure";
116 }
117
118 /// Converts silenceable failure into LogicalResult success without reporting
119 /// the diagnostic, preserves the other states.
silence()120 LogicalResult silence() {
121 if (!diagnostics.empty()) {
122 diagnostics.clear();
123 result = ::mlir::success();
124 }
125 return result;
126 }
127
128 /// Take the diagnostic and silence.
takeDiagnostics()129 SmallVector<Diagnostic> &&takeDiagnostics() {
130 assert(!diagnostics.empty() && "expected a diagnostic to be present");
131 auto guard = llvm::make_scope_exit([&]() { diagnostics.clear(); });
132 return std::move(diagnostics);
133 }
134
135 /// Streams the given values into the last diagnotic.
136 /// Expects this object to be a silenceable failure.
137 template <typename T>
138 DiagnosedSilenceableFailure &operator<<(T &&value) & {
139 assert(isSilenceableFailure() &&
140 "can only append output in silenceable failure state");
141 diagnostics.back() << std::forward<T>(value);
142 return *this;
143 }
144 template <typename T>
145 DiagnosedSilenceableFailure &&operator<<(T &&value) && {
146 return std::move(this->operator<<(std::forward<T>(value)));
147 }
148
149 /// Attaches a note to the last diagnostic.
150 /// Expects this object to be a silenceable failure.
151 Diagnostic &attachNote(Optional<Location> loc = llvm::None) {
152 assert(isSilenceableFailure() &&
153 "can only attach notes to silenceable failures");
154 return diagnostics.back().attachNote(loc);
155 }
156
157 private:
DiagnosedSilenceableFailure(Diagnostic && diagnostic)158 explicit DiagnosedSilenceableFailure(Diagnostic &&diagnostic)
159 : diagnostics(), result(failure()) {
160 diagnostics.emplace_back(std::move(diagnostic));
161 }
DiagnosedSilenceableFailure(SmallVector<Diagnostic> && diagnostics)162 explicit DiagnosedSilenceableFailure(SmallVector<Diagnostic> &&diagnostics)
163 : diagnostics(std::move(diagnostics)), result(failure()) {}
164
165 /// The diagnostics associated with this object. If non-empty, the object is
166 /// considered to be in the silenceable failure state regardless of the
167 /// `result` field.
168 SmallVector<Diagnostic, 1> diagnostics;
169
170 /// The "definite" logical state, either success or failure.
171 /// Ignored if the diagnostics message is present.
172 LogicalResult result;
173
174 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
175 /// Whether the associated diagnostics have been reported.
176 /// Diagnostics reporting consumes the diagnostics, so we need a mechanism to
177 /// differentiate reported diagnostics from a state where it was never
178 /// created.
179 bool reported = false;
180 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
181 };
182
183 namespace transform {
184
185 class TransformOpInterface;
186
187 /// Options controlling the application of transform operations by the
188 /// TransformState.
189 class TransformOptions {
190 public:
TransformOptions()191 TransformOptions() {}
192
193 /// Requests computationally expensive checks of the transform and payload IR
194 /// well-formedness to be performed before each transformation. In particular,
195 /// these ensure that the handles still point to valid operations when used.
196 TransformOptions &enableExpensiveChecks(bool enable = true) {
197 expensiveChecksEnabled = enable;
198 return *this;
199 }
200
201 /// Returns true if the expensive checks are requested.
getExpensiveChecksEnabled()202 bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }
203
204 private:
205 bool expensiveChecksEnabled = true;
206 };
207
208 /// The state maintained across applications of various ops implementing the
209 /// TransformOpInterface. The operations implementing this interface and the
210 /// surrounding structure are referred to as transform IR. The operations to
211 /// which transformations apply are referred to as payload IR. The state thus
212 /// contains the mapping between values defined in the transform IR ops and
213 /// payload IR ops. It assumes that each value in the transform IR can be used
214 /// at most once (since transformations are likely to change the payload IR ops
215 /// the value corresponds to). Checks that transform IR values correspond to
216 /// disjoint sets of payload IR ops throughout the transformation.
217 ///
218 /// A reference to this class is passed as an argument to "apply" methods of the
219 /// transform op interface. Thus the "apply" method can call
220 /// `state.getPayloadOps( getSomeOperand() )` to obtain the list of operations
221 /// associated with its operand and subject to transformation. The method is
222 /// expected to populate the `TransformResults` class instance in order to
223 /// update the mapping. The `applyTransform` method takes care of propagating
224 /// the state of `TransformResults` into the instance of this class.
225 ///
226 /// When applying transform IR operations with regions, the client is expected
227 /// to create a RegionScope RAII object to create a new "stack frame" for
228 /// values defined inside the region. The mappings from and to these values will
229 /// be automatically dropped when the object goes out of scope, typically at the
230 /// end of the "apply" function of the parent operation. If a region contains
231 /// blocks with arguments, the client can map those arguments to payload IR ops
232 /// using "mapBlockArguments".
233 class TransformState {
234 /// Mapping between a Value in the transform IR and the corresponding set of
235 /// operations in the payload IR.
236 using TransformOpMapping = DenseMap<Value, SmallVector<Operation *>>;
237
238 /// Mapping between a payload IR operation and the transform IR value it is
239 /// currently associated with.
240 using TransformOpReverseMapping = DenseMap<Operation *, Value>;
241
242 /// Bidirectional mappings between transform IR values and payload IR
243 /// operations.
244 struct Mappings {
245 TransformOpMapping direct;
246 TransformOpReverseMapping reverse;
247 };
248
249 public:
250 /// Creates a state for transform ops living in the given region. The parent
251 /// operation of the region. The second argument points to the root operation
252 /// in the payload IR beind transformed, which may or may not contain the
253 /// region with transform ops. Additional options can be provided through the
254 /// trailing configuration object.
255 TransformState(Region ®ion, Operation *root,
256 const TransformOptions &options = TransformOptions());
257
258 /// Returns the op at which the transformation state is rooted. This is
259 /// typically helpful for transformations that apply globally.
260 Operation *getTopLevel() const;
261
262 /// Returns the list of ops that the given transform IR value corresponds to.
263 /// This is helpful for transformations that apply to a particular handle.
264 ArrayRef<Operation *> getPayloadOps(Value value) const;
265
266 /// Returns the Transform IR handle for the given Payload IR op if it exists
267 /// in the state, null otherwise.
268 Value getHandleForPayloadOp(Operation *op) const;
269
270 /// Applies the transformation specified by the given transform op and updates
271 /// the state accordingly.
272 DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform);
273
274 /// Records the mapping between a block argument in the transform IR and a
275 /// list of operations in the payload IR. The arguments must be defined in
276 /// blocks of the currently processed transform IR region, typically after a
277 /// region scope is defined.
mapBlockArguments(BlockArgument argument,ArrayRef<Operation * > operations)278 LogicalResult mapBlockArguments(BlockArgument argument,
279 ArrayRef<Operation *> operations) {
280 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
281 assert(argument.getParentRegion() == regionStack.back() &&
282 "mapping block arguments from a region other than the active one");
283 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
284 return setPayloadOps(argument, operations);
285 }
286
287 // Forward declarations to support limited visibility.
288 class RegionScope;
289
290 /// Creates a new region scope for the given region. The region is expected to
291 /// be nested in the currently processed region.
292 // Implementation note: this method is inline but implemented outside of the
293 // class body to comply with visibility and full-declaration requirements.
294 inline RegionScope make_region_scope(Region ®ion);
295
296 /// A RAII object maintaining a "stack frame" for a transform IR region. When
297 /// applying a transform IR operation that contains a region, the caller is
298 /// expected to create a RegionScope before applying the ops contained in the
299 /// region. This ensures that the mappings between values defined in the
300 /// transform IR region and payload IR operations are cleared when the region
301 /// processing ends; such values cannot be accessed outside the region.
302 class RegionScope {
303 public:
304 /// Forgets the mapping from or to values defined in the associated
305 /// transform IR region.
~RegionScope()306 ~RegionScope() {
307 state.mappings.erase(region);
308 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
309 state.regionStack.pop_back();
310 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
311 }
312
313 private:
314 /// Creates a new scope for mappings between values defined in the given
315 /// transform IR region and payload IR operations.
RegionScope(TransformState & state,Region & region)316 RegionScope(TransformState &state, Region ®ion)
317 : state(state), region(®ion) {
318 auto res = state.mappings.try_emplace(this->region);
319 assert(res.second && "the region scope is already present");
320 (void)res;
321 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
322 assert(state.regionStack.back()->isProperAncestor(®ion) &&
323 "scope started at a non-nested region");
324 state.regionStack.push_back(®ion);
325 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
326 }
327
328 /// Back-reference to the transform state.
329 TransformState &state;
330
331 /// The region this scope is associated with.
332 Region *region;
333
334 friend RegionScope TransformState::make_region_scope(Region &);
335 };
336 friend class RegionScope;
337
338 /// Base class for TransformState extensions that allow TransformState to
339 /// contain user-specified information in the state object. Clients are
340 /// expected to derive this class, add the desired fields, and make the
341 /// derived class compatible with the MLIR TypeID mechanism:
342 ///
343 /// ```mlir
344 /// class MyExtension final : public TransformState::Extension {
345 /// public:
346 /// MyExtension(TranfsormState &state, int myData)
347 /// : Extension(state) {...}
348 /// private:
349 /// int mySupplementaryData;
350 /// };
351 /// ```
352 ///
353 /// Instances of this and derived classes are not expected to be created by
354 /// the user, instead they are directly constructed within a TransformState. A
355 /// TransformState can only contain one extension with the given TypeID.
356 /// Extensions can be obtained from a TransformState instance, and can be
357 /// removed when they are no longer required.
358 ///
359 /// ```mlir
360 /// transformState.addExtension<MyExtension>(/*myData=*/42);
361 /// MyExtension *ext = transformState.getExtension<MyExtension>();
362 /// ext->doSomething();
363 /// ```
364 class Extension {
365 // Allow TransformState to allocate Extensions.
366 friend class TransformState;
367
368 public:
369 /// Base virtual destructor.
370 // Out-of-line definition ensures symbols are emitted in a single object
371 // file.
372 virtual ~Extension();
373
374 protected:
375 /// Constructs an extension of the given TransformState object.
Extension(TransformState & state)376 Extension(TransformState &state) : state(state) {}
377
378 /// Provides read-only access to the parent TransformState object.
getTransformState()379 const TransformState &getTransformState() const { return state; }
380
381 /// Replaces the given payload op with another op. If the replacement op is
382 /// null, removes the association of the payload op with its handle.
383 LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
384
385 private:
386 /// Back-reference to the state that is being extended.
387 TransformState &state;
388 };
389
390 /// Adds a new Extension of the type specified as template parameter,
391 /// constructing it with the arguments provided. The extension is owned by the
392 /// TransformState. It is expected that the state does not already have an
393 /// extension of the same type. Extension constructors are expected to take
394 /// a reference to TransformState as first argument, automatically supplied
395 /// by this call.
396 template <typename Ty, typename... Args>
addExtension(Args &&...args)397 Ty &addExtension(Args &&...args) {
398 static_assert(
399 std::is_base_of<Extension, Ty>::value,
400 "only an class derived from TransformState::Extension is allowed here");
401 auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...);
402 auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
403 assert(result.second && "extension already added");
404 return *static_cast<Ty *>(result.first->second.get());
405 }
406
407 /// Returns the extension of the specified type.
408 template <typename Ty>
getExtension()409 Ty *getExtension() {
410 static_assert(
411 std::is_base_of<Extension, Ty>::value,
412 "only an class derived from TransformState::Extension is allowed here");
413 auto iter = extensions.find(TypeID::get<Ty>());
414 if (iter == extensions.end())
415 return nullptr;
416 return static_cast<Ty *>(iter->second.get());
417 }
418
419 /// Removes the extension of the specified type.
420 template <typename Ty>
removeExtension()421 void removeExtension() {
422 static_assert(
423 std::is_base_of<Extension, Ty>::value,
424 "only an class derived from TransformState::Extension is allowed here");
425 extensions.erase(TypeID::get<Ty>());
426 }
427
428 private:
429 /// Identifier for storing top-level value in the `operations` mapping.
430 static constexpr Value kTopLevelValue = Value();
431
432 /// Returns the mappings frame for the reigon in which the value is defined.
getMapping(Value value)433 const Mappings &getMapping(Value value) const {
434 return const_cast<TransformState *>(this)->getMapping(value);
435 }
getMapping(Value value)436 Mappings &getMapping(Value value) {
437 auto it = mappings.find(value.getParentRegion());
438 assert(it != mappings.end() &&
439 "trying to find a mapping for a value from an unmapped region");
440 return it->second;
441 }
442
443 /// Returns the mappings frame for the region in which the operation resides.
getMapping(Operation * operation)444 const Mappings &getMapping(Operation *operation) const {
445 return const_cast<TransformState *>(this)->getMapping(operation);
446 }
getMapping(Operation * operation)447 Mappings &getMapping(Operation *operation) {
448 auto it = mappings.find(operation->getParentRegion());
449 assert(it != mappings.end() &&
450 "trying to find a mapping for an operation from an unmapped region");
451 return it->second;
452 }
453
454 /// Sets the payload IR ops associated with the given transform IR value.
455 /// Fails if this would result in multiple transform IR values with uses
456 /// corresponding to the same payload IR ops. For example, a hypothetical
457 /// "find function by name" transform op would (indirectly) call this
458 /// function for its result. Having two such calls in a row with for different
459 /// values, e.g. coming from different ops:
460 ///
461 /// %0 = transform.find_func_by_name { name = "myfunc" }
462 /// %1 = transform.find_func_by_name { name = "myfunc" }
463 ///
464 /// would lead to both values pointing to the same operation. The second call
465 /// to setPayloadOps will fail, unless the association with the %0 value is
466 /// removed first by calling update/removePayloadOps.
467 LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
468
469 /// Forgets the payload IR ops associated with the given transform IR value.
470 void removePayloadOps(Value value);
471
472 /// Updates the payload IR ops associated with the given transform IR value.
473 /// The callback function is called once per associated operation and is
474 /// expected to return the modified operation or nullptr. In the latter case,
475 /// the corresponding operation is no longer associated with the transform IR
476 /// value. May fail if the operation produced by the update callback is
477 /// already associated with a different Transform IR handle value.
478 LogicalResult
479 updatePayloadOps(Value value,
480 function_ref<Operation *(Operation *)> callback);
481
482 /// Attempts to record the mapping between the given Payload IR operation and
483 /// the given Transform IR handle. Fails and reports an error if the operation
484 /// is already tracked by another handle.
485 static LogicalResult tryEmplaceReverseMapping(Mappings &map, Operation *op,
486 Value handle);
487
488 /// If the operand is a handle consumed by the operation, i.e. has the "free"
489 /// memory effect associated with it, identifies other handles that are
490 /// pointing to payload IR operations nested in the operations pointed to by
491 /// the consumed handle. Marks all such handles as invalidated so trigger
492 /// errors if they are used.
493 void recordHandleInvalidation(OpOperand &handle);
494
495 /// Checks that the operation does not use invalidated handles as operands.
496 /// Reports errors and returns failure if it does. Otherwise, invalidates the
497 /// handles consumed by the operation as well as any handles pointing to
498 /// payload IR operations nested in the operations associated with the
499 /// consumed handles.
500 LogicalResult
501 checkAndRecordHandleInvalidation(TransformOpInterface transform);
502
503 /// The mappings between transform IR values and payload IR ops, aggregated by
504 /// the region in which the transform IR values are defined.
505 llvm::SmallDenseMap<Region *, Mappings> mappings;
506
507 /// Extensions attached to the TransformState, identified by the TypeID of
508 /// their type. Only one extension of any given type is allowed.
509 DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
510
511 /// The top-level operation that contains all payload IR, typically a module.
512 Operation *topLevel;
513
514 /// Additional options controlling the transformation state behavior.
515 TransformOptions options;
516
517 /// The mapping from invalidated handles to the error-reporting functions that
518 /// describe when the handles were invalidated. Calling such a function emits
519 /// a user-visible diagnostic.
520 DenseMap<Value, std::function<void()>> invalidatedHandles;
521
522 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
523 /// A stack of nested regions that are being processed in the transform IR.
524 /// Each region must be an ancestor of the following regions in this list.
525 /// These are also the keys for "mappings".
526 SmallVector<Region *> regionStack;
527 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
528 };
529
530 /// Local mapping between values defined by a specific op implementing the
531 /// TransformOpInterface and the payload IR ops they correspond to.
532 class TransformResults {
533 friend class TransformState;
534
535 public:
536 /// Indicates that the result of the transform IR op at the given position
537 /// corresponds to the given list of payload IR ops. Each result must be set
538 /// by the transformation exactly once.
539 void set(OpResult value, ArrayRef<Operation *> ops);
540
541 private:
542 /// Creates an instance of TransformResults that expects mappings for
543 /// `numSegments` values.
544 explicit TransformResults(unsigned numSegments);
545
546 /// Gets the list of operations associated with the result identified by its
547 /// number in the list of operation results.
548 ArrayRef<Operation *> get(unsigned resultNumber) const;
549
550 /// Storage for pointers to payload IR ops that are associated with results of
551 /// a transform IR op. `segments` contains as many entries as the transform IR
552 /// op has results. Each entry is a reference to a contiguous segment in
553 /// the `operations` list that contains the pointers to operations. This
554 /// allows for operations to be stored contiguously without nested vectors and
555 /// for different segments to be set in any order.
556 SmallVector<ArrayRef<Operation *>, 2> segments;
557 SmallVector<Operation *> operations;
558 };
559
make_region_scope(Region & region)560 TransformState::RegionScope TransformState::make_region_scope(Region ®ion) {
561 return RegionScope(*this, region);
562 }
563
564 namespace detail {
565 /// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait
566 /// to either the list of operations associated with its operand or the root of
567 /// the payload IR, depending on what is available in the context.
568 LogicalResult
569 mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
570 Operation *op, Region ®ion);
571
572 /// Verification hook for PossibleTopLevelTransformOpTrait.
573 LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
574 } // namespace detail
575
576 /// This trait is supposed to be attached to Transform dialect operations that
577 /// can be standalone top-level transforms. Such operations typically contain
578 /// other Transform dialect operations that can be executed following some
579 /// control flow logic specific to the current operation. The operations with
580 /// this trait are expected to have at least one single-block region with one
581 /// argument of PDL Operation type. The operations are also expected to be valid
582 /// without operands, in which case they are considered top-level, and with one
583 /// or more arguments, in which case they are considered nested. Top-level
584 /// operations have the block argument of the entry block in the Transform IR
585 /// correspond to the root operation of Payload IR. Nested operations have the
586 /// block argument of the entry block in the Transform IR correspond to a list
587 /// of Payload IR operations mapped to the first operand of the Transform IR
588 /// operation. The operation must implement TransformOpInterface.
589 template <typename OpTy>
590 class PossibleTopLevelTransformOpTrait
591 : public OpTrait::TraitBase<OpTy, PossibleTopLevelTransformOpTrait> {
592 public:
593 /// Verifies that `op` satisfies the invariants of this trait. Not expected to
594 /// be called directly.
verifyTrait(Operation * op)595 static LogicalResult verifyTrait(Operation *op) {
596 return detail::verifyPossibleTopLevelTransformOpTrait(op);
597 }
598
599 /// Returns the single block of the given region.
600 Block *getBodyBlock(unsigned region = 0) {
601 return &this->getOperation()->getRegion(region).front();
602 }
603
604 /// Sets up the mapping between the entry block of the given region of this op
605 /// and the relevant list of Payload IR operations in the given state. The
606 /// state is expected to be already scoped at the region of this operation.
607 /// Returns failure if the mapping failed, e.g., the value is already mapped.
mapBlockArguments(TransformState & state,Region & region)608 LogicalResult mapBlockArguments(TransformState &state, Region ®ion) {
609 assert(region.getParentOp() == this->getOperation() &&
610 "op comes from the wrong region");
611 return detail::mapPossibleTopLevelTransformOpBlockArguments(
612 state, this->getOperation(), region);
613 }
mapBlockArguments(TransformState & state)614 LogicalResult mapBlockArguments(TransformState &state) {
615 assert(
616 this->getOperation()->getNumRegions() == 1 &&
617 "must indicate the region to map if the operation has more than one");
618 return mapBlockArguments(state, this->getOperation()->getRegion(0));
619 }
620 };
621
622 /// Trait implementing the TransformOpInterface for operations applying a
623 /// transformation to a single operation handle and producing zero, one or
624 /// multiple operation handles.
625 /// The op must implement a method with the following signature:
626 /// - DiagnosedSilenceableFailure applyToOne(OpTy,
627 /// SmallVector<Operation*> &results, state)
628 /// to perform a transformation that is applied in turn to all payload IR
629 /// operations that correspond to the handle of the transform IR operation.
630 /// In `applyToOne`, OpTy is either Operation* or a concrete payload IR Op class
631 /// that the transformation is applied to (and NOT the class of the transform IR
632 /// op).
633 /// The `applyToOne` method takes an empty `results` vector that it fills with
634 /// zero, one or multiple operations depending on the number of resultd expected
635 /// by the transform op.
636 /// The number of results must match the number of results of the transform op.
637 /// `applyToOne` is allowed to fill the `results` with all null elements to
638 /// signify that the transformation did not apply to the payload IR operations.
639 /// Such null elements are filtered out from results before return.
640 ///
641 /// The transform op having this trait is expected to have a single operand.
642 template <typename OpTy>
643 class TransformEachOpTrait
644 : public OpTrait::TraitBase<OpTy, TransformEachOpTrait> {
645 public:
646 /// Calls `applyToOne` for every payload operation associated with the operand
647 /// of this transform IR op, the following case disjunction happens:
648 /// 1. If not target payload ops are associated to the operand then fill the
649 /// results vector with the expected number of null elements and return
650 /// success. This is the corner case handling that allows propagating
651 /// the "no-op" case gracefully to improve usability.
652 /// 2. If any `applyToOne` returns definiteFailure, the transformation is
653 /// immediately considered definitely failed and we return.
654 /// 3. All applications of `applyToOne` are checked to return a number of
655 /// results expected by the transform IR op. If not, this is a definite
656 /// failure and we return early.
657 /// 4. If `applyToOne` produces ops, associate them with the result of this
658 /// transform op.
659 /// 5. If any `applyToOne` return silenceableFailure, the transformation is
660 /// considered silenceable.
661 /// 6. Otherwise the transformation is considered successful.
662 DiagnosedSilenceableFailure apply(TransformResults &transformResults,
663 TransformState &state);
664
665 /// Checks that the op matches the expectations of this trait.
666 static LogicalResult verifyTrait(Operation *op);
667 };
668
669 /// Side effect resource corresponding to the mapping between Transform IR
670 /// values and Payload IR operations. An Allocate effect from this resource
671 /// means creating a new mapping entry, it is always accompanied by a Write
672 /// effet. A Read effect from this resource means accessing the mapping. A Free
673 /// effect on this resource indicates the removal of the mapping entry,
674 /// typically after a transformation that modifies the Payload IR operations
675 /// associated with one of the Transform IR operation's operands. It is always
676 /// accompanied by a Read effect. Read-after-Free and double-Free are not
677 /// allowed (they would be problematic with "regular" memory effects too) as
678 /// they indicate an attempt to access Payload IR operations that have been
679 /// modified, potentially erased, by the previous tranfsormations.
680 // TODO: consider custom effects if these are not enabling generic passes such
681 // as CSE/DCE to work.
682 struct TransformMappingResource
683 : public SideEffects::Resource::Base<TransformMappingResource> {
getNameTransformMappingResource684 StringRef getName() override { return "transform.mapping"; }
685 };
686
687 /// Side effect resource corresponding to the Payload IR itself. Only Read and
688 /// Write effects are expected on this resource, with Write always accompanied
689 /// by a Read (short of fully replacing the top-level Payload IR operation, one
690 /// cannot modify the Payload IR without reading it first). This is intended
691 /// to disallow reordering of Transform IR operations that mutate the Payload IR
692 /// while still allowing the reordering of those that only access it.
693 struct PayloadIRResource
694 : public SideEffects::Resource::Base<PayloadIRResource> {
getNamePayloadIRResource695 StringRef getName() override { return "transform.payload_ir"; }
696 };
697
698 /// Populates `effects` with the memory effects indicating the operation on the
699 /// given handle value:
700 /// - consumes = Read + Free,
701 /// - produces = Allocate + Write,
702 /// - onlyReads = Read.
703 void consumesHandle(ValueRange handles,
704 SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
705 void producesHandle(ValueRange handles,
706 SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
707 void onlyReadsHandle(ValueRange handles,
708 SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
709
710 /// Checks whether the transform op consumes the given handle.
711 bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
712
713 /// Populates `effects` with the memory effects indicating the access to payload
714 /// IR resource.
715 void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
716 void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
717
718 /// Trait implementing the MemoryEffectOpInterface for operations that "consume"
719 /// their operands and produce new results.
720 template <typename OpTy>
721 class FunctionalStyleTransformOpTrait
722 : public OpTrait::TraitBase<OpTy, FunctionalStyleTransformOpTrait> {
723 public:
724 /// This op "consumes" the operands by reading and freeing then, "produces"
725 /// the results by allocating and writing it and reads/writes the payload IR
726 /// in the process.
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)727 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
728 consumesHandle(this->getOperation()->getOperands(), effects);
729 producesHandle(this->getOperation()->getResults(), effects);
730 modifiesPayload(effects);
731 }
732
733 /// Checks that the op matches the expectations of this trait.
verifyTrait(Operation * op)734 static LogicalResult verifyTrait(Operation *op) {
735 if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
736 op->emitError()
737 << "FunctionalStyleTransformOpTrait should only be attached to ops "
738 "that implement MemoryEffectOpInterface";
739 }
740 return success();
741 }
742 };
743
744 /// Trait implementing the MemoryEffectOpInterface for single-operand
745 /// single-result operations that use their operand without consuming and
746 /// without modifying the Payload IR to produce a new handle.
747 template <typename OpTy>
748 class NavigationTransformOpTrait
749 : public OpTrait::TraitBase<OpTy, NavigationTransformOpTrait> {
750 public:
751 /// This op produces handles to the Payload IR without consuming the original
752 /// handles and without modifying the IR itself.
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)753 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
754 onlyReadsHandle(this->getOperation()->getOperands(), effects);
755 producesHandle(this->getOperation()->getResults(), effects);
756 onlyReadsPayload(effects);
757 }
758
759 /// Checks that the op matches the expectation of this trait.
verifyTrait(Operation * op)760 static LogicalResult verifyTrait(Operation *op) {
761 static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
762 "expected single-operand op");
763 static_assert(OpTy::template hasTrait<OpTrait::OneResult>(),
764 "expected single-result op");
765 if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
766 op->emitError() << "NavigationTransformOpTrait should only be attached "
767 "to ops that implement MemoryEffectOpInterface";
768 }
769 return success();
770 }
771 };
772
773 } // namespace transform
774 } // namespace mlir
775
776 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h.inc"
777
778 namespace mlir {
779 namespace transform {
780 namespace detail {
781 /// Applies a one-to-one or a one-to-many transform to each of the given
782 /// targets. Puts the results of transforms, if any, in `results` in the same
783 /// order. Fails if any of the application fails. Individual transforms must be
784 /// callable with the following signature:
785 /// - DiagnosedSilenceableFailure(OpTy,
786 /// SmallVector<Operation*> &results, state)
787 /// where OpTy is either
788 /// - Operation *, in which case the transform is always applied;
789 /// - a concrete Op class, in which case a check is performed whether
790 /// `targets` contains operations of the same class and a silenceable failure
791 /// is reported if it does not.
792 template <typename FnTy>
applyTransformToEach(Location loc,int expectedNumResults,ArrayRef<Operation * > targets,SmallVectorImpl<SmallVector<Operation * >> & results,FnTy transform)793 DiagnosedSilenceableFailure applyTransformToEach(
794 Location loc, int expectedNumResults, ArrayRef<Operation *> targets,
795 SmallVectorImpl<SmallVector<Operation *>> &results, FnTy transform) {
796 SmallVector<Diagnostic> silenceableStack;
797 using OpTy = typename llvm::function_traits<FnTy>::template arg_t<0>;
798 static_assert(std::is_convertible<OpTy, Operation *>::value,
799 "expected transform function to take an operation");
800 for (Operation *target : targets) {
801 // Emplace back a placeholder for the returned new ops.
802 // This is filled with `expectedNumResults` if the op fails to apply.
803 results.push_back(SmallVector<Operation *>());
804
805 auto specificOp = dyn_cast<OpTy>(target);
806 if (!specificOp) {
807 Diagnostic diag(loc, DiagnosticSeverity::Error);
808 diag << "transform applied to the wrong op kind";
809 diag.attachNote(target->getLoc()) << "when applied to this op";
810 // Producing `expectedNumResults` nullptr is a silenceableFailure mode.
811 // TODO: encode this implicit `expectedNumResults` nullptr ==
812 // silenceableFailure with a proper trait.
813 results.back().assign(expectedNumResults, nullptr);
814 silenceableStack.push_back(std::move(diag));
815 continue;
816 }
817
818 DiagnosedSilenceableFailure result = transform(specificOp, results.back());
819 if (result.isDefiniteFailure())
820 return result;
821 if (result.isSilenceableFailure())
822 for (auto &&diag : result.takeDiagnostics())
823 silenceableStack.push_back(std::move(diag));
824 }
825 if (!silenceableStack.empty()) {
826 return DiagnosedSilenceableFailure::silenceableFailure(
827 std::move(silenceableStack));
828 }
829 return DiagnosedSilenceableFailure::success();
830 }
831
832 /// Helper function: transpose MxN into NxM; assumes that the input is a valid.
833 static inline SmallVector<SmallVector<Operation *, 1>>
transposeResults(const SmallVector<SmallVector<Operation * >,1> & m)834 transposeResults(const SmallVector<SmallVector<Operation *>, 1> &m) {
835 SmallVector<SmallVector<Operation *, 1>> res;
836 if (m.empty())
837 return res;
838 int64_t rows = m.size(), cols = m[0].size();
839 for (int64_t j = 0; j < cols; ++j)
840 res.push_back(SmallVector<Operation *, 1>(rows, nullptr));
841 for (int64_t i = 0; i < rows; ++i) {
842 assert(static_cast<int64_t>(m[i].size()) == cols);
843 for (int64_t j = 0; j < cols; ++j) {
844 res[j][i] = m[i][j];
845 }
846 }
847 return res;
848 }
849 } // namespace detail
850 } // namespace transform
851 } // namespace mlir
852
853 template <typename OpTy>
854 mlir::DiagnosedSilenceableFailure
apply(TransformResults & transformResults,TransformState & state)855 mlir::transform::TransformEachOpTrait<OpTy>::apply(
856 TransformResults &transformResults, TransformState &state) {
857 using TransformOpType = typename llvm::function_traits<
858 decltype(&OpTy::applyToOne)>::template arg_t<0>;
859 ArrayRef<Operation *> targets =
860 state.getPayloadOps(this->getOperation()->getOperand(0));
861
862 // Step 1. Handle the corner case where no target is specified.
863 // This is typically the case when the matcher fails to apply and we need to
864 // propagate gracefully.
865 // In this case, we fill all results with an empty vector.
866 if (targets.empty()) {
867 SmallVector<Operation *> empty;
868 for (auto r : this->getOperation()->getResults())
869 transformResults.set(r.template cast<OpResult>(), empty);
870 return DiagnosedSilenceableFailure::success();
871 }
872
873 // Step 2. Call applyToOne on each target and record newly produced ops in its
874 // corresponding results entry.
875 int expectedNumResults = this->getOperation()->getNumResults();
876 SmallVector<SmallVector<Operation *>, 1> results;
877 DiagnosedSilenceableFailure result = detail::applyTransformToEach(
878 this->getOperation()->getLoc(), expectedNumResults, targets, results,
879 [&](TransformOpType specificOp, SmallVector<Operation *> &partialResult) {
880 auto res = static_cast<OpTy *>(this)->applyToOne(specificOp,
881 partialResult, state);
882 if (res.isDefiniteFailure())
883 return res;
884
885 // TODO: encode this implicit must always produce `expectedNumResults`
886 // and nullptr is fine with a proper trait.
887 if (static_cast<int>(partialResult.size()) != expectedNumResults) {
888 auto loc = this->getOperation()->getLoc();
889 auto diag = mlir::emitError(loc, "applications of ")
890 << OpTy::getOperationName() << " expected to produce "
891 << expectedNumResults << " results (actually produced "
892 << partialResult.size() << ").";
893 diag.attachNote(loc)
894 << "If you need variadic results, consider a generic `apply` "
895 << "instead of the specialized `applyToOne`.";
896 diag.attachNote(loc)
897 << "Producing " << expectedNumResults << " null results is "
898 << "allowed if the use case warrants it.";
899 diag.attachNote(specificOp->getLoc()) << "when applied to this op";
900 return DiagnosedSilenceableFailure::definiteFailure();
901 }
902 // Check that all is null or none is null
903 // TODO: relax this behavior and encode with a proper trait.
904 if (llvm::any_of(partialResult, [](Operation *op) { return op; }) &&
905 llvm::any_of(partialResult, [](Operation *op) { return !op; })) {
906 auto loc = this->getOperation()->getLoc();
907 auto diag = mlir::emitError(loc, "unexpected application of ")
908 << OpTy::getOperationName()
909 << " produces both null and non null results.";
910 diag.attachNote(specificOp->getLoc()) << "when applied to this op";
911 return DiagnosedSilenceableFailure::definiteFailure();
912 }
913 return res;
914 });
915
916 // Step 3. Propagate the definite failure if any and bail out.
917 if (result.isDefiniteFailure())
918 return result;
919
920 // Step 4. If there are no results, return early.
921 if (OpTy::template hasTrait<OpTrait::ZeroResults>())
922 return result;
923
924 // Step 5. Perform transposition of M applications producing N results each
925 // into N results for each of the M applications.
926 SmallVector<SmallVector<Operation *, 1>> transposedResults =
927 detail::transposeResults(results);
928
929 // Step 6. Single result applies to M ops produces one single M-result.
930 if (OpTy::template hasTrait<OpTrait::OneResult>()) {
931 assert(transposedResults.size() == 1 && "Expected single result");
932 transformResults.set(
933 this->getOperation()->getResult(0).template cast<OpResult>(),
934 transposedResults[0]);
935 // ApplyToOne may have returned silenceableFailure, propagate it.
936 return result;
937 }
938
939 // Step 7. Filter out empty results and set the transformResults.
940 for (const auto &it :
941 llvm::zip(this->getOperation()->getResults(), transposedResults)) {
942 SmallVector<Operation *, 1> filtered;
943 llvm::copy_if(std::get<1>(it), std::back_inserter(filtered),
944 [](Operation *op) { return op; });
945 transformResults.set(std::get<0>(it).template cast<OpResult>(), filtered);
946 }
947
948 // Step 8. ApplyToOne may have returned silenceableFailure, propagate it.
949 return result;
950 }
951
952 template <typename OpTy>
953 mlir::LogicalResult
verifyTrait(Operation * op)954 mlir::transform::TransformEachOpTrait<OpTy>::verifyTrait(Operation *op) {
955 static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
956 "expected single-operand op");
957 if (!op->getName().getInterface<TransformOpInterface>()) {
958 return op->emitError() << "TransformEachOpTrait should only be attached to "
959 "ops that implement TransformOpInterface";
960 }
961
962 return success();
963 }
964
965 #endif // DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H
966