1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
10 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/Operation.h"
13 #include "llvm/ADT/ScopeExit.h"
14 #include "llvm/ADT/SmallPtrSet.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // TransformState
20 //===----------------------------------------------------------------------===//
21 
22 constexpr const Value transform::TransformState::kTopLevelValue;
23 
24 transform::TransformState::TransformState(Region &region, Operation *root,
25                                           const TransformOptions &options)
26     : topLevel(root), options(options) {
27   auto result = mappings.try_emplace(&region);
28   assert(result.second && "the region scope is already present");
29   (void)result;
30 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
31   regionStack.push_back(&region);
32 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
33 }
34 
35 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
36 
37 ArrayRef<Operation *>
38 transform::TransformState::getPayloadOps(Value value) const {
39   const TransformOpMapping &operationMapping = getMapping(value).direct;
40   auto iter = operationMapping.find(value);
41   assert(iter != operationMapping.end() && "unknown handle");
42   return iter->getSecond();
43 }
44 
45 Value transform::TransformState::getHandleForPayloadOp(Operation *op) const {
46   for (const Mappings &mapping : llvm::make_second_range(mappings)) {
47     if (Value handle = mapping.reverse.lookup(op))
48       return handle;
49   }
50   return Value();
51 }
52 
53 LogicalResult transform::TransformState::tryEmplaceReverseMapping(
54     Mappings &map, Operation *operation, Value handle) {
55   auto insertionResult = map.reverse.insert({operation, handle});
56   if (!insertionResult.second) {
57     InFlightDiagnostic diag = operation->emitError()
58                               << "operation tracked by two handles";
59     diag.attachNote(handle.getLoc()) << "handle";
60     diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
61     return diag;
62   }
63   return success();
64 }
65 
66 LogicalResult
67 transform::TransformState::setPayloadOps(Value value,
68                                          ArrayRef<Operation *> targets) {
69   assert(value != kTopLevelValue &&
70          "attempting to reset the transformation root");
71 
72   if (value.use_empty())
73     return success();
74 
75   // Setting new payload for the value without cleaning it first is a misuse of
76   // the API, assert here.
77   SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
78   Mappings &mappings = getMapping(value);
79   bool inserted =
80       mappings.direct.insert({value, std::move(storedTargets)}).second;
81   assert(inserted && "value is already associated with another list");
82   (void)inserted;
83 
84   // Having multiple handles to the same operation is an error in the transform
85   // expressed using the dialect and may be constructed by valid API calls from
86   // valid IR. Emit an error here.
87   for (Operation *op : targets) {
88     if (failed(tryEmplaceReverseMapping(mappings, op, value)))
89       return failure();
90   }
91 
92   return success();
93 }
94 
95 void transform::TransformState::removePayloadOps(Value value) {
96   Mappings &mappings = getMapping(value);
97   for (Operation *op : mappings.direct[value])
98     mappings.reverse.erase(op);
99   mappings.direct.erase(value);
100 }
101 
102 LogicalResult transform::TransformState::updatePayloadOps(
103     Value value, function_ref<Operation *(Operation *)> callback) {
104   Mappings &mappings = getMapping(value);
105   auto it = mappings.direct.find(value);
106   assert(it != mappings.direct.end() && "unknown handle");
107   SmallVector<Operation *> &association = it->getSecond();
108   SmallVector<Operation *> updated;
109   updated.reserve(association.size());
110 
111   for (Operation *op : association) {
112     mappings.reverse.erase(op);
113     if (Operation *updatedOp = callback(op)) {
114       updated.push_back(updatedOp);
115       if (failed(tryEmplaceReverseMapping(mappings, updatedOp, value)))
116         return failure();
117     }
118   }
119 
120   std::swap(association, updated);
121   return success();
122 }
123 
124 void transform::TransformState::recordHandleInvalidation(OpOperand &handle) {
125   ArrayRef<Operation *> potentialAncestors = getPayloadOps(handle.get());
126   for (const Mappings &mapping : llvm::make_second_range(mappings)) {
127     for (const auto &kvp : mapping.reverse) {
128       // If the op is associated with invalidated handle, skip the check as it
129       // may be reading invalid IR.
130       Operation *op = kvp.first;
131       Value otherHandle = kvp.second;
132       if (invalidatedHandles.count(otherHandle))
133         continue;
134 
135       for (Operation *ancestor : potentialAncestors) {
136         if (!ancestor->isProperAncestor(op))
137           continue;
138 
139         // Make sure the error-reporting lambda doesn't capture anything
140         // by-reference because it will go out of scope. Additionally, extract
141         // location from Payload IR ops because the ops themselves may be
142         // deleted before the lambda gets called.
143         Location ancestorLoc = ancestor->getLoc();
144         Location opLoc = op->getLoc();
145         Operation *owner = handle.getOwner();
146         unsigned operandNo = handle.getOperandNumber();
147         invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
148                                            otherHandle]() {
149           InFlightDiagnostic diag =
150               owner->emitOpError()
151               << "invalidated the handle to payload operations nested in the "
152                  "payload operation associated with its operand #"
153               << operandNo;
154           diag.attachNote(ancestorLoc) << "ancestor op";
155           diag.attachNote(opLoc) << "nested op";
156           diag.attachNote(otherHandle.getLoc()) << "other handle";
157         };
158       }
159     }
160   }
161 }
162 
163 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
164     TransformOpInterface transform) {
165   auto memoryEffectsIface =
166       cast<MemoryEffectOpInterface>(transform.getOperation());
167   SmallVector<MemoryEffects::EffectInstance> effects;
168   memoryEffectsIface.getEffectsOnResource(
169       transform::TransformMappingResource::get(), effects);
170 
171   for (OpOperand &target : transform->getOpOperands()) {
172     // If the operand uses an invalidated handle, report it.
173     auto it = invalidatedHandles.find(target.get());
174     if (it != invalidatedHandles.end())
175       return it->getSecond()(), failure();
176 
177     // Invalidate handles pointing to the operations nested in the operation
178     // associated with the handle consumed by this operation.
179     auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
180       return isa<MemoryEffects::Free>(effect.getEffect()) &&
181              effect.getValue() == target.get();
182     };
183     if (llvm::find_if(effects, consumesTarget) != effects.end())
184       recordHandleInvalidation(target);
185   }
186   return success();
187 }
188 
189 LogicalResult
190 transform::TransformState::applyTransform(TransformOpInterface transform) {
191   if (options.getExpensiveChecksEnabled() &&
192       failed(checkAndRecordHandleInvalidation(transform))) {
193     return failure();
194   }
195 
196   transform::TransformResults results(transform->getNumResults());
197   if (failed(transform.apply(results, *this)))
198     return failure();
199 
200   // Remove the mapping for the operand if it is consumed by the operation. This
201   // allows us to catch use-after-free with assertions later on.
202   auto memEffectInterface =
203       cast<MemoryEffectOpInterface>(transform.getOperation());
204   SmallVector<MemoryEffects::EffectInstance, 2> effects;
205   for (OpOperand &target : transform->getOpOperands()) {
206     effects.clear();
207     memEffectInterface.getEffectsOnValue(target.get(), effects);
208     if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
209           return isa<transform::TransformMappingResource>(
210                      effect.getResource()) &&
211                  isa<MemoryEffects::Free>(effect.getEffect());
212         })) {
213       removePayloadOps(target.get());
214     }
215   }
216 
217   for (OpResult result : transform->getResults()) {
218     assert(result.getDefiningOp() == transform.getOperation() &&
219            "payload IR association for a value other than the result of the "
220            "current transform op");
221     if (failed(setPayloadOps(result, results.get(result.getResultNumber()))))
222       return failure();
223   }
224 
225   return success();
226 }
227 
228 //===----------------------------------------------------------------------===//
229 // TransformState::Extension
230 //===----------------------------------------------------------------------===//
231 
232 transform::TransformState::Extension::~Extension() = default;
233 
234 LogicalResult
235 transform::TransformState::Extension::replacePayloadOp(Operation *op,
236                                                        Operation *replacement) {
237   return state.updatePayloadOps(state.getHandleForPayloadOp(op),
238                                 [&](Operation *current) {
239                                   return current == op ? replacement : current;
240                                 });
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // TransformResults
245 //===----------------------------------------------------------------------===//
246 
247 transform::TransformResults::TransformResults(unsigned numSegments) {
248   segments.resize(numSegments,
249                   ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
250 }
251 
252 void transform::TransformResults::set(OpResult value,
253                                       ArrayRef<Operation *> ops) {
254   unsigned position = value.getResultNumber();
255   assert(position < segments.size() &&
256          "setting results for a non-existent handle");
257   assert(segments[position].data() == nullptr && "results already set");
258   unsigned start = operations.size();
259   llvm::append_range(operations, ops);
260   segments[position] = makeArrayRef(operations).drop_front(start);
261 }
262 
263 ArrayRef<Operation *>
264 transform::TransformResults::get(unsigned resultNumber) const {
265   assert(resultNumber < segments.size() &&
266          "querying results for a non-existent handle");
267   assert(segments[resultNumber].data() != nullptr && "querying unset results");
268   return segments[resultNumber];
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // Utilities for PossibleTopLevelTransformOpTrait.
273 //===----------------------------------------------------------------------===//
274 
275 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
276     TransformState &state, Operation *op) {
277   SmallVector<Operation *> targets;
278   if (op->getNumOperands() != 0)
279     llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
280   else
281     targets.push_back(state.getTopLevel());
282 
283   return state.mapBlockArguments(op->getRegion(0).front().getArgument(0),
284                                  targets);
285 }
286 
287 LogicalResult
288 transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
289   // Attaching this trait without the interface is a misuse of the API, but it
290   // cannot be caught via a static_assert because interface registration is
291   // dynamic.
292   assert(isa<TransformOpInterface>(op) &&
293          "should implement TransformOpInterface to have "
294          "PossibleTopLevelTransformOpTrait");
295 
296   if (op->getNumRegions() != 1)
297     return op->emitOpError() << "expects one region";
298 
299   Region *bodyRegion = &op->getRegion(0);
300   if (!llvm::hasNItems(*bodyRegion, 1))
301     return op->emitOpError() << "expects a single-block region";
302 
303   Block *body = &bodyRegion->front();
304   if (body->getNumArguments() != 1 ||
305       !body->getArgumentTypes()[0].isa<pdl::OperationType>()) {
306     return op->emitOpError()
307            << "expects the entry block to have one argument of type "
308            << pdl::OperationType::get(op->getContext());
309   }
310 
311   if (auto *parent =
312           op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
313     if (op->getNumOperands() == 0) {
314       InFlightDiagnostic diag =
315           op->emitOpError()
316           << "expects the root operation to be provided for a nested op";
317       diag.attachNote(parent->getLoc())
318           << "nested in another possible top-level op";
319       return diag;
320     }
321   }
322 
323   return success();
324 }
325 
326 //===----------------------------------------------------------------------===//
327 // Generated interface implementation.
328 //===----------------------------------------------------------------------===//
329 
330 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc"
331