1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
10 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/Operation.h"
13 #include "llvm/Support/Debug.h"
14
15 #define DEBUG_TYPE "transform-dialect"
16 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
17
18 using namespace mlir;
19
20 //===----------------------------------------------------------------------===//
21 // TransformState
22 //===----------------------------------------------------------------------===//
23
24 constexpr const Value transform::TransformState::kTopLevelValue;
25
TransformState(Region & region,Operation * root,const TransformOptions & options)26 transform::TransformState::TransformState(Region ®ion, Operation *root,
27 const TransformOptions &options)
28 : topLevel(root), options(options) {
29 auto result = mappings.try_emplace(®ion);
30 assert(result.second && "the region scope is already present");
31 (void)result;
32 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
33 regionStack.push_back(®ion);
34 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
35 }
36
getTopLevel() const37 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
38
39 ArrayRef<Operation *>
getPayloadOps(Value value) const40 transform::TransformState::getPayloadOps(Value value) const {
41 const TransformOpMapping &operationMapping = getMapping(value).direct;
42 auto iter = operationMapping.find(value);
43 assert(iter != operationMapping.end() && "unknown handle");
44 return iter->getSecond();
45 }
46
getHandleForPayloadOp(Operation * op) const47 Value transform::TransformState::getHandleForPayloadOp(Operation *op) const {
48 for (const Mappings &mapping : llvm::make_second_range(mappings)) {
49 if (Value handle = mapping.reverse.lookup(op))
50 return handle;
51 }
52 return Value();
53 }
54
tryEmplaceReverseMapping(Mappings & map,Operation * operation,Value handle)55 LogicalResult transform::TransformState::tryEmplaceReverseMapping(
56 Mappings &map, Operation *operation, Value handle) {
57 auto insertionResult = map.reverse.insert({operation, handle});
58 if (!insertionResult.second && insertionResult.first->second != handle) {
59 InFlightDiagnostic diag = operation->emitError()
60 << "operation tracked by two handles";
61 diag.attachNote(handle.getLoc()) << "handle";
62 diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
63 return diag;
64 }
65 return success();
66 }
67
68 LogicalResult
setPayloadOps(Value value,ArrayRef<Operation * > targets)69 transform::TransformState::setPayloadOps(Value value,
70 ArrayRef<Operation *> targets) {
71 assert(value != kTopLevelValue &&
72 "attempting to reset the transformation root");
73
74 if (value.use_empty())
75 return success();
76
77 // Setting new payload for the value without cleaning it first is a misuse of
78 // the API, assert here.
79 SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
80 Mappings &mappings = getMapping(value);
81 bool inserted =
82 mappings.direct.insert({value, std::move(storedTargets)}).second;
83 assert(inserted && "value is already associated with another list");
84 (void)inserted;
85
86 // Having multiple handles to the same operation is an error in the transform
87 // expressed using the dialect and may be constructed by valid API calls from
88 // valid IR. Emit an error here.
89 for (Operation *op : targets) {
90 if (failed(tryEmplaceReverseMapping(mappings, op, value)))
91 return failure();
92 }
93
94 return success();
95 }
96
removePayloadOps(Value value)97 void transform::TransformState::removePayloadOps(Value value) {
98 Mappings &mappings = getMapping(value);
99 for (Operation *op : mappings.direct[value])
100 mappings.reverse.erase(op);
101 mappings.direct.erase(value);
102 }
103
updatePayloadOps(Value value,function_ref<Operation * (Operation *)> callback)104 LogicalResult transform::TransformState::updatePayloadOps(
105 Value value, function_ref<Operation *(Operation *)> callback) {
106 Mappings &mappings = getMapping(value);
107 auto it = mappings.direct.find(value);
108 assert(it != mappings.direct.end() && "unknown handle");
109 SmallVector<Operation *> &association = it->getSecond();
110 SmallVector<Operation *> updated;
111 updated.reserve(association.size());
112
113 for (Operation *op : association) {
114 mappings.reverse.erase(op);
115 if (Operation *updatedOp = callback(op)) {
116 updated.push_back(updatedOp);
117 if (failed(tryEmplaceReverseMapping(mappings, updatedOp, value)))
118 return failure();
119 }
120 }
121
122 std::swap(association, updated);
123 return success();
124 }
125
recordHandleInvalidation(OpOperand & handle)126 void transform::TransformState::recordHandleInvalidation(OpOperand &handle) {
127 ArrayRef<Operation *> potentialAncestors = getPayloadOps(handle.get());
128 for (const Mappings &mapping : llvm::make_second_range(mappings)) {
129 for (const auto &kvp : mapping.reverse) {
130 // If the op is associated with invalidated handle, skip the check as it
131 // may be reading invalid IR.
132 Operation *op = kvp.first;
133 Value otherHandle = kvp.second;
134 if (invalidatedHandles.count(otherHandle))
135 continue;
136
137 for (Operation *ancestor : potentialAncestors) {
138 if (!ancestor->isProperAncestor(op))
139 continue;
140
141 // Make sure the error-reporting lambda doesn't capture anything
142 // by-reference because it will go out of scope. Additionally, extract
143 // location from Payload IR ops because the ops themselves may be
144 // deleted before the lambda gets called.
145 Location ancestorLoc = ancestor->getLoc();
146 Location opLoc = op->getLoc();
147 Operation *owner = handle.getOwner();
148 unsigned operandNo = handle.getOperandNumber();
149 invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
150 otherHandle]() {
151 InFlightDiagnostic diag =
152 owner->emitOpError()
153 << "invalidated the handle to payload operations nested in the "
154 "payload operation associated with its operand #"
155 << operandNo;
156 diag.attachNote(ancestorLoc) << "ancestor op";
157 diag.attachNote(opLoc) << "nested op";
158 diag.attachNote(otherHandle.getLoc()) << "other handle";
159 };
160 }
161 }
162 }
163 }
164
checkAndRecordHandleInvalidation(TransformOpInterface transform)165 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
166 TransformOpInterface transform) {
167 auto memoryEffectsIface =
168 cast<MemoryEffectOpInterface>(transform.getOperation());
169 SmallVector<MemoryEffects::EffectInstance> effects;
170 memoryEffectsIface.getEffectsOnResource(
171 transform::TransformMappingResource::get(), effects);
172
173 for (OpOperand &target : transform->getOpOperands()) {
174 // If the operand uses an invalidated handle, report it.
175 auto it = invalidatedHandles.find(target.get());
176 if (it != invalidatedHandles.end())
177 return it->getSecond()(), failure();
178
179 // Invalidate handles pointing to the operations nested in the operation
180 // associated with the handle consumed by this operation.
181 auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
182 return isa<MemoryEffects::Free>(effect.getEffect()) &&
183 effect.getValue() == target.get();
184 };
185 if (llvm::any_of(effects, consumesTarget))
186 recordHandleInvalidation(target);
187 }
188 return success();
189 }
190
191 DiagnosedSilenceableFailure
applyTransform(TransformOpInterface transform)192 transform::TransformState::applyTransform(TransformOpInterface transform) {
193 LLVM_DEBUG(DBGS() << "applying: " << transform << "\n");
194 if (options.getExpensiveChecksEnabled()) {
195 if (failed(checkAndRecordHandleInvalidation(transform)))
196 return DiagnosedSilenceableFailure::definiteFailure();
197
198 for (OpOperand &operand : transform->getOpOperands()) {
199 if (!isHandleConsumed(operand.get(), transform))
200 continue;
201
202 DenseSet<Operation *> seen;
203 for (Operation *op : getPayloadOps(operand.get())) {
204 if (!seen.insert(op).second) {
205 DiagnosedSilenceableFailure diag =
206 transform.emitSilenceableError()
207 << "a handle passed as operand #" << operand.getOperandNumber()
208 << " and consumed by this operation points to a payload "
209 "operation more than once";
210 diag.attachNote(op->getLoc()) << "repeated target op";
211 return diag;
212 }
213 }
214 }
215 }
216
217 transform::TransformResults results(transform->getNumResults());
218 DiagnosedSilenceableFailure result(transform.apply(results, *this));
219 if (!result.succeeded())
220 return result;
221
222 // Remove the mapping for the operand if it is consumed by the operation. This
223 // allows us to catch use-after-free with assertions later on.
224 auto memEffectInterface =
225 cast<MemoryEffectOpInterface>(transform.getOperation());
226 SmallVector<MemoryEffects::EffectInstance, 2> effects;
227 for (OpOperand &target : transform->getOpOperands()) {
228 effects.clear();
229 memEffectInterface.getEffectsOnValue(target.get(), effects);
230 if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
231 return isa<transform::TransformMappingResource>(
232 effect.getResource()) &&
233 isa<MemoryEffects::Free>(effect.getEffect());
234 })) {
235 removePayloadOps(target.get());
236 }
237 }
238
239 for (OpResult result : transform->getResults()) {
240 assert(result.getDefiningOp() == transform.getOperation() &&
241 "payload IR association for a value other than the result of the "
242 "current transform op");
243 if (failed(setPayloadOps(result, results.get(result.getResultNumber()))))
244 return DiagnosedSilenceableFailure::definiteFailure();
245 }
246
247 return DiagnosedSilenceableFailure::success();
248 }
249
250 //===----------------------------------------------------------------------===//
251 // TransformState::Extension
252 //===----------------------------------------------------------------------===//
253
254 transform::TransformState::Extension::~Extension() = default;
255
256 LogicalResult
replacePayloadOp(Operation * op,Operation * replacement)257 transform::TransformState::Extension::replacePayloadOp(Operation *op,
258 Operation *replacement) {
259 return state.updatePayloadOps(state.getHandleForPayloadOp(op),
260 [&](Operation *current) {
261 return current == op ? replacement : current;
262 });
263 }
264
265 //===----------------------------------------------------------------------===//
266 // TransformResults
267 //===----------------------------------------------------------------------===//
268
TransformResults(unsigned numSegments)269 transform::TransformResults::TransformResults(unsigned numSegments) {
270 segments.resize(numSegments,
271 ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
272 }
273
set(OpResult value,ArrayRef<Operation * > ops)274 void transform::TransformResults::set(OpResult value,
275 ArrayRef<Operation *> ops) {
276 unsigned position = value.getResultNumber();
277 assert(position < segments.size() &&
278 "setting results for a non-existent handle");
279 assert(segments[position].data() == nullptr && "results already set");
280 unsigned start = operations.size();
281 llvm::append_range(operations, ops);
282 segments[position] = makeArrayRef(operations).drop_front(start);
283 }
284
285 ArrayRef<Operation *>
get(unsigned resultNumber) const286 transform::TransformResults::get(unsigned resultNumber) const {
287 assert(resultNumber < segments.size() &&
288 "querying results for a non-existent handle");
289 assert(segments[resultNumber].data() != nullptr && "querying unset results");
290 return segments[resultNumber];
291 }
292
293 //===----------------------------------------------------------------------===//
294 // Utilities for PossibleTopLevelTransformOpTrait.
295 //===----------------------------------------------------------------------===//
296
mapPossibleTopLevelTransformOpBlockArguments(TransformState & state,Operation * op,Region & region)297 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
298 TransformState &state, Operation *op, Region ®ion) {
299 SmallVector<Operation *> targets;
300 if (op->getNumOperands() != 0)
301 llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
302 else
303 targets.push_back(state.getTopLevel());
304
305 return state.mapBlockArguments(region.front().getArgument(0), targets);
306 }
307
308 LogicalResult
verifyPossibleTopLevelTransformOpTrait(Operation * op)309 transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
310 // Attaching this trait without the interface is a misuse of the API, but it
311 // cannot be caught via a static_assert because interface registration is
312 // dynamic.
313 assert(isa<TransformOpInterface>(op) &&
314 "should implement TransformOpInterface to have "
315 "PossibleTopLevelTransformOpTrait");
316
317 if (op->getNumRegions() < 1)
318 return op->emitOpError() << "expects at least one region";
319
320 Region *bodyRegion = &op->getRegion(0);
321 if (!llvm::hasNItems(*bodyRegion, 1))
322 return op->emitOpError() << "expects a single-block region";
323
324 Block *body = &bodyRegion->front();
325 if (body->getNumArguments() != 1 ||
326 !body->getArgumentTypes()[0].isa<pdl::OperationType>()) {
327 return op->emitOpError()
328 << "expects the entry block to have one argument of type "
329 << pdl::OperationType::get(op->getContext());
330 }
331
332 if (auto *parent =
333 op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
334 if (op->getNumOperands() == 0) {
335 InFlightDiagnostic diag =
336 op->emitOpError()
337 << "expects the root operation to be provided for a nested op";
338 diag.attachNote(parent->getLoc())
339 << "nested in another possible top-level op";
340 return diag;
341 }
342 }
343
344 return success();
345 }
346
347 //===----------------------------------------------------------------------===//
348 // Memory effects.
349 //===----------------------------------------------------------------------===//
350
consumesHandle(ValueRange handles,SmallVectorImpl<MemoryEffects::EffectInstance> & effects)351 void transform::consumesHandle(
352 ValueRange handles,
353 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
354 for (Value handle : handles) {
355 effects.emplace_back(MemoryEffects::Read::get(), handle,
356 TransformMappingResource::get());
357 effects.emplace_back(MemoryEffects::Free::get(), handle,
358 TransformMappingResource::get());
359 }
360 }
361
362 /// Returns `true` if the given list of effects instances contains an instance
363 /// with the effect type specified as template parameter.
364 template <typename EffectTy, typename ResourceTy = SideEffects::DefaultResource>
hasEffect(ArrayRef<MemoryEffects::EffectInstance> effects)365 static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> effects) {
366 return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
367 return isa<EffectTy>(effect.getEffect()) &&
368 isa<ResourceTy>(effect.getResource());
369 });
370 }
371
isHandleConsumed(Value handle,transform::TransformOpInterface transform)372 bool transform::isHandleConsumed(Value handle,
373 transform::TransformOpInterface transform) {
374 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
375 SmallVector<MemoryEffects::EffectInstance> effects;
376 iface.getEffectsOnValue(handle, effects);
377 return hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
378 hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
379 }
380
producesHandle(ValueRange handles,SmallVectorImpl<MemoryEffects::EffectInstance> & effects)381 void transform::producesHandle(
382 ValueRange handles,
383 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
384 for (Value handle : handles) {
385 effects.emplace_back(MemoryEffects::Allocate::get(), handle,
386 TransformMappingResource::get());
387 effects.emplace_back(MemoryEffects::Write::get(), handle,
388 TransformMappingResource::get());
389 }
390 }
391
onlyReadsHandle(ValueRange handles,SmallVectorImpl<MemoryEffects::EffectInstance> & effects)392 void transform::onlyReadsHandle(
393 ValueRange handles,
394 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
395 for (Value handle : handles) {
396 effects.emplace_back(MemoryEffects::Read::get(), handle,
397 TransformMappingResource::get());
398 }
399 }
400
modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)401 void transform::modifiesPayload(
402 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
403 effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
404 effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
405 }
406
onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)407 void transform::onlyReadsPayload(
408 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
409 effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
410 }
411
412 //===----------------------------------------------------------------------===//
413 // Generated interface implementation.
414 //===----------------------------------------------------------------------===//
415
416 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc"
417