1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
10 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/Operation.h"
13 #include "llvm/Support/Debug.h"
14 
15 #define DEBUG_TYPE "transform-dialect"
16 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
17 
18 using namespace mlir;
19 
20 //===----------------------------------------------------------------------===//
21 // TransformState
22 //===----------------------------------------------------------------------===//
23 
24 constexpr const Value transform::TransformState::kTopLevelValue;
25 
26 transform::TransformState::TransformState(Region &region, Operation *root,
27                                           const TransformOptions &options)
28     : topLevel(root), options(options) {
29   auto result = mappings.try_emplace(&region);
30   assert(result.second && "the region scope is already present");
31   (void)result;
32 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
33   regionStack.push_back(&region);
34 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
35 }
36 
37 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
38 
39 ArrayRef<Operation *>
40 transform::TransformState::getPayloadOps(Value value) const {
41   const TransformOpMapping &operationMapping = getMapping(value).direct;
42   auto iter = operationMapping.find(value);
43   assert(iter != operationMapping.end() && "unknown handle");
44   return iter->getSecond();
45 }
46 
47 Value transform::TransformState::getHandleForPayloadOp(Operation *op) const {
48   for (const Mappings &mapping : llvm::make_second_range(mappings)) {
49     if (Value handle = mapping.reverse.lookup(op))
50       return handle;
51   }
52   return Value();
53 }
54 
55 LogicalResult transform::TransformState::tryEmplaceReverseMapping(
56     Mappings &map, Operation *operation, Value handle) {
57   auto insertionResult = map.reverse.insert({operation, handle});
58   if (!insertionResult.second) {
59     InFlightDiagnostic diag = operation->emitError()
60                               << "operation tracked by two handles";
61     diag.attachNote(handle.getLoc()) << "handle";
62     diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
63     return diag;
64   }
65   return success();
66 }
67 
68 LogicalResult
69 transform::TransformState::setPayloadOps(Value value,
70                                          ArrayRef<Operation *> targets) {
71   assert(value != kTopLevelValue &&
72          "attempting to reset the transformation root");
73 
74   if (value.use_empty())
75     return success();
76 
77   // Setting new payload for the value without cleaning it first is a misuse of
78   // the API, assert here.
79   SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
80   Mappings &mappings = getMapping(value);
81   bool inserted =
82       mappings.direct.insert({value, std::move(storedTargets)}).second;
83   assert(inserted && "value is already associated with another list");
84   (void)inserted;
85 
86   // Having multiple handles to the same operation is an error in the transform
87   // expressed using the dialect and may be constructed by valid API calls from
88   // valid IR. Emit an error here.
89   for (Operation *op : targets) {
90     if (failed(tryEmplaceReverseMapping(mappings, op, value)))
91       return failure();
92   }
93 
94   return success();
95 }
96 
97 void transform::TransformState::removePayloadOps(Value value) {
98   Mappings &mappings = getMapping(value);
99   for (Operation *op : mappings.direct[value])
100     mappings.reverse.erase(op);
101   mappings.direct.erase(value);
102 }
103 
104 LogicalResult transform::TransformState::updatePayloadOps(
105     Value value, function_ref<Operation *(Operation *)> callback) {
106   Mappings &mappings = getMapping(value);
107   auto it = mappings.direct.find(value);
108   assert(it != mappings.direct.end() && "unknown handle");
109   SmallVector<Operation *> &association = it->getSecond();
110   SmallVector<Operation *> updated;
111   updated.reserve(association.size());
112 
113   for (Operation *op : association) {
114     mappings.reverse.erase(op);
115     if (Operation *updatedOp = callback(op)) {
116       updated.push_back(updatedOp);
117       if (failed(tryEmplaceReverseMapping(mappings, updatedOp, value)))
118         return failure();
119     }
120   }
121 
122   std::swap(association, updated);
123   return success();
124 }
125 
126 void transform::TransformState::recordHandleInvalidation(OpOperand &handle) {
127   ArrayRef<Operation *> potentialAncestors = getPayloadOps(handle.get());
128   for (const Mappings &mapping : llvm::make_second_range(mappings)) {
129     for (const auto &kvp : mapping.reverse) {
130       // If the op is associated with invalidated handle, skip the check as it
131       // may be reading invalid IR.
132       Operation *op = kvp.first;
133       Value otherHandle = kvp.second;
134       if (invalidatedHandles.count(otherHandle))
135         continue;
136 
137       for (Operation *ancestor : potentialAncestors) {
138         if (!ancestor->isProperAncestor(op))
139           continue;
140 
141         // Make sure the error-reporting lambda doesn't capture anything
142         // by-reference because it will go out of scope. Additionally, extract
143         // location from Payload IR ops because the ops themselves may be
144         // deleted before the lambda gets called.
145         Location ancestorLoc = ancestor->getLoc();
146         Location opLoc = op->getLoc();
147         Operation *owner = handle.getOwner();
148         unsigned operandNo = handle.getOperandNumber();
149         invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
150                                            otherHandle]() {
151           InFlightDiagnostic diag =
152               owner->emitOpError()
153               << "invalidated the handle to payload operations nested in the "
154                  "payload operation associated with its operand #"
155               << operandNo;
156           diag.attachNote(ancestorLoc) << "ancestor op";
157           diag.attachNote(opLoc) << "nested op";
158           diag.attachNote(otherHandle.getLoc()) << "other handle";
159         };
160       }
161     }
162   }
163 }
164 
165 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
166     TransformOpInterface transform) {
167   auto memoryEffectsIface =
168       cast<MemoryEffectOpInterface>(transform.getOperation());
169   SmallVector<MemoryEffects::EffectInstance> effects;
170   memoryEffectsIface.getEffectsOnResource(
171       transform::TransformMappingResource::get(), effects);
172 
173   for (OpOperand &target : transform->getOpOperands()) {
174     // If the operand uses an invalidated handle, report it.
175     auto it = invalidatedHandles.find(target.get());
176     if (it != invalidatedHandles.end())
177       return it->getSecond()(), failure();
178 
179     // Invalidate handles pointing to the operations nested in the operation
180     // associated with the handle consumed by this operation.
181     auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
182       return isa<MemoryEffects::Free>(effect.getEffect()) &&
183              effect.getValue() == target.get();
184     };
185     if (llvm::find_if(effects, consumesTarget) != effects.end())
186       recordHandleInvalidation(target);
187   }
188   return success();
189 }
190 
191 DiagnosedSilenceableFailure
192 transform::TransformState::applyTransform(TransformOpInterface transform) {
193   LLVM_DEBUG(DBGS() << "applying: " << transform << "\n");
194   if (options.getExpensiveChecksEnabled() &&
195       failed(checkAndRecordHandleInvalidation(transform))) {
196     return DiagnosedSilenceableFailure::definiteFailure();
197   }
198 
199   transform::TransformResults results(transform->getNumResults());
200   DiagnosedSilenceableFailure result(transform.apply(results, *this));
201   if (!result.succeeded())
202     return result;
203 
204   // Remove the mapping for the operand if it is consumed by the operation. This
205   // allows us to catch use-after-free with assertions later on.
206   auto memEffectInterface =
207       cast<MemoryEffectOpInterface>(transform.getOperation());
208   SmallVector<MemoryEffects::EffectInstance, 2> effects;
209   for (OpOperand &target : transform->getOpOperands()) {
210     effects.clear();
211     memEffectInterface.getEffectsOnValue(target.get(), effects);
212     if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
213           return isa<transform::TransformMappingResource>(
214                      effect.getResource()) &&
215                  isa<MemoryEffects::Free>(effect.getEffect());
216         })) {
217       removePayloadOps(target.get());
218     }
219   }
220 
221   for (OpResult result : transform->getResults()) {
222     assert(result.getDefiningOp() == transform.getOperation() &&
223            "payload IR association for a value other than the result of the "
224            "current transform op");
225     if (failed(setPayloadOps(result, results.get(result.getResultNumber()))))
226       return DiagnosedSilenceableFailure::definiteFailure();
227   }
228 
229   return DiagnosedSilenceableFailure::success();
230 }
231 
232 //===----------------------------------------------------------------------===//
233 // TransformState::Extension
234 //===----------------------------------------------------------------------===//
235 
236 transform::TransformState::Extension::~Extension() = default;
237 
238 LogicalResult
239 transform::TransformState::Extension::replacePayloadOp(Operation *op,
240                                                        Operation *replacement) {
241   return state.updatePayloadOps(state.getHandleForPayloadOp(op),
242                                 [&](Operation *current) {
243                                   return current == op ? replacement : current;
244                                 });
245 }
246 
247 //===----------------------------------------------------------------------===//
248 // TransformResults
249 //===----------------------------------------------------------------------===//
250 
251 transform::TransformResults::TransformResults(unsigned numSegments) {
252   segments.resize(numSegments,
253                   ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
254 }
255 
256 void transform::TransformResults::set(OpResult value,
257                                       ArrayRef<Operation *> ops) {
258   unsigned position = value.getResultNumber();
259   assert(position < segments.size() &&
260          "setting results for a non-existent handle");
261   assert(segments[position].data() == nullptr && "results already set");
262   unsigned start = operations.size();
263   llvm::append_range(operations, ops);
264   segments[position] = makeArrayRef(operations).drop_front(start);
265 }
266 
267 ArrayRef<Operation *>
268 transform::TransformResults::get(unsigned resultNumber) const {
269   assert(resultNumber < segments.size() &&
270          "querying results for a non-existent handle");
271   assert(segments[resultNumber].data() != nullptr && "querying unset results");
272   return segments[resultNumber];
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // Utilities for PossibleTopLevelTransformOpTrait.
277 //===----------------------------------------------------------------------===//
278 
279 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
280     TransformState &state, Operation *op, Region &region) {
281   SmallVector<Operation *> targets;
282   if (op->getNumOperands() != 0)
283     llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
284   else
285     targets.push_back(state.getTopLevel());
286 
287   return state.mapBlockArguments(region.front().getArgument(0), targets);
288 }
289 
290 LogicalResult
291 transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
292   // Attaching this trait without the interface is a misuse of the API, but it
293   // cannot be caught via a static_assert because interface registration is
294   // dynamic.
295   assert(isa<TransformOpInterface>(op) &&
296          "should implement TransformOpInterface to have "
297          "PossibleTopLevelTransformOpTrait");
298 
299   if (op->getNumRegions() < 1)
300     return op->emitOpError() << "expects at least one region";
301 
302   Region *bodyRegion = &op->getRegion(0);
303   if (!llvm::hasNItems(*bodyRegion, 1))
304     return op->emitOpError() << "expects a single-block region";
305 
306   Block *body = &bodyRegion->front();
307   if (body->getNumArguments() != 1 ||
308       !body->getArgumentTypes()[0].isa<pdl::OperationType>()) {
309     return op->emitOpError()
310            << "expects the entry block to have one argument of type "
311            << pdl::OperationType::get(op->getContext());
312   }
313 
314   if (auto *parent =
315           op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
316     if (op->getNumOperands() == 0) {
317       InFlightDiagnostic diag =
318           op->emitOpError()
319           << "expects the root operation to be provided for a nested op";
320       diag.attachNote(parent->getLoc())
321           << "nested in another possible top-level op";
322       return diag;
323     }
324   }
325 
326   return success();
327 }
328 
329 //===----------------------------------------------------------------------===//
330 // Generated interface implementation.
331 //===----------------------------------------------------------------------===//
332 
333 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc"
334