1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
10 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/Operation.h"
13 #include "llvm/ADT/ScopeExit.h"
14 #include "llvm/ADT/SmallPtrSet.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // TransformState
20 //===----------------------------------------------------------------------===//
21 
22 constexpr const Value transform::TransformState::kTopLevelValue;
23 
24 transform::TransformState::TransformState(Region &region, Operation *root)
25     : topLevel(root) {
26   auto result = mappings.try_emplace(&region);
27   assert(result.second && "the region scope is already present");
28   (void)result;
29 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
30   regionStack.push_back(&region);
31 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
32 }
33 
34 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
35 
36 ArrayRef<Operation *>
37 transform::TransformState::getPayloadOps(Value value) const {
38   const TransformOpMapping &operationMapping = getMapping(value).direct;
39   auto iter = operationMapping.find(value);
40   assert(iter != operationMapping.end() && "unknown handle");
41   return iter->getSecond();
42 }
43 
44 LogicalResult
45 transform::TransformState::setPayloadOps(Value value,
46                                          ArrayRef<Operation *> targets) {
47   assert(value != kTopLevelValue &&
48          "attempting to reset the transformation root");
49 
50   if (value.use_empty())
51     return success();
52 
53   // Setting new payload for the value without cleaning it first is a misuse of
54   // the API, assert here.
55   SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
56   Mappings &mappings = getMapping(value);
57   bool inserted =
58       mappings.direct.insert({value, std::move(storedTargets)}).second;
59   assert(inserted && "value is already associated with another list");
60   (void)inserted;
61 
62   // Having multiple handles to the same operation is an error in the transform
63   // expressed using the dialect and may be constructed by valid API calls from
64   // valid IR. Emit an error here.
65   for (Operation *op : targets) {
66     auto insertionResult = mappings.reverse.insert({op, value});
67     if (!insertionResult.second) {
68       InFlightDiagnostic diag = op->emitError()
69                                 << "operation tracked by two handles";
70       diag.attachNote(value.getLoc()) << "handle";
71       diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
72       return diag;
73     }
74   }
75 
76   return success();
77 }
78 
79 void transform::TransformState::removePayloadOps(Value value) {
80   Mappings &mappings = getMapping(value);
81   for (Operation *op : mappings.direct[value])
82     mappings.reverse.erase(op);
83   mappings.direct.erase(value);
84 }
85 
86 void transform::TransformState::updatePayloadOps(
87     Value value, function_ref<Operation *(Operation *)> callback) {
88   auto it = getMapping(value).direct.find(value);
89   assert(it != getMapping(value).direct.end() && "unknown handle");
90   SmallVector<Operation *> &association = it->getSecond();
91   SmallVector<Operation *> updated;
92   updated.reserve(association.size());
93 
94   for (Operation *op : association)
95     if (Operation *updatedOp = callback(op))
96       updated.push_back(updatedOp);
97 
98   std::swap(association, updated);
99 }
100 
101 LogicalResult
102 transform::TransformState::applyTransform(TransformOpInterface transform) {
103   transform::TransformResults results(transform->getNumResults());
104   if (failed(transform.apply(results, *this)))
105     return failure();
106 
107   // Remove the mapping for the operand if it is consumed by the operation. This
108   // allows us to catch use-after-free with assertions later on.
109   auto memEffectInterface =
110       cast<MemoryEffectOpInterface>(transform.getOperation());
111   SmallVector<MemoryEffects::EffectInstance, 2> effects;
112   for (Value target : transform->getOperands()) {
113     effects.clear();
114     memEffectInterface.getEffectsOnValue(target, effects);
115     if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
116           return isa<transform::TransformMappingResource>(
117                      effect.getResource()) &&
118                  isa<MemoryEffects::Free>(effect.getEffect());
119         })) {
120       removePayloadOps(target);
121     }
122   }
123 
124   for (auto &en : llvm::enumerate(transform->getResults())) {
125     assert(en.value().getDefiningOp() == transform.getOperation() &&
126            "payload IR association for a value other than the result of the "
127            "current transform op");
128     if (failed(setPayloadOps(en.value(), results.get(en.index()))))
129       return failure();
130   }
131 
132   return success();
133 }
134 
135 transform::TransformState::Extension::~Extension() = default;
136 
137 //===----------------------------------------------------------------------===//
138 // TransformResults
139 //===----------------------------------------------------------------------===//
140 
141 transform::TransformResults::TransformResults(unsigned numSegments) {
142   segments.resize(numSegments,
143                   ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
144 }
145 
146 void transform::TransformResults::set(OpResult value,
147                                       ArrayRef<Operation *> ops) {
148   unsigned position = value.getResultNumber();
149   assert(position < segments.size() &&
150          "setting results for a non-existent handle");
151   assert(segments[position].data() == nullptr && "results already set");
152   unsigned start = operations.size();
153   llvm::append_range(operations, ops);
154   segments[position] = makeArrayRef(operations).drop_front(start);
155 }
156 
157 ArrayRef<Operation *>
158 transform::TransformResults::get(unsigned resultNumber) const {
159   assert(resultNumber < segments.size() &&
160          "querying results for a non-existent handle");
161   assert(segments[resultNumber].data() != nullptr && "querying unset results");
162   return segments[resultNumber];
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // Utilities for PossibleTopLevelTransformOpTrait.
167 //===----------------------------------------------------------------------===//
168 
169 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
170     TransformState &state, Operation *op) {
171   SmallVector<Operation *> targets;
172   if (op->getNumOperands() != 0)
173     llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
174   else
175     targets.push_back(state.getTopLevel());
176 
177   return state.mapBlockArguments(op->getRegion(0).front().getArgument(0),
178                                  targets);
179 }
180 
181 LogicalResult
182 transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
183   // Attaching this trait without the interface is a misuse of the API, but it
184   // cannot be caught via a static_assert because interface registration is
185   // dynamic.
186   assert(isa<TransformOpInterface>(op) &&
187          "should implement TransformOpInterface to have "
188          "PossibleTopLevelTransformOpTrait");
189 
190   if (op->getNumRegions() != 1)
191     return op->emitOpError() << "expects one region";
192 
193   Region *bodyRegion = &op->getRegion(0);
194   if (!llvm::hasNItems(*bodyRegion, 1))
195     return op->emitOpError() << "expects a single-block region";
196 
197   Block *body = &bodyRegion->front();
198   if (body->getNumArguments() != 1 ||
199       !body->getArgumentTypes()[0].isa<pdl::OperationType>()) {
200     return op->emitOpError()
201            << "expects the entry block to have one argument of type "
202            << pdl::OperationType::get(op->getContext());
203   }
204 
205   if (auto *parent =
206           op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
207     if (op->getNumOperands() == 0) {
208       InFlightDiagnostic diag =
209           op->emitOpError()
210           << "expects the root operation to be provided for a nested op";
211       diag.attachNote(parent->getLoc())
212           << "nested in another possible top-level op";
213       return diag;
214     }
215   }
216 
217   return success();
218 }
219 
220 //===----------------------------------------------------------------------===//
221 // Generated interface implementation.
222 //===----------------------------------------------------------------------===//
223 
224 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc"
225