1d064c480SAlex Zinenko //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2d064c480SAlex Zinenko //
3d064c480SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d064c480SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5d064c480SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d064c480SAlex Zinenko //
7d064c480SAlex Zinenko //===----------------------------------------------------------------------===//
8d064c480SAlex Zinenko
9d064c480SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
1030f22429SAlex Zinenko #include "mlir/Dialect/PDL/IR/PDLTypes.h"
11d064c480SAlex Zinenko #include "mlir/IR/Diagnostics.h"
12d064c480SAlex Zinenko #include "mlir/IR/Operation.h"
13e3890b7fSAlex Zinenko #include "llvm/Support/Debug.h"
14e3890b7fSAlex Zinenko
15e3890b7fSAlex Zinenko #define DEBUG_TYPE "transform-dialect"
16e3890b7fSAlex Zinenko #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
17d064c480SAlex Zinenko
18d064c480SAlex Zinenko using namespace mlir;
19d064c480SAlex Zinenko
20d064c480SAlex Zinenko //===----------------------------------------------------------------------===//
21d064c480SAlex Zinenko // TransformState
22d064c480SAlex Zinenko //===----------------------------------------------------------------------===//
23d064c480SAlex Zinenko
24d064c480SAlex Zinenko constexpr const Value transform::TransformState::kTopLevelValue;
25d064c480SAlex Zinenko
TransformState(Region & region,Operation * root,const TransformOptions & options)266403e1b1SAlex Zinenko transform::TransformState::TransformState(Region ®ion, Operation *root,
276403e1b1SAlex Zinenko const TransformOptions &options)
286403e1b1SAlex Zinenko : topLevel(root), options(options) {
290eb403adSAlex Zinenko auto result = mappings.try_emplace(®ion);
300eb403adSAlex Zinenko assert(result.second && "the region scope is already present");
310eb403adSAlex Zinenko (void)result;
320eb403adSAlex Zinenko #if LLVM_ENABLE_ABI_BREAKING_CHECKS
330eb403adSAlex Zinenko regionStack.push_back(®ion);
340eb403adSAlex Zinenko #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
35d064c480SAlex Zinenko }
36d064c480SAlex Zinenko
getTopLevel() const370eb403adSAlex Zinenko Operation *transform::TransformState::getTopLevel() const { return topLevel; }
38d064c480SAlex Zinenko
39d064c480SAlex Zinenko ArrayRef<Operation *>
getPayloadOps(Value value) const40d064c480SAlex Zinenko transform::TransformState::getPayloadOps(Value value) const {
410eb403adSAlex Zinenko const TransformOpMapping &operationMapping = getMapping(value).direct;
42d064c480SAlex Zinenko auto iter = operationMapping.find(value);
43d064c480SAlex Zinenko assert(iter != operationMapping.end() && "unknown handle");
44d064c480SAlex Zinenko return iter->getSecond();
45d064c480SAlex Zinenko }
46d064c480SAlex Zinenko
getHandleForPayloadOp(Operation * op) const476c57b0deSAlex Zinenko Value transform::TransformState::getHandleForPayloadOp(Operation *op) const {
486c57b0deSAlex Zinenko for (const Mappings &mapping : llvm::make_second_range(mappings)) {
496c57b0deSAlex Zinenko if (Value handle = mapping.reverse.lookup(op))
506c57b0deSAlex Zinenko return handle;
516c57b0deSAlex Zinenko }
526c57b0deSAlex Zinenko return Value();
536c57b0deSAlex Zinenko }
546c57b0deSAlex Zinenko
tryEmplaceReverseMapping(Mappings & map,Operation * operation,Value handle)556c57b0deSAlex Zinenko LogicalResult transform::TransformState::tryEmplaceReverseMapping(
566c57b0deSAlex Zinenko Mappings &map, Operation *operation, Value handle) {
576c57b0deSAlex Zinenko auto insertionResult = map.reverse.insert({operation, handle});
5800d1a1a2SAlex Zinenko if (!insertionResult.second && insertionResult.first->second != handle) {
596c57b0deSAlex Zinenko InFlightDiagnostic diag = operation->emitError()
606c57b0deSAlex Zinenko << "operation tracked by two handles";
616c57b0deSAlex Zinenko diag.attachNote(handle.getLoc()) << "handle";
626c57b0deSAlex Zinenko diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
636c57b0deSAlex Zinenko return diag;
646c57b0deSAlex Zinenko }
656c57b0deSAlex Zinenko return success();
666c57b0deSAlex Zinenko }
676c57b0deSAlex Zinenko
68d064c480SAlex Zinenko LogicalResult
setPayloadOps(Value value,ArrayRef<Operation * > targets)69d064c480SAlex Zinenko transform::TransformState::setPayloadOps(Value value,
70d064c480SAlex Zinenko ArrayRef<Operation *> targets) {
71d064c480SAlex Zinenko assert(value != kTopLevelValue &&
72d064c480SAlex Zinenko "attempting to reset the transformation root");
73d064c480SAlex Zinenko
74d064c480SAlex Zinenko if (value.use_empty())
75d064c480SAlex Zinenko return success();
76d064c480SAlex Zinenko
77d064c480SAlex Zinenko // Setting new payload for the value without cleaning it first is a misuse of
78d064c480SAlex Zinenko // the API, assert here.
79d064c480SAlex Zinenko SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
800eb403adSAlex Zinenko Mappings &mappings = getMapping(value);
81d064c480SAlex Zinenko bool inserted =
820eb403adSAlex Zinenko mappings.direct.insert({value, std::move(storedTargets)}).second;
83d064c480SAlex Zinenko assert(inserted && "value is already associated with another list");
84d064c480SAlex Zinenko (void)inserted;
85d064c480SAlex Zinenko
86d064c480SAlex Zinenko // Having multiple handles to the same operation is an error in the transform
87d064c480SAlex Zinenko // expressed using the dialect and may be constructed by valid API calls from
88d064c480SAlex Zinenko // valid IR. Emit an error here.
89d064c480SAlex Zinenko for (Operation *op : targets) {
906c57b0deSAlex Zinenko if (failed(tryEmplaceReverseMapping(mappings, op, value)))
916c57b0deSAlex Zinenko return failure();
92d064c480SAlex Zinenko }
93d064c480SAlex Zinenko
94d064c480SAlex Zinenko return success();
95d064c480SAlex Zinenko }
96d064c480SAlex Zinenko
removePayloadOps(Value value)97d064c480SAlex Zinenko void transform::TransformState::removePayloadOps(Value value) {
980eb403adSAlex Zinenko Mappings &mappings = getMapping(value);
990eb403adSAlex Zinenko for (Operation *op : mappings.direct[value])
1000eb403adSAlex Zinenko mappings.reverse.erase(op);
1010eb403adSAlex Zinenko mappings.direct.erase(value);
102d064c480SAlex Zinenko }
103d064c480SAlex Zinenko
updatePayloadOps(Value value,function_ref<Operation * (Operation *)> callback)1046c57b0deSAlex Zinenko LogicalResult transform::TransformState::updatePayloadOps(
105d064c480SAlex Zinenko Value value, function_ref<Operation *(Operation *)> callback) {
1066c57b0deSAlex Zinenko Mappings &mappings = getMapping(value);
1076c57b0deSAlex Zinenko auto it = mappings.direct.find(value);
1086c57b0deSAlex Zinenko assert(it != mappings.direct.end() && "unknown handle");
109d064c480SAlex Zinenko SmallVector<Operation *> &association = it->getSecond();
110d064c480SAlex Zinenko SmallVector<Operation *> updated;
111d064c480SAlex Zinenko updated.reserve(association.size());
112d064c480SAlex Zinenko
1136c57b0deSAlex Zinenko for (Operation *op : association) {
1146c57b0deSAlex Zinenko mappings.reverse.erase(op);
1156c57b0deSAlex Zinenko if (Operation *updatedOp = callback(op)) {
116d064c480SAlex Zinenko updated.push_back(updatedOp);
1176c57b0deSAlex Zinenko if (failed(tryEmplaceReverseMapping(mappings, updatedOp, value)))
1186c57b0deSAlex Zinenko return failure();
1196c57b0deSAlex Zinenko }
1206c57b0deSAlex Zinenko }
121d064c480SAlex Zinenko
122d064c480SAlex Zinenko std::swap(association, updated);
1236c57b0deSAlex Zinenko return success();
124d064c480SAlex Zinenko }
125d064c480SAlex Zinenko
recordHandleInvalidation(OpOperand & handle)1266403e1b1SAlex Zinenko void transform::TransformState::recordHandleInvalidation(OpOperand &handle) {
1276403e1b1SAlex Zinenko ArrayRef<Operation *> potentialAncestors = getPayloadOps(handle.get());
1286403e1b1SAlex Zinenko for (const Mappings &mapping : llvm::make_second_range(mappings)) {
1296403e1b1SAlex Zinenko for (const auto &kvp : mapping.reverse) {
1306403e1b1SAlex Zinenko // If the op is associated with invalidated handle, skip the check as it
1316403e1b1SAlex Zinenko // may be reading invalid IR.
1326403e1b1SAlex Zinenko Operation *op = kvp.first;
1336403e1b1SAlex Zinenko Value otherHandle = kvp.second;
1346403e1b1SAlex Zinenko if (invalidatedHandles.count(otherHandle))
1356403e1b1SAlex Zinenko continue;
1366403e1b1SAlex Zinenko
1376403e1b1SAlex Zinenko for (Operation *ancestor : potentialAncestors) {
1386403e1b1SAlex Zinenko if (!ancestor->isProperAncestor(op))
1396403e1b1SAlex Zinenko continue;
1406403e1b1SAlex Zinenko
1416403e1b1SAlex Zinenko // Make sure the error-reporting lambda doesn't capture anything
1426403e1b1SAlex Zinenko // by-reference because it will go out of scope. Additionally, extract
1436403e1b1SAlex Zinenko // location from Payload IR ops because the ops themselves may be
1446403e1b1SAlex Zinenko // deleted before the lambda gets called.
1456403e1b1SAlex Zinenko Location ancestorLoc = ancestor->getLoc();
1466403e1b1SAlex Zinenko Location opLoc = op->getLoc();
1476403e1b1SAlex Zinenko Operation *owner = handle.getOwner();
1486403e1b1SAlex Zinenko unsigned operandNo = handle.getOperandNumber();
1496403e1b1SAlex Zinenko invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
1506403e1b1SAlex Zinenko otherHandle]() {
1516403e1b1SAlex Zinenko InFlightDiagnostic diag =
1526403e1b1SAlex Zinenko owner->emitOpError()
1536403e1b1SAlex Zinenko << "invalidated the handle to payload operations nested in the "
1546403e1b1SAlex Zinenko "payload operation associated with its operand #"
1556403e1b1SAlex Zinenko << operandNo;
1566403e1b1SAlex Zinenko diag.attachNote(ancestorLoc) << "ancestor op";
1576403e1b1SAlex Zinenko diag.attachNote(opLoc) << "nested op";
1586403e1b1SAlex Zinenko diag.attachNote(otherHandle.getLoc()) << "other handle";
1596403e1b1SAlex Zinenko };
1606403e1b1SAlex Zinenko }
1616403e1b1SAlex Zinenko }
1626403e1b1SAlex Zinenko }
1636403e1b1SAlex Zinenko }
1646403e1b1SAlex Zinenko
checkAndRecordHandleInvalidation(TransformOpInterface transform)1656403e1b1SAlex Zinenko LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
1666403e1b1SAlex Zinenko TransformOpInterface transform) {
1676403e1b1SAlex Zinenko auto memoryEffectsIface =
1686403e1b1SAlex Zinenko cast<MemoryEffectOpInterface>(transform.getOperation());
1696403e1b1SAlex Zinenko SmallVector<MemoryEffects::EffectInstance> effects;
1706403e1b1SAlex Zinenko memoryEffectsIface.getEffectsOnResource(
1716403e1b1SAlex Zinenko transform::TransformMappingResource::get(), effects);
1726403e1b1SAlex Zinenko
1736403e1b1SAlex Zinenko for (OpOperand &target : transform->getOpOperands()) {
1746403e1b1SAlex Zinenko // If the operand uses an invalidated handle, report it.
1756403e1b1SAlex Zinenko auto it = invalidatedHandles.find(target.get());
1766403e1b1SAlex Zinenko if (it != invalidatedHandles.end())
1776403e1b1SAlex Zinenko return it->getSecond()(), failure();
1786403e1b1SAlex Zinenko
1796403e1b1SAlex Zinenko // Invalidate handles pointing to the operations nested in the operation
1806403e1b1SAlex Zinenko // associated with the handle consumed by this operation.
1816403e1b1SAlex Zinenko auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
1826403e1b1SAlex Zinenko return isa<MemoryEffects::Free>(effect.getEffect()) &&
1836403e1b1SAlex Zinenko effect.getValue() == target.get();
1846403e1b1SAlex Zinenko };
185*9e88cbccSKazu Hirata if (llvm::any_of(effects, consumesTarget))
1866403e1b1SAlex Zinenko recordHandleInvalidation(target);
1876403e1b1SAlex Zinenko }
1886403e1b1SAlex Zinenko return success();
1896403e1b1SAlex Zinenko }
1906403e1b1SAlex Zinenko
1911d45282aSAlex Zinenko DiagnosedSilenceableFailure
applyTransform(TransformOpInterface transform)192d064c480SAlex Zinenko transform::TransformState::applyTransform(TransformOpInterface transform) {
193e3890b7fSAlex Zinenko LLVM_DEBUG(DBGS() << "applying: " << transform << "\n");
19400d1a1a2SAlex Zinenko if (options.getExpensiveChecksEnabled()) {
19500d1a1a2SAlex Zinenko if (failed(checkAndRecordHandleInvalidation(transform)))
1961d45282aSAlex Zinenko return DiagnosedSilenceableFailure::definiteFailure();
19700d1a1a2SAlex Zinenko
19800d1a1a2SAlex Zinenko for (OpOperand &operand : transform->getOpOperands()) {
19900d1a1a2SAlex Zinenko if (!isHandleConsumed(operand.get(), transform))
20000d1a1a2SAlex Zinenko continue;
20100d1a1a2SAlex Zinenko
20200d1a1a2SAlex Zinenko DenseSet<Operation *> seen;
20300d1a1a2SAlex Zinenko for (Operation *op : getPayloadOps(operand.get())) {
20400d1a1a2SAlex Zinenko if (!seen.insert(op).second) {
20500d1a1a2SAlex Zinenko DiagnosedSilenceableFailure diag =
20600d1a1a2SAlex Zinenko transform.emitSilenceableError()
20700d1a1a2SAlex Zinenko << "a handle passed as operand #" << operand.getOperandNumber()
20800d1a1a2SAlex Zinenko << " and consumed by this operation points to a payload "
20900d1a1a2SAlex Zinenko "operation more than once";
21000d1a1a2SAlex Zinenko diag.attachNote(op->getLoc()) << "repeated target op";
21100d1a1a2SAlex Zinenko return diag;
21200d1a1a2SAlex Zinenko }
21300d1a1a2SAlex Zinenko }
21400d1a1a2SAlex Zinenko }
2156403e1b1SAlex Zinenko }
2166403e1b1SAlex Zinenko
217d064c480SAlex Zinenko transform::TransformResults results(transform->getNumResults());
2181d45282aSAlex Zinenko DiagnosedSilenceableFailure result(transform.apply(results, *this));
219e3890b7fSAlex Zinenko if (!result.succeeded())
220e3890b7fSAlex Zinenko return result;
221d064c480SAlex Zinenko
22240a8bd63SAlex Zinenko // Remove the mapping for the operand if it is consumed by the operation. This
22340a8bd63SAlex Zinenko // allows us to catch use-after-free with assertions later on.
22440a8bd63SAlex Zinenko auto memEffectInterface =
22540a8bd63SAlex Zinenko cast<MemoryEffectOpInterface>(transform.getOperation());
22640a8bd63SAlex Zinenko SmallVector<MemoryEffects::EffectInstance, 2> effects;
2276403e1b1SAlex Zinenko for (OpOperand &target : transform->getOpOperands()) {
22840a8bd63SAlex Zinenko effects.clear();
2296403e1b1SAlex Zinenko memEffectInterface.getEffectsOnValue(target.get(), effects);
23040a8bd63SAlex Zinenko if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
23140a8bd63SAlex Zinenko return isa<transform::TransformMappingResource>(
23240a8bd63SAlex Zinenko effect.getResource()) &&
23340a8bd63SAlex Zinenko isa<MemoryEffects::Free>(effect.getEffect());
23440a8bd63SAlex Zinenko })) {
2356403e1b1SAlex Zinenko removePayloadOps(target.get());
23640a8bd63SAlex Zinenko }
23740a8bd63SAlex Zinenko }
238d064c480SAlex Zinenko
2396403e1b1SAlex Zinenko for (OpResult result : transform->getResults()) {
2406403e1b1SAlex Zinenko assert(result.getDefiningOp() == transform.getOperation() &&
2410eb403adSAlex Zinenko "payload IR association for a value other than the result of the "
2420eb403adSAlex Zinenko "current transform op");
2436403e1b1SAlex Zinenko if (failed(setPayloadOps(result, results.get(result.getResultNumber()))))
2441d45282aSAlex Zinenko return DiagnosedSilenceableFailure::definiteFailure();
2450eb403adSAlex Zinenko }
246d064c480SAlex Zinenko
2471d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success();
248d064c480SAlex Zinenko }
249d064c480SAlex Zinenko
2506c57b0deSAlex Zinenko //===----------------------------------------------------------------------===//
2516c57b0deSAlex Zinenko // TransformState::Extension
2526c57b0deSAlex Zinenko //===----------------------------------------------------------------------===//
2536c57b0deSAlex Zinenko
25430f22429SAlex Zinenko transform::TransformState::Extension::~Extension() = default;
25530f22429SAlex Zinenko
2566c57b0deSAlex Zinenko LogicalResult
replacePayloadOp(Operation * op,Operation * replacement)2576c57b0deSAlex Zinenko transform::TransformState::Extension::replacePayloadOp(Operation *op,
2586c57b0deSAlex Zinenko Operation *replacement) {
2596c57b0deSAlex Zinenko return state.updatePayloadOps(state.getHandleForPayloadOp(op),
2606c57b0deSAlex Zinenko [&](Operation *current) {
2616c57b0deSAlex Zinenko return current == op ? replacement : current;
2626c57b0deSAlex Zinenko });
2636c57b0deSAlex Zinenko }
2646c57b0deSAlex Zinenko
265d064c480SAlex Zinenko //===----------------------------------------------------------------------===//
266d064c480SAlex Zinenko // TransformResults
267d064c480SAlex Zinenko //===----------------------------------------------------------------------===//
268d064c480SAlex Zinenko
TransformResults(unsigned numSegments)269d064c480SAlex Zinenko transform::TransformResults::TransformResults(unsigned numSegments) {
270d064c480SAlex Zinenko segments.resize(numSegments,
271d064c480SAlex Zinenko ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
272d064c480SAlex Zinenko }
273d064c480SAlex Zinenko
set(OpResult value,ArrayRef<Operation * > ops)274d064c480SAlex Zinenko void transform::TransformResults::set(OpResult value,
275d064c480SAlex Zinenko ArrayRef<Operation *> ops) {
276d064c480SAlex Zinenko unsigned position = value.getResultNumber();
277d064c480SAlex Zinenko assert(position < segments.size() &&
278d064c480SAlex Zinenko "setting results for a non-existent handle");
279d064c480SAlex Zinenko assert(segments[position].data() == nullptr && "results already set");
280d064c480SAlex Zinenko unsigned start = operations.size();
281d064c480SAlex Zinenko llvm::append_range(operations, ops);
282d064c480SAlex Zinenko segments[position] = makeArrayRef(operations).drop_front(start);
283d064c480SAlex Zinenko }
284d064c480SAlex Zinenko
285d064c480SAlex Zinenko ArrayRef<Operation *>
get(unsigned resultNumber) const286d064c480SAlex Zinenko transform::TransformResults::get(unsigned resultNumber) const {
287d064c480SAlex Zinenko assert(resultNumber < segments.size() &&
288d064c480SAlex Zinenko "querying results for a non-existent handle");
289d064c480SAlex Zinenko assert(segments[resultNumber].data() != nullptr && "querying unset results");
290d064c480SAlex Zinenko return segments[resultNumber];
291d064c480SAlex Zinenko }
292d064c480SAlex Zinenko
293d064c480SAlex Zinenko //===----------------------------------------------------------------------===//
29430f22429SAlex Zinenko // Utilities for PossibleTopLevelTransformOpTrait.
29530f22429SAlex Zinenko //===----------------------------------------------------------------------===//
29630f22429SAlex Zinenko
mapPossibleTopLevelTransformOpBlockArguments(TransformState & state,Operation * op,Region & region)29730f22429SAlex Zinenko LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
2981d45282aSAlex Zinenko TransformState &state, Operation *op, Region ®ion) {
29930f22429SAlex Zinenko SmallVector<Operation *> targets;
30030f22429SAlex Zinenko if (op->getNumOperands() != 0)
30130f22429SAlex Zinenko llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
30230f22429SAlex Zinenko else
30330f22429SAlex Zinenko targets.push_back(state.getTopLevel());
30430f22429SAlex Zinenko
3051d45282aSAlex Zinenko return state.mapBlockArguments(region.front().getArgument(0), targets);
30630f22429SAlex Zinenko }
30730f22429SAlex Zinenko
30830f22429SAlex Zinenko LogicalResult
verifyPossibleTopLevelTransformOpTrait(Operation * op)30930f22429SAlex Zinenko transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
31030f22429SAlex Zinenko // Attaching this trait without the interface is a misuse of the API, but it
31130f22429SAlex Zinenko // cannot be caught via a static_assert because interface registration is
31230f22429SAlex Zinenko // dynamic.
31330f22429SAlex Zinenko assert(isa<TransformOpInterface>(op) &&
31430f22429SAlex Zinenko "should implement TransformOpInterface to have "
31530f22429SAlex Zinenko "PossibleTopLevelTransformOpTrait");
31630f22429SAlex Zinenko
317e3890b7fSAlex Zinenko if (op->getNumRegions() < 1)
318e3890b7fSAlex Zinenko return op->emitOpError() << "expects at least one region";
31930f22429SAlex Zinenko
32030f22429SAlex Zinenko Region *bodyRegion = &op->getRegion(0);
32130f22429SAlex Zinenko if (!llvm::hasNItems(*bodyRegion, 1))
32230f22429SAlex Zinenko return op->emitOpError() << "expects a single-block region";
32330f22429SAlex Zinenko
32430f22429SAlex Zinenko Block *body = &bodyRegion->front();
32530f22429SAlex Zinenko if (body->getNumArguments() != 1 ||
32630f22429SAlex Zinenko !body->getArgumentTypes()[0].isa<pdl::OperationType>()) {
32730f22429SAlex Zinenko return op->emitOpError()
32830f22429SAlex Zinenko << "expects the entry block to have one argument of type "
32930f22429SAlex Zinenko << pdl::OperationType::get(op->getContext());
33030f22429SAlex Zinenko }
33130f22429SAlex Zinenko
33230f22429SAlex Zinenko if (auto *parent =
33330f22429SAlex Zinenko op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
33430f22429SAlex Zinenko if (op->getNumOperands() == 0) {
33530f22429SAlex Zinenko InFlightDiagnostic diag =
33630f22429SAlex Zinenko op->emitOpError()
33730f22429SAlex Zinenko << "expects the root operation to be provided for a nested op";
33830f22429SAlex Zinenko diag.attachNote(parent->getLoc())
33930f22429SAlex Zinenko << "nested in another possible top-level op";
34030f22429SAlex Zinenko return diag;
34130f22429SAlex Zinenko }
34230f22429SAlex Zinenko }
34330f22429SAlex Zinenko
34430f22429SAlex Zinenko return success();
34530f22429SAlex Zinenko }
34630f22429SAlex Zinenko
34730f22429SAlex Zinenko //===----------------------------------------------------------------------===//
34800d1a1a2SAlex Zinenko // Memory effects.
34900d1a1a2SAlex Zinenko //===----------------------------------------------------------------------===//
35000d1a1a2SAlex Zinenko
consumesHandle(ValueRange handles,SmallVectorImpl<MemoryEffects::EffectInstance> & effects)35100d1a1a2SAlex Zinenko void transform::consumesHandle(
35200d1a1a2SAlex Zinenko ValueRange handles,
35300d1a1a2SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
35400d1a1a2SAlex Zinenko for (Value handle : handles) {
35500d1a1a2SAlex Zinenko effects.emplace_back(MemoryEffects::Read::get(), handle,
35600d1a1a2SAlex Zinenko TransformMappingResource::get());
35700d1a1a2SAlex Zinenko effects.emplace_back(MemoryEffects::Free::get(), handle,
35800d1a1a2SAlex Zinenko TransformMappingResource::get());
35900d1a1a2SAlex Zinenko }
36000d1a1a2SAlex Zinenko }
36100d1a1a2SAlex Zinenko
36200d1a1a2SAlex Zinenko /// Returns `true` if the given list of effects instances contains an instance
36300d1a1a2SAlex Zinenko /// with the effect type specified as template parameter.
364e15b855eSAlex Zinenko template <typename EffectTy, typename ResourceTy = SideEffects::DefaultResource>
hasEffect(ArrayRef<MemoryEffects::EffectInstance> effects)36500d1a1a2SAlex Zinenko static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> effects) {
36600d1a1a2SAlex Zinenko return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
367e15b855eSAlex Zinenko return isa<EffectTy>(effect.getEffect()) &&
368e15b855eSAlex Zinenko isa<ResourceTy>(effect.getResource());
36900d1a1a2SAlex Zinenko });
37000d1a1a2SAlex Zinenko }
37100d1a1a2SAlex Zinenko
isHandleConsumed(Value handle,transform::TransformOpInterface transform)37200d1a1a2SAlex Zinenko bool transform::isHandleConsumed(Value handle,
37300d1a1a2SAlex Zinenko transform::TransformOpInterface transform) {
37400d1a1a2SAlex Zinenko auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
37500d1a1a2SAlex Zinenko SmallVector<MemoryEffects::EffectInstance> effects;
37600d1a1a2SAlex Zinenko iface.getEffectsOnValue(handle, effects);
377e15b855eSAlex Zinenko return hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
378e15b855eSAlex Zinenko hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
37900d1a1a2SAlex Zinenko }
38000d1a1a2SAlex Zinenko
producesHandle(ValueRange handles,SmallVectorImpl<MemoryEffects::EffectInstance> & effects)38100d1a1a2SAlex Zinenko void transform::producesHandle(
38200d1a1a2SAlex Zinenko ValueRange handles,
38300d1a1a2SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
38400d1a1a2SAlex Zinenko for (Value handle : handles) {
38500d1a1a2SAlex Zinenko effects.emplace_back(MemoryEffects::Allocate::get(), handle,
38600d1a1a2SAlex Zinenko TransformMappingResource::get());
38700d1a1a2SAlex Zinenko effects.emplace_back(MemoryEffects::Write::get(), handle,
38800d1a1a2SAlex Zinenko TransformMappingResource::get());
38900d1a1a2SAlex Zinenko }
39000d1a1a2SAlex Zinenko }
39100d1a1a2SAlex Zinenko
onlyReadsHandle(ValueRange handles,SmallVectorImpl<MemoryEffects::EffectInstance> & effects)39200d1a1a2SAlex Zinenko void transform::onlyReadsHandle(
39300d1a1a2SAlex Zinenko ValueRange handles,
39400d1a1a2SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
39500d1a1a2SAlex Zinenko for (Value handle : handles) {
39600d1a1a2SAlex Zinenko effects.emplace_back(MemoryEffects::Read::get(), handle,
39700d1a1a2SAlex Zinenko TransformMappingResource::get());
39800d1a1a2SAlex Zinenko }
39900d1a1a2SAlex Zinenko }
40000d1a1a2SAlex Zinenko
modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)40100d1a1a2SAlex Zinenko void transform::modifiesPayload(
40200d1a1a2SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
40300d1a1a2SAlex Zinenko effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
40400d1a1a2SAlex Zinenko effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
40500d1a1a2SAlex Zinenko }
40600d1a1a2SAlex Zinenko
onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)40700d1a1a2SAlex Zinenko void transform::onlyReadsPayload(
40800d1a1a2SAlex Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
40900d1a1a2SAlex Zinenko effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
41000d1a1a2SAlex Zinenko }
41100d1a1a2SAlex Zinenko
41200d1a1a2SAlex Zinenko //===----------------------------------------------------------------------===//
413d064c480SAlex Zinenko // Generated interface implementation.
414d064c480SAlex Zinenko //===----------------------------------------------------------------------===//
415d064c480SAlex Zinenko
416d064c480SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc"
417