1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
10 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/Operation.h"
13 #include "llvm/ADT/ScopeExit.h"
14 #include "llvm/ADT/SmallPtrSet.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // TransformState
20 //===----------------------------------------------------------------------===//
21 
22 constexpr const Value transform::TransformState::kTopLevelValue;
23 
24 transform::TransformState::TransformState(Region &region, Operation *root)
25     : topLevel(root) {
26   auto result = mappings.try_emplace(&region);
27   assert(result.second && "the region scope is already present");
28   (void)result;
29 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
30   regionStack.push_back(&region);
31 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
32 }
33 
34 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
35 
36 ArrayRef<Operation *>
37 transform::TransformState::getPayloadOps(Value value) const {
38   const TransformOpMapping &operationMapping = getMapping(value).direct;
39   auto iter = operationMapping.find(value);
40   assert(iter != operationMapping.end() && "unknown handle");
41   return iter->getSecond();
42 }
43 
44 Value transform::TransformState::getHandleForPayloadOp(Operation *op) const {
45   for (const Mappings &mapping : llvm::make_second_range(mappings)) {
46     if (Value handle = mapping.reverse.lookup(op))
47       return handle;
48   }
49   return Value();
50 }
51 
52 LogicalResult transform::TransformState::tryEmplaceReverseMapping(
53     Mappings &map, Operation *operation, Value handle) {
54   auto insertionResult = map.reverse.insert({operation, handle});
55   if (!insertionResult.second) {
56     InFlightDiagnostic diag = operation->emitError()
57                               << "operation tracked by two handles";
58     diag.attachNote(handle.getLoc()) << "handle";
59     diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
60     return diag;
61   }
62   return success();
63 }
64 
65 LogicalResult
66 transform::TransformState::setPayloadOps(Value value,
67                                          ArrayRef<Operation *> targets) {
68   assert(value != kTopLevelValue &&
69          "attempting to reset the transformation root");
70 
71   if (value.use_empty())
72     return success();
73 
74   // Setting new payload for the value without cleaning it first is a misuse of
75   // the API, assert here.
76   SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
77   Mappings &mappings = getMapping(value);
78   bool inserted =
79       mappings.direct.insert({value, std::move(storedTargets)}).second;
80   assert(inserted && "value is already associated with another list");
81   (void)inserted;
82 
83   // Having multiple handles to the same operation is an error in the transform
84   // expressed using the dialect and may be constructed by valid API calls from
85   // valid IR. Emit an error here.
86   for (Operation *op : targets) {
87     if (failed(tryEmplaceReverseMapping(mappings, op, value)))
88       return failure();
89   }
90 
91   return success();
92 }
93 
94 void transform::TransformState::removePayloadOps(Value value) {
95   Mappings &mappings = getMapping(value);
96   for (Operation *op : mappings.direct[value])
97     mappings.reverse.erase(op);
98   mappings.direct.erase(value);
99 }
100 
101 LogicalResult transform::TransformState::updatePayloadOps(
102     Value value, function_ref<Operation *(Operation *)> callback) {
103   Mappings &mappings = getMapping(value);
104   auto it = mappings.direct.find(value);
105   assert(it != mappings.direct.end() && "unknown handle");
106   SmallVector<Operation *> &association = it->getSecond();
107   SmallVector<Operation *> updated;
108   updated.reserve(association.size());
109 
110   for (Operation *op : association) {
111     mappings.reverse.erase(op);
112     if (Operation *updatedOp = callback(op)) {
113       updated.push_back(updatedOp);
114       if (failed(tryEmplaceReverseMapping(mappings, updatedOp, value)))
115         return failure();
116     }
117   }
118 
119   std::swap(association, updated);
120   return success();
121 }
122 
123 LogicalResult
124 transform::TransformState::applyTransform(TransformOpInterface transform) {
125   transform::TransformResults results(transform->getNumResults());
126   if (failed(transform.apply(results, *this)))
127     return failure();
128 
129   // Remove the mapping for the operand if it is consumed by the operation. This
130   // allows us to catch use-after-free with assertions later on.
131   auto memEffectInterface =
132       cast<MemoryEffectOpInterface>(transform.getOperation());
133   SmallVector<MemoryEffects::EffectInstance, 2> effects;
134   for (Value target : transform->getOperands()) {
135     effects.clear();
136     memEffectInterface.getEffectsOnValue(target, effects);
137     if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
138           return isa<transform::TransformMappingResource>(
139                      effect.getResource()) &&
140                  isa<MemoryEffects::Free>(effect.getEffect());
141         })) {
142       removePayloadOps(target);
143     }
144   }
145 
146   for (auto &en : llvm::enumerate(transform->getResults())) {
147     assert(en.value().getDefiningOp() == transform.getOperation() &&
148            "payload IR association for a value other than the result of the "
149            "current transform op");
150     if (failed(setPayloadOps(en.value(), results.get(en.index()))))
151       return failure();
152   }
153 
154   return success();
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // TransformState::Extension
159 //===----------------------------------------------------------------------===//
160 
161 transform::TransformState::Extension::~Extension() = default;
162 
163 LogicalResult
164 transform::TransformState::Extension::replacePayloadOp(Operation *op,
165                                                        Operation *replacement) {
166   return state.updatePayloadOps(state.getHandleForPayloadOp(op),
167                                 [&](Operation *current) {
168                                   return current == op ? replacement : current;
169                                 });
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // TransformResults
174 //===----------------------------------------------------------------------===//
175 
176 transform::TransformResults::TransformResults(unsigned numSegments) {
177   segments.resize(numSegments,
178                   ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
179 }
180 
181 void transform::TransformResults::set(OpResult value,
182                                       ArrayRef<Operation *> ops) {
183   unsigned position = value.getResultNumber();
184   assert(position < segments.size() &&
185          "setting results for a non-existent handle");
186   assert(segments[position].data() == nullptr && "results already set");
187   unsigned start = operations.size();
188   llvm::append_range(operations, ops);
189   segments[position] = makeArrayRef(operations).drop_front(start);
190 }
191 
192 ArrayRef<Operation *>
193 transform::TransformResults::get(unsigned resultNumber) const {
194   assert(resultNumber < segments.size() &&
195          "querying results for a non-existent handle");
196   assert(segments[resultNumber].data() != nullptr && "querying unset results");
197   return segments[resultNumber];
198 }
199 
200 //===----------------------------------------------------------------------===//
201 // Utilities for PossibleTopLevelTransformOpTrait.
202 //===----------------------------------------------------------------------===//
203 
204 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
205     TransformState &state, Operation *op) {
206   SmallVector<Operation *> targets;
207   if (op->getNumOperands() != 0)
208     llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
209   else
210     targets.push_back(state.getTopLevel());
211 
212   return state.mapBlockArguments(op->getRegion(0).front().getArgument(0),
213                                  targets);
214 }
215 
216 LogicalResult
217 transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
218   // Attaching this trait without the interface is a misuse of the API, but it
219   // cannot be caught via a static_assert because interface registration is
220   // dynamic.
221   assert(isa<TransformOpInterface>(op) &&
222          "should implement TransformOpInterface to have "
223          "PossibleTopLevelTransformOpTrait");
224 
225   if (op->getNumRegions() != 1)
226     return op->emitOpError() << "expects one region";
227 
228   Region *bodyRegion = &op->getRegion(0);
229   if (!llvm::hasNItems(*bodyRegion, 1))
230     return op->emitOpError() << "expects a single-block region";
231 
232   Block *body = &bodyRegion->front();
233   if (body->getNumArguments() != 1 ||
234       !body->getArgumentTypes()[0].isa<pdl::OperationType>()) {
235     return op->emitOpError()
236            << "expects the entry block to have one argument of type "
237            << pdl::OperationType::get(op->getContext());
238   }
239 
240   if (auto *parent =
241           op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
242     if (op->getNumOperands() == 0) {
243       InFlightDiagnostic diag =
244           op->emitOpError()
245           << "expects the root operation to be provided for a nested op";
246       diag.attachNote(parent->getLoc())
247           << "nested in another possible top-level op";
248       return diag;
249     }
250   }
251 
252   return success();
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // Generated interface implementation.
257 //===----------------------------------------------------------------------===//
258 
259 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc"
260