1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
10 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/Operation.h"
13 #include "llvm/Support/Debug.h"
14 
15 #define DEBUG_TYPE "transform-dialect"
16 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
17 
18 using namespace mlir;
19 
20 //===----------------------------------------------------------------------===//
21 // TransformState
22 //===----------------------------------------------------------------------===//
23 
24 constexpr const Value transform::TransformState::kTopLevelValue;
25 
26 transform::TransformState::TransformState(Region &region, Operation *root,
27                                           const TransformOptions &options)
28     : topLevel(root), options(options) {
29   auto result = mappings.try_emplace(&region);
30   assert(result.second && "the region scope is already present");
31   (void)result;
32 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
33   regionStack.push_back(&region);
34 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
35 }
36 
37 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
38 
39 ArrayRef<Operation *>
40 transform::TransformState::getPayloadOps(Value value) const {
41   const TransformOpMapping &operationMapping = getMapping(value).direct;
42   auto iter = operationMapping.find(value);
43   assert(iter != operationMapping.end() && "unknown handle");
44   return iter->getSecond();
45 }
46 
47 Value transform::TransformState::getHandleForPayloadOp(Operation *op) const {
48   for (const Mappings &mapping : llvm::make_second_range(mappings)) {
49     if (Value handle = mapping.reverse.lookup(op))
50       return handle;
51   }
52   return Value();
53 }
54 
55 LogicalResult transform::TransformState::tryEmplaceReverseMapping(
56     Mappings &map, Operation *operation, Value handle) {
57   auto insertionResult = map.reverse.insert({operation, handle});
58   if (!insertionResult.second && insertionResult.first->second != handle) {
59     InFlightDiagnostic diag = operation->emitError()
60                               << "operation tracked by two handles";
61     diag.attachNote(handle.getLoc()) << "handle";
62     diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
63     return diag;
64   }
65   return success();
66 }
67 
68 LogicalResult
69 transform::TransformState::setPayloadOps(Value value,
70                                          ArrayRef<Operation *> targets) {
71   assert(value != kTopLevelValue &&
72          "attempting to reset the transformation root");
73 
74   if (value.use_empty())
75     return success();
76 
77   // Setting new payload for the value without cleaning it first is a misuse of
78   // the API, assert here.
79   SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
80   Mappings &mappings = getMapping(value);
81   bool inserted =
82       mappings.direct.insert({value, std::move(storedTargets)}).second;
83   assert(inserted && "value is already associated with another list");
84   (void)inserted;
85 
86   // Having multiple handles to the same operation is an error in the transform
87   // expressed using the dialect and may be constructed by valid API calls from
88   // valid IR. Emit an error here.
89   for (Operation *op : targets) {
90     if (failed(tryEmplaceReverseMapping(mappings, op, value)))
91       return failure();
92   }
93 
94   return success();
95 }
96 
97 void transform::TransformState::removePayloadOps(Value value) {
98   Mappings &mappings = getMapping(value);
99   for (Operation *op : mappings.direct[value])
100     mappings.reverse.erase(op);
101   mappings.direct.erase(value);
102 }
103 
104 LogicalResult transform::TransformState::updatePayloadOps(
105     Value value, function_ref<Operation *(Operation *)> callback) {
106   Mappings &mappings = getMapping(value);
107   auto it = mappings.direct.find(value);
108   assert(it != mappings.direct.end() && "unknown handle");
109   SmallVector<Operation *> &association = it->getSecond();
110   SmallVector<Operation *> updated;
111   updated.reserve(association.size());
112 
113   for (Operation *op : association) {
114     mappings.reverse.erase(op);
115     if (Operation *updatedOp = callback(op)) {
116       updated.push_back(updatedOp);
117       if (failed(tryEmplaceReverseMapping(mappings, updatedOp, value)))
118         return failure();
119     }
120   }
121 
122   std::swap(association, updated);
123   return success();
124 }
125 
126 void transform::TransformState::recordHandleInvalidation(OpOperand &handle) {
127   ArrayRef<Operation *> potentialAncestors = getPayloadOps(handle.get());
128   for (const Mappings &mapping : llvm::make_second_range(mappings)) {
129     for (const auto &kvp : mapping.reverse) {
130       // If the op is associated with invalidated handle, skip the check as it
131       // may be reading invalid IR.
132       Operation *op = kvp.first;
133       Value otherHandle = kvp.second;
134       if (invalidatedHandles.count(otherHandle))
135         continue;
136 
137       for (Operation *ancestor : potentialAncestors) {
138         if (!ancestor->isProperAncestor(op))
139           continue;
140 
141         // Make sure the error-reporting lambda doesn't capture anything
142         // by-reference because it will go out of scope. Additionally, extract
143         // location from Payload IR ops because the ops themselves may be
144         // deleted before the lambda gets called.
145         Location ancestorLoc = ancestor->getLoc();
146         Location opLoc = op->getLoc();
147         Operation *owner = handle.getOwner();
148         unsigned operandNo = handle.getOperandNumber();
149         invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
150                                            otherHandle]() {
151           InFlightDiagnostic diag =
152               owner->emitOpError()
153               << "invalidated the handle to payload operations nested in the "
154                  "payload operation associated with its operand #"
155               << operandNo;
156           diag.attachNote(ancestorLoc) << "ancestor op";
157           diag.attachNote(opLoc) << "nested op";
158           diag.attachNote(otherHandle.getLoc()) << "other handle";
159         };
160       }
161     }
162   }
163 }
164 
165 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
166     TransformOpInterface transform) {
167   auto memoryEffectsIface =
168       cast<MemoryEffectOpInterface>(transform.getOperation());
169   SmallVector<MemoryEffects::EffectInstance> effects;
170   memoryEffectsIface.getEffectsOnResource(
171       transform::TransformMappingResource::get(), effects);
172 
173   for (OpOperand &target : transform->getOpOperands()) {
174     // If the operand uses an invalidated handle, report it.
175     auto it = invalidatedHandles.find(target.get());
176     if (it != invalidatedHandles.end())
177       return it->getSecond()(), failure();
178 
179     // Invalidate handles pointing to the operations nested in the operation
180     // associated with the handle consumed by this operation.
181     auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
182       return isa<MemoryEffects::Free>(effect.getEffect()) &&
183              effect.getValue() == target.get();
184     };
185     if (llvm::any_of(effects, consumesTarget))
186       recordHandleInvalidation(target);
187   }
188   return success();
189 }
190 
191 DiagnosedSilenceableFailure
192 transform::TransformState::applyTransform(TransformOpInterface transform) {
193   LLVM_DEBUG(DBGS() << "applying: " << transform << "\n");
194   if (options.getExpensiveChecksEnabled()) {
195     if (failed(checkAndRecordHandleInvalidation(transform)))
196       return DiagnosedSilenceableFailure::definiteFailure();
197 
198     for (OpOperand &operand : transform->getOpOperands()) {
199       if (!isHandleConsumed(operand.get(), transform))
200         continue;
201 
202       DenseSet<Operation *> seen;
203       for (Operation *op : getPayloadOps(operand.get())) {
204         if (!seen.insert(op).second) {
205           DiagnosedSilenceableFailure diag =
206               transform.emitSilenceableError()
207               << "a handle passed as operand #" << operand.getOperandNumber()
208               << " and consumed by this operation points to a payload "
209                  "operation more than once";
210           diag.attachNote(op->getLoc()) << "repeated target op";
211           return diag;
212         }
213       }
214     }
215   }
216 
217   transform::TransformResults results(transform->getNumResults());
218   DiagnosedSilenceableFailure result(transform.apply(results, *this));
219   if (!result.succeeded())
220     return result;
221 
222   // Remove the mapping for the operand if it is consumed by the operation. This
223   // allows us to catch use-after-free with assertions later on.
224   auto memEffectInterface =
225       cast<MemoryEffectOpInterface>(transform.getOperation());
226   SmallVector<MemoryEffects::EffectInstance, 2> effects;
227   for (OpOperand &target : transform->getOpOperands()) {
228     effects.clear();
229     memEffectInterface.getEffectsOnValue(target.get(), effects);
230     if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
231           return isa<transform::TransformMappingResource>(
232                      effect.getResource()) &&
233                  isa<MemoryEffects::Free>(effect.getEffect());
234         })) {
235       removePayloadOps(target.get());
236     }
237   }
238 
239   for (OpResult result : transform->getResults()) {
240     assert(result.getDefiningOp() == transform.getOperation() &&
241            "payload IR association for a value other than the result of the "
242            "current transform op");
243     if (failed(setPayloadOps(result, results.get(result.getResultNumber()))))
244       return DiagnosedSilenceableFailure::definiteFailure();
245   }
246 
247   return DiagnosedSilenceableFailure::success();
248 }
249 
250 //===----------------------------------------------------------------------===//
251 // TransformState::Extension
252 //===----------------------------------------------------------------------===//
253 
254 transform::TransformState::Extension::~Extension() = default;
255 
256 LogicalResult
257 transform::TransformState::Extension::replacePayloadOp(Operation *op,
258                                                        Operation *replacement) {
259   return state.updatePayloadOps(state.getHandleForPayloadOp(op),
260                                 [&](Operation *current) {
261                                   return current == op ? replacement : current;
262                                 });
263 }
264 
265 //===----------------------------------------------------------------------===//
266 // TransformResults
267 //===----------------------------------------------------------------------===//
268 
269 transform::TransformResults::TransformResults(unsigned numSegments) {
270   segments.resize(numSegments,
271                   ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
272 }
273 
274 void transform::TransformResults::set(OpResult value,
275                                       ArrayRef<Operation *> ops) {
276   unsigned position = value.getResultNumber();
277   assert(position < segments.size() &&
278          "setting results for a non-existent handle");
279   assert(segments[position].data() == nullptr && "results already set");
280   unsigned start = operations.size();
281   llvm::append_range(operations, ops);
282   segments[position] = makeArrayRef(operations).drop_front(start);
283 }
284 
285 ArrayRef<Operation *>
286 transform::TransformResults::get(unsigned resultNumber) const {
287   assert(resultNumber < segments.size() &&
288          "querying results for a non-existent handle");
289   assert(segments[resultNumber].data() != nullptr && "querying unset results");
290   return segments[resultNumber];
291 }
292 
293 //===----------------------------------------------------------------------===//
294 // Utilities for PossibleTopLevelTransformOpTrait.
295 //===----------------------------------------------------------------------===//
296 
297 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
298     TransformState &state, Operation *op, Region &region) {
299   SmallVector<Operation *> targets;
300   if (op->getNumOperands() != 0)
301     llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
302   else
303     targets.push_back(state.getTopLevel());
304 
305   return state.mapBlockArguments(region.front().getArgument(0), targets);
306 }
307 
308 LogicalResult
309 transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
310   // Attaching this trait without the interface is a misuse of the API, but it
311   // cannot be caught via a static_assert because interface registration is
312   // dynamic.
313   assert(isa<TransformOpInterface>(op) &&
314          "should implement TransformOpInterface to have "
315          "PossibleTopLevelTransformOpTrait");
316 
317   if (op->getNumRegions() < 1)
318     return op->emitOpError() << "expects at least one region";
319 
320   Region *bodyRegion = &op->getRegion(0);
321   if (!llvm::hasNItems(*bodyRegion, 1))
322     return op->emitOpError() << "expects a single-block region";
323 
324   Block *body = &bodyRegion->front();
325   if (body->getNumArguments() != 1 ||
326       !body->getArgumentTypes()[0].isa<pdl::OperationType>()) {
327     return op->emitOpError()
328            << "expects the entry block to have one argument of type "
329            << pdl::OperationType::get(op->getContext());
330   }
331 
332   if (auto *parent =
333           op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
334     if (op->getNumOperands() == 0) {
335       InFlightDiagnostic diag =
336           op->emitOpError()
337           << "expects the root operation to be provided for a nested op";
338       diag.attachNote(parent->getLoc())
339           << "nested in another possible top-level op";
340       return diag;
341     }
342   }
343 
344   return success();
345 }
346 
347 //===----------------------------------------------------------------------===//
348 // Memory effects.
349 //===----------------------------------------------------------------------===//
350 
351 void transform::consumesHandle(
352     ValueRange handles,
353     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
354   for (Value handle : handles) {
355     effects.emplace_back(MemoryEffects::Read::get(), handle,
356                          TransformMappingResource::get());
357     effects.emplace_back(MemoryEffects::Free::get(), handle,
358                          TransformMappingResource::get());
359   }
360 }
361 
362 /// Returns `true` if the given list of effects instances contains an instance
363 /// with the effect type specified as template parameter.
364 template <typename EffectTy, typename ResourceTy = SideEffects::DefaultResource>
365 static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> effects) {
366   return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
367     return isa<EffectTy>(effect.getEffect()) &&
368            isa<ResourceTy>(effect.getResource());
369   });
370 }
371 
372 bool transform::isHandleConsumed(Value handle,
373                                  transform::TransformOpInterface transform) {
374   auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
375   SmallVector<MemoryEffects::EffectInstance> effects;
376   iface.getEffectsOnValue(handle, effects);
377   return hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
378          hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
379 }
380 
381 void transform::producesHandle(
382     ValueRange handles,
383     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
384   for (Value handle : handles) {
385     effects.emplace_back(MemoryEffects::Allocate::get(), handle,
386                          TransformMappingResource::get());
387     effects.emplace_back(MemoryEffects::Write::get(), handle,
388                          TransformMappingResource::get());
389   }
390 }
391 
392 void transform::onlyReadsHandle(
393     ValueRange handles,
394     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
395   for (Value handle : handles) {
396     effects.emplace_back(MemoryEffects::Read::get(), handle,
397                          TransformMappingResource::get());
398   }
399 }
400 
401 void transform::modifiesPayload(
402     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
403   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
404   effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
405 }
406 
407 void transform::onlyReadsPayload(
408     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
409   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // Generated interface implementation.
414 //===----------------------------------------------------------------------===//
415 
416 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc"
417