1d064c480SAlex Zinenko //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2d064c480SAlex Zinenko //
3d064c480SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d064c480SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5d064c480SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d064c480SAlex Zinenko //
7d064c480SAlex Zinenko //===----------------------------------------------------------------------===//
8d064c480SAlex Zinenko 
9d064c480SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
1030f22429SAlex Zinenko #include "mlir/Dialect/PDL/IR/PDLTypes.h"
11d064c480SAlex Zinenko #include "mlir/IR/Diagnostics.h"
12d064c480SAlex Zinenko #include "mlir/IR/Operation.h"
13e3890b7fSAlex Zinenko #include "llvm/Support/Debug.h"
14e3890b7fSAlex Zinenko 
15e3890b7fSAlex Zinenko #define DEBUG_TYPE "transform-dialect"
16e3890b7fSAlex Zinenko #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
17d064c480SAlex Zinenko 
18d064c480SAlex Zinenko using namespace mlir;
19d064c480SAlex Zinenko 
20d064c480SAlex Zinenko //===----------------------------------------------------------------------===//
21d064c480SAlex Zinenko // TransformState
22d064c480SAlex Zinenko //===----------------------------------------------------------------------===//
23d064c480SAlex Zinenko 
24d064c480SAlex Zinenko constexpr const Value transform::TransformState::kTopLevelValue;
25d064c480SAlex Zinenko 
TransformState(Region & region,Operation * root,const TransformOptions & options)266403e1b1SAlex Zinenko transform::TransformState::TransformState(Region &region, Operation *root,
276403e1b1SAlex Zinenko                                           const TransformOptions &options)
286403e1b1SAlex Zinenko     : topLevel(root), options(options) {
290eb403adSAlex Zinenko   auto result = mappings.try_emplace(&region);
300eb403adSAlex Zinenko   assert(result.second && "the region scope is already present");
310eb403adSAlex Zinenko   (void)result;
320eb403adSAlex Zinenko #if LLVM_ENABLE_ABI_BREAKING_CHECKS
330eb403adSAlex Zinenko   regionStack.push_back(&region);
340eb403adSAlex Zinenko #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
35d064c480SAlex Zinenko }
36d064c480SAlex Zinenko 
getTopLevel() const370eb403adSAlex Zinenko Operation *transform::TransformState::getTopLevel() const { return topLevel; }
38d064c480SAlex Zinenko 
39d064c480SAlex Zinenko ArrayRef<Operation *>
getPayloadOps(Value value) const40d064c480SAlex Zinenko transform::TransformState::getPayloadOps(Value value) const {
410eb403adSAlex Zinenko   const TransformOpMapping &operationMapping = getMapping(value).direct;
42d064c480SAlex Zinenko   auto iter = operationMapping.find(value);
43d064c480SAlex Zinenko   assert(iter != operationMapping.end() && "unknown handle");
44d064c480SAlex Zinenko   return iter->getSecond();
45d064c480SAlex Zinenko }
46d064c480SAlex Zinenko 
getHandleForPayloadOp(Operation * op) const476c57b0deSAlex Zinenko Value transform::TransformState::getHandleForPayloadOp(Operation *op) const {
486c57b0deSAlex Zinenko   for (const Mappings &mapping : llvm::make_second_range(mappings)) {
496c57b0deSAlex Zinenko     if (Value handle = mapping.reverse.lookup(op))
506c57b0deSAlex Zinenko       return handle;
516c57b0deSAlex Zinenko   }
526c57b0deSAlex Zinenko   return Value();
536c57b0deSAlex Zinenko }
546c57b0deSAlex Zinenko 
tryEmplaceReverseMapping(Mappings & map,Operation * operation,Value handle)556c57b0deSAlex Zinenko LogicalResult transform::TransformState::tryEmplaceReverseMapping(
566c57b0deSAlex Zinenko     Mappings &map, Operation *operation, Value handle) {
576c57b0deSAlex Zinenko   auto insertionResult = map.reverse.insert({operation, handle});
5800d1a1a2SAlex Zinenko   if (!insertionResult.second && insertionResult.first->second != handle) {
596c57b0deSAlex Zinenko     InFlightDiagnostic diag = operation->emitError()
606c57b0deSAlex Zinenko                               << "operation tracked by two handles";
616c57b0deSAlex Zinenko     diag.attachNote(handle.getLoc()) << "handle";
626c57b0deSAlex Zinenko     diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
636c57b0deSAlex Zinenko     return diag;
646c57b0deSAlex Zinenko   }
656c57b0deSAlex Zinenko   return success();
666c57b0deSAlex Zinenko }
676c57b0deSAlex Zinenko 
68d064c480SAlex Zinenko LogicalResult
setPayloadOps(Value value,ArrayRef<Operation * > targets)69d064c480SAlex Zinenko transform::TransformState::setPayloadOps(Value value,
70d064c480SAlex Zinenko                                          ArrayRef<Operation *> targets) {
71d064c480SAlex Zinenko   assert(value != kTopLevelValue &&
72d064c480SAlex Zinenko          "attempting to reset the transformation root");
73d064c480SAlex Zinenko 
74d064c480SAlex Zinenko   if (value.use_empty())
75d064c480SAlex Zinenko     return success();
76d064c480SAlex Zinenko 
77d064c480SAlex Zinenko   // Setting new payload for the value without cleaning it first is a misuse of
78d064c480SAlex Zinenko   // the API, assert here.
79d064c480SAlex Zinenko   SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
800eb403adSAlex Zinenko   Mappings &mappings = getMapping(value);
81d064c480SAlex Zinenko   bool inserted =
820eb403adSAlex Zinenko       mappings.direct.insert({value, std::move(storedTargets)}).second;
83d064c480SAlex Zinenko   assert(inserted && "value is already associated with another list");
84d064c480SAlex Zinenko   (void)inserted;
85d064c480SAlex Zinenko 
86d064c480SAlex Zinenko   // Having multiple handles to the same operation is an error in the transform
87d064c480SAlex Zinenko   // expressed using the dialect and may be constructed by valid API calls from
88d064c480SAlex Zinenko   // valid IR. Emit an error here.
89d064c480SAlex Zinenko   for (Operation *op : targets) {
906c57b0deSAlex Zinenko     if (failed(tryEmplaceReverseMapping(mappings, op, value)))
916c57b0deSAlex Zinenko       return failure();
92d064c480SAlex Zinenko   }
93d064c480SAlex Zinenko 
94d064c480SAlex Zinenko   return success();
95d064c480SAlex Zinenko }
96d064c480SAlex Zinenko 
removePayloadOps(Value value)97d064c480SAlex Zinenko void transform::TransformState::removePayloadOps(Value value) {
980eb403adSAlex Zinenko   Mappings &mappings = getMapping(value);
990eb403adSAlex Zinenko   for (Operation *op : mappings.direct[value])
1000eb403adSAlex Zinenko     mappings.reverse.erase(op);
1010eb403adSAlex Zinenko   mappings.direct.erase(value);
102d064c480SAlex Zinenko }
103d064c480SAlex Zinenko 
updatePayloadOps(Value value,function_ref<Operation * (Operation *)> callback)1046c57b0deSAlex Zinenko LogicalResult transform::TransformState::updatePayloadOps(
105d064c480SAlex Zinenko     Value value, function_ref<Operation *(Operation *)> callback) {
1066c57b0deSAlex Zinenko   Mappings &mappings = getMapping(value);
1076c57b0deSAlex Zinenko   auto it = mappings.direct.find(value);
1086c57b0deSAlex Zinenko   assert(it != mappings.direct.end() && "unknown handle");
109d064c480SAlex Zinenko   SmallVector<Operation *> &association = it->getSecond();
110d064c480SAlex Zinenko   SmallVector<Operation *> updated;
111d064c480SAlex Zinenko   updated.reserve(association.size());
112d064c480SAlex Zinenko 
1136c57b0deSAlex Zinenko   for (Operation *op : association) {
1146c57b0deSAlex Zinenko     mappings.reverse.erase(op);
1156c57b0deSAlex Zinenko     if (Operation *updatedOp = callback(op)) {
116d064c480SAlex Zinenko       updated.push_back(updatedOp);
1176c57b0deSAlex Zinenko       if (failed(tryEmplaceReverseMapping(mappings, updatedOp, value)))
1186c57b0deSAlex Zinenko         return failure();
1196c57b0deSAlex Zinenko     }
1206c57b0deSAlex Zinenko   }
121d064c480SAlex Zinenko 
122d064c480SAlex Zinenko   std::swap(association, updated);
1236c57b0deSAlex Zinenko   return success();
124d064c480SAlex Zinenko }
125d064c480SAlex Zinenko 
recordHandleInvalidation(OpOperand & handle)1266403e1b1SAlex Zinenko void transform::TransformState::recordHandleInvalidation(OpOperand &handle) {
1276403e1b1SAlex Zinenko   ArrayRef<Operation *> potentialAncestors = getPayloadOps(handle.get());
1286403e1b1SAlex Zinenko   for (const Mappings &mapping : llvm::make_second_range(mappings)) {
1296403e1b1SAlex Zinenko     for (const auto &kvp : mapping.reverse) {
1306403e1b1SAlex Zinenko       // If the op is associated with invalidated handle, skip the check as it
1316403e1b1SAlex Zinenko       // may be reading invalid IR.
1326403e1b1SAlex Zinenko       Operation *op = kvp.first;
1336403e1b1SAlex Zinenko       Value otherHandle = kvp.second;
1346403e1b1SAlex Zinenko       if (invalidatedHandles.count(otherHandle))
1356403e1b1SAlex Zinenko         continue;
1366403e1b1SAlex Zinenko 
1376403e1b1SAlex Zinenko       for (Operation *ancestor : potentialAncestors) {
1386403e1b1SAlex Zinenko         if (!ancestor->isProperAncestor(op))
1396403e1b1SAlex Zinenko           continue;
1406403e1b1SAlex Zinenko 
1416403e1b1SAlex Zinenko         // Make sure the error-reporting lambda doesn't capture anything
1426403e1b1SAlex Zinenko         // by-reference because it will go out of scope. Additionally, extract
1436403e1b1SAlex Zinenko         // location from Payload IR ops because the ops themselves may be
1446403e1b1SAlex Zinenko         // deleted before the lambda gets called.
1456403e1b1SAlex Zinenko         Location ancestorLoc = ancestor->getLoc();
1466403e1b1SAlex Zinenko         Location opLoc = op->getLoc();
1476403e1b1SAlex Zinenko         Operation *owner = handle.getOwner();
1486403e1b1SAlex Zinenko         unsigned operandNo = handle.getOperandNumber();
1496403e1b1SAlex Zinenko         invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
1506403e1b1SAlex Zinenko                                            otherHandle]() {
1516403e1b1SAlex Zinenko           InFlightDiagnostic diag =
1526403e1b1SAlex Zinenko               owner->emitOpError()
1536403e1b1SAlex Zinenko               << "invalidated the handle to payload operations nested in the "
1546403e1b1SAlex Zinenko                  "payload operation associated with its operand #"
1556403e1b1SAlex Zinenko               << operandNo;
1566403e1b1SAlex Zinenko           diag.attachNote(ancestorLoc) << "ancestor op";
1576403e1b1SAlex Zinenko           diag.attachNote(opLoc) << "nested op";
1586403e1b1SAlex Zinenko           diag.attachNote(otherHandle.getLoc()) << "other handle";
1596403e1b1SAlex Zinenko         };
1606403e1b1SAlex Zinenko       }
1616403e1b1SAlex Zinenko     }
1626403e1b1SAlex Zinenko   }
1636403e1b1SAlex Zinenko }
1646403e1b1SAlex Zinenko 
checkAndRecordHandleInvalidation(TransformOpInterface transform)1656403e1b1SAlex Zinenko LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
1666403e1b1SAlex Zinenko     TransformOpInterface transform) {
1676403e1b1SAlex Zinenko   auto memoryEffectsIface =
1686403e1b1SAlex Zinenko       cast<MemoryEffectOpInterface>(transform.getOperation());
1696403e1b1SAlex Zinenko   SmallVector<MemoryEffects::EffectInstance> effects;
1706403e1b1SAlex Zinenko   memoryEffectsIface.getEffectsOnResource(
1716403e1b1SAlex Zinenko       transform::TransformMappingResource::get(), effects);
1726403e1b1SAlex Zinenko 
1736403e1b1SAlex Zinenko   for (OpOperand &target : transform->getOpOperands()) {
1746403e1b1SAlex Zinenko     // If the operand uses an invalidated handle, report it.
1756403e1b1SAlex Zinenko     auto it = invalidatedHandles.find(target.get());
1766403e1b1SAlex Zinenko     if (it != invalidatedHandles.end())
1776403e1b1SAlex Zinenko       return it->getSecond()(), failure();
1786403e1b1SAlex Zinenko 
1796403e1b1SAlex Zinenko     // Invalidate handles pointing to the operations nested in the operation
1806403e1b1SAlex Zinenko     // associated with the handle consumed by this operation.
1816403e1b1SAlex Zinenko     auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
1826403e1b1SAlex Zinenko       return isa<MemoryEffects::Free>(effect.getEffect()) &&
1836403e1b1SAlex Zinenko              effect.getValue() == target.get();
1846403e1b1SAlex Zinenko     };
185*9e88cbccSKazu Hirata     if (llvm::any_of(effects, consumesTarget))
1866403e1b1SAlex Zinenko       recordHandleInvalidation(target);
1876403e1b1SAlex Zinenko   }
1886403e1b1SAlex Zinenko   return success();
1896403e1b1SAlex Zinenko }
1906403e1b1SAlex Zinenko 
1911d45282aSAlex Zinenko DiagnosedSilenceableFailure
applyTransform(TransformOpInterface transform)192d064c480SAlex Zinenko transform::TransformState::applyTransform(TransformOpInterface transform) {
193e3890b7fSAlex Zinenko   LLVM_DEBUG(DBGS() << "applying: " << transform << "\n");
19400d1a1a2SAlex Zinenko   if (options.getExpensiveChecksEnabled()) {
19500d1a1a2SAlex Zinenko     if (failed(checkAndRecordHandleInvalidation(transform)))
1961d45282aSAlex Zinenko       return DiagnosedSilenceableFailure::definiteFailure();
19700d1a1a2SAlex Zinenko 
19800d1a1a2SAlex Zinenko     for (OpOperand &operand : transform->getOpOperands()) {
19900d1a1a2SAlex Zinenko       if (!isHandleConsumed(operand.get(), transform))
20000d1a1a2SAlex Zinenko         continue;
20100d1a1a2SAlex Zinenko 
20200d1a1a2SAlex Zinenko       DenseSet<Operation *> seen;
20300d1a1a2SAlex Zinenko       for (Operation *op : getPayloadOps(operand.get())) {
20400d1a1a2SAlex Zinenko         if (!seen.insert(op).second) {
20500d1a1a2SAlex Zinenko           DiagnosedSilenceableFailure diag =
20600d1a1a2SAlex Zinenko               transform.emitSilenceableError()
20700d1a1a2SAlex Zinenko               << "a handle passed as operand #" << operand.getOperandNumber()
20800d1a1a2SAlex Zinenko               << " and consumed by this operation points to a payload "
20900d1a1a2SAlex Zinenko                  "operation more than once";
21000d1a1a2SAlex Zinenko           diag.attachNote(op->getLoc()) << "repeated target op";
21100d1a1a2SAlex Zinenko           return diag;
21200d1a1a2SAlex Zinenko         }
21300d1a1a2SAlex Zinenko       }
21400d1a1a2SAlex Zinenko     }
2156403e1b1SAlex Zinenko   }
2166403e1b1SAlex Zinenko 
217d064c480SAlex Zinenko   transform::TransformResults results(transform->getNumResults());
2181d45282aSAlex Zinenko   DiagnosedSilenceableFailure result(transform.apply(results, *this));
219e3890b7fSAlex Zinenko   if (!result.succeeded())
220e3890b7fSAlex Zinenko     return result;
221d064c480SAlex Zinenko 
22240a8bd63SAlex Zinenko   // Remove the mapping for the operand if it is consumed by the operation. This
22340a8bd63SAlex Zinenko   // allows us to catch use-after-free with assertions later on.
22440a8bd63SAlex Zinenko   auto memEffectInterface =
22540a8bd63SAlex Zinenko       cast<MemoryEffectOpInterface>(transform.getOperation());
22640a8bd63SAlex Zinenko   SmallVector<MemoryEffects::EffectInstance, 2> effects;
2276403e1b1SAlex Zinenko   for (OpOperand &target : transform->getOpOperands()) {
22840a8bd63SAlex Zinenko     effects.clear();
2296403e1b1SAlex Zinenko     memEffectInterface.getEffectsOnValue(target.get(), effects);
23040a8bd63SAlex Zinenko     if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
23140a8bd63SAlex Zinenko           return isa<transform::TransformMappingResource>(
23240a8bd63SAlex Zinenko                      effect.getResource()) &&
23340a8bd63SAlex Zinenko                  isa<MemoryEffects::Free>(effect.getEffect());
23440a8bd63SAlex Zinenko         })) {
2356403e1b1SAlex Zinenko       removePayloadOps(target.get());
23640a8bd63SAlex Zinenko     }
23740a8bd63SAlex Zinenko   }
238d064c480SAlex Zinenko 
2396403e1b1SAlex Zinenko   for (OpResult result : transform->getResults()) {
2406403e1b1SAlex Zinenko     assert(result.getDefiningOp() == transform.getOperation() &&
2410eb403adSAlex Zinenko            "payload IR association for a value other than the result of the "
2420eb403adSAlex Zinenko            "current transform op");
2436403e1b1SAlex Zinenko     if (failed(setPayloadOps(result, results.get(result.getResultNumber()))))
2441d45282aSAlex Zinenko       return DiagnosedSilenceableFailure::definiteFailure();
2450eb403adSAlex Zinenko   }
246d064c480SAlex Zinenko 
2471d45282aSAlex Zinenko   return DiagnosedSilenceableFailure::success();
248d064c480SAlex Zinenko }
249d064c480SAlex Zinenko 
2506c57b0deSAlex Zinenko //===----------------------------------------------------------------------===//
2516c57b0deSAlex Zinenko // TransformState::Extension
2526c57b0deSAlex Zinenko //===----------------------------------------------------------------------===//
2536c57b0deSAlex Zinenko 
25430f22429SAlex Zinenko transform::TransformState::Extension::~Extension() = default;
25530f22429SAlex Zinenko 
2566c57b0deSAlex Zinenko LogicalResult
replacePayloadOp(Operation * op,Operation * replacement)2576c57b0deSAlex Zinenko transform::TransformState::Extension::replacePayloadOp(Operation *op,
2586c57b0deSAlex Zinenko                                                        Operation *replacement) {
2596c57b0deSAlex Zinenko   return state.updatePayloadOps(state.getHandleForPayloadOp(op),
2606c57b0deSAlex Zinenko                                 [&](Operation *current) {
2616c57b0deSAlex Zinenko                                   return current == op ? replacement : current;
2626c57b0deSAlex Zinenko                                 });
2636c57b0deSAlex Zinenko }
2646c57b0deSAlex Zinenko 
265d064c480SAlex Zinenko //===----------------------------------------------------------------------===//
266d064c480SAlex Zinenko // TransformResults
267d064c480SAlex Zinenko //===----------------------------------------------------------------------===//
268d064c480SAlex Zinenko 
TransformResults(unsigned numSegments)269d064c480SAlex Zinenko transform::TransformResults::TransformResults(unsigned numSegments) {
270d064c480SAlex Zinenko   segments.resize(numSegments,
271d064c480SAlex Zinenko                   ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
272d064c480SAlex Zinenko }
273d064c480SAlex Zinenko 
set(OpResult value,ArrayRef<Operation * > ops)274d064c480SAlex Zinenko void transform::TransformResults::set(OpResult value,
275d064c480SAlex Zinenko                                       ArrayRef<Operation *> ops) {
276d064c480SAlex Zinenko   unsigned position = value.getResultNumber();
277d064c480SAlex Zinenko   assert(position < segments.size() &&
278d064c480SAlex Zinenko          "setting results for a non-existent handle");
279d064c480SAlex Zinenko   assert(segments[position].data() == nullptr && "results already set");
280d064c480SAlex Zinenko   unsigned start = operations.size();
281d064c480SAlex Zinenko   llvm::append_range(operations, ops);
282d064c480SAlex Zinenko   segments[position] = makeArrayRef(operations).drop_front(start);
283d064c480SAlex Zinenko }
284d064c480SAlex Zinenko 
285d064c480SAlex Zinenko ArrayRef<Operation *>
get(unsigned resultNumber) const286d064c480SAlex Zinenko transform::TransformResults::get(unsigned resultNumber) const {
287d064c480SAlex Zinenko   assert(resultNumber < segments.size() &&
288d064c480SAlex Zinenko          "querying results for a non-existent handle");
289d064c480SAlex Zinenko   assert(segments[resultNumber].data() != nullptr && "querying unset results");
290d064c480SAlex Zinenko   return segments[resultNumber];
291d064c480SAlex Zinenko }
292d064c480SAlex Zinenko 
293d064c480SAlex Zinenko //===----------------------------------------------------------------------===//
29430f22429SAlex Zinenko // Utilities for PossibleTopLevelTransformOpTrait.
29530f22429SAlex Zinenko //===----------------------------------------------------------------------===//
29630f22429SAlex Zinenko 
mapPossibleTopLevelTransformOpBlockArguments(TransformState & state,Operation * op,Region & region)29730f22429SAlex Zinenko LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
2981d45282aSAlex Zinenko     TransformState &state, Operation *op, Region &region) {
29930f22429SAlex Zinenko   SmallVector<Operation *> targets;
30030f22429SAlex Zinenko   if (op->getNumOperands() != 0)
30130f22429SAlex Zinenko     llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
30230f22429SAlex Zinenko   else
30330f22429SAlex Zinenko     targets.push_back(state.getTopLevel());
30430f22429SAlex Zinenko 
3051d45282aSAlex Zinenko   return state.mapBlockArguments(region.front().getArgument(0), targets);
30630f22429SAlex Zinenko }
30730f22429SAlex Zinenko 
30830f22429SAlex Zinenko LogicalResult
verifyPossibleTopLevelTransformOpTrait(Operation * op)30930f22429SAlex Zinenko transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
31030f22429SAlex Zinenko   // Attaching this trait without the interface is a misuse of the API, but it
31130f22429SAlex Zinenko   // cannot be caught via a static_assert because interface registration is
31230f22429SAlex Zinenko   // dynamic.
31330f22429SAlex Zinenko   assert(isa<TransformOpInterface>(op) &&
31430f22429SAlex Zinenko          "should implement TransformOpInterface to have "
31530f22429SAlex Zinenko          "PossibleTopLevelTransformOpTrait");
31630f22429SAlex Zinenko 
317e3890b7fSAlex Zinenko   if (op->getNumRegions() < 1)
318e3890b7fSAlex Zinenko     return op->emitOpError() << "expects at least one region";
31930f22429SAlex Zinenko 
32030f22429SAlex Zinenko   Region *bodyRegion = &op->getRegion(0);
32130f22429SAlex Zinenko   if (!llvm::hasNItems(*bodyRegion, 1))
32230f22429SAlex Zinenko     return op->emitOpError() << "expects a single-block region";
32330f22429SAlex Zinenko 
32430f22429SAlex Zinenko   Block *body = &bodyRegion->front();
32530f22429SAlex Zinenko   if (body->getNumArguments() != 1 ||
32630f22429SAlex Zinenko       !body->getArgumentTypes()[0].isa<pdl::OperationType>()) {
32730f22429SAlex Zinenko     return op->emitOpError()
32830f22429SAlex Zinenko            << "expects the entry block to have one argument of type "
32930f22429SAlex Zinenko            << pdl::OperationType::get(op->getContext());
33030f22429SAlex Zinenko   }
33130f22429SAlex Zinenko 
33230f22429SAlex Zinenko   if (auto *parent =
33330f22429SAlex Zinenko           op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
33430f22429SAlex Zinenko     if (op->getNumOperands() == 0) {
33530f22429SAlex Zinenko       InFlightDiagnostic diag =
33630f22429SAlex Zinenko           op->emitOpError()
33730f22429SAlex Zinenko           << "expects the root operation to be provided for a nested op";
33830f22429SAlex Zinenko       diag.attachNote(parent->getLoc())
33930f22429SAlex Zinenko           << "nested in another possible top-level op";
34030f22429SAlex Zinenko       return diag;
34130f22429SAlex Zinenko     }
34230f22429SAlex Zinenko   }
34330f22429SAlex Zinenko 
34430f22429SAlex Zinenko   return success();
34530f22429SAlex Zinenko }
34630f22429SAlex Zinenko 
34730f22429SAlex Zinenko //===----------------------------------------------------------------------===//
34800d1a1a2SAlex Zinenko // Memory effects.
34900d1a1a2SAlex Zinenko //===----------------------------------------------------------------------===//
35000d1a1a2SAlex Zinenko 
consumesHandle(ValueRange handles,SmallVectorImpl<MemoryEffects::EffectInstance> & effects)35100d1a1a2SAlex Zinenko void transform::consumesHandle(
35200d1a1a2SAlex Zinenko     ValueRange handles,
35300d1a1a2SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
35400d1a1a2SAlex Zinenko   for (Value handle : handles) {
35500d1a1a2SAlex Zinenko     effects.emplace_back(MemoryEffects::Read::get(), handle,
35600d1a1a2SAlex Zinenko                          TransformMappingResource::get());
35700d1a1a2SAlex Zinenko     effects.emplace_back(MemoryEffects::Free::get(), handle,
35800d1a1a2SAlex Zinenko                          TransformMappingResource::get());
35900d1a1a2SAlex Zinenko   }
36000d1a1a2SAlex Zinenko }
36100d1a1a2SAlex Zinenko 
36200d1a1a2SAlex Zinenko /// Returns `true` if the given list of effects instances contains an instance
36300d1a1a2SAlex Zinenko /// with the effect type specified as template parameter.
364e15b855eSAlex Zinenko template <typename EffectTy, typename ResourceTy = SideEffects::DefaultResource>
hasEffect(ArrayRef<MemoryEffects::EffectInstance> effects)36500d1a1a2SAlex Zinenko static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> effects) {
36600d1a1a2SAlex Zinenko   return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
367e15b855eSAlex Zinenko     return isa<EffectTy>(effect.getEffect()) &&
368e15b855eSAlex Zinenko            isa<ResourceTy>(effect.getResource());
36900d1a1a2SAlex Zinenko   });
37000d1a1a2SAlex Zinenko }
37100d1a1a2SAlex Zinenko 
isHandleConsumed(Value handle,transform::TransformOpInterface transform)37200d1a1a2SAlex Zinenko bool transform::isHandleConsumed(Value handle,
37300d1a1a2SAlex Zinenko                                  transform::TransformOpInterface transform) {
37400d1a1a2SAlex Zinenko   auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
37500d1a1a2SAlex Zinenko   SmallVector<MemoryEffects::EffectInstance> effects;
37600d1a1a2SAlex Zinenko   iface.getEffectsOnValue(handle, effects);
377e15b855eSAlex Zinenko   return hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
378e15b855eSAlex Zinenko          hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
37900d1a1a2SAlex Zinenko }
38000d1a1a2SAlex Zinenko 
producesHandle(ValueRange handles,SmallVectorImpl<MemoryEffects::EffectInstance> & effects)38100d1a1a2SAlex Zinenko void transform::producesHandle(
38200d1a1a2SAlex Zinenko     ValueRange handles,
38300d1a1a2SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
38400d1a1a2SAlex Zinenko   for (Value handle : handles) {
38500d1a1a2SAlex Zinenko     effects.emplace_back(MemoryEffects::Allocate::get(), handle,
38600d1a1a2SAlex Zinenko                          TransformMappingResource::get());
38700d1a1a2SAlex Zinenko     effects.emplace_back(MemoryEffects::Write::get(), handle,
38800d1a1a2SAlex Zinenko                          TransformMappingResource::get());
38900d1a1a2SAlex Zinenko   }
39000d1a1a2SAlex Zinenko }
39100d1a1a2SAlex Zinenko 
onlyReadsHandle(ValueRange handles,SmallVectorImpl<MemoryEffects::EffectInstance> & effects)39200d1a1a2SAlex Zinenko void transform::onlyReadsHandle(
39300d1a1a2SAlex Zinenko     ValueRange handles,
39400d1a1a2SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
39500d1a1a2SAlex Zinenko   for (Value handle : handles) {
39600d1a1a2SAlex Zinenko     effects.emplace_back(MemoryEffects::Read::get(), handle,
39700d1a1a2SAlex Zinenko                          TransformMappingResource::get());
39800d1a1a2SAlex Zinenko   }
39900d1a1a2SAlex Zinenko }
40000d1a1a2SAlex Zinenko 
modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)40100d1a1a2SAlex Zinenko void transform::modifiesPayload(
40200d1a1a2SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
40300d1a1a2SAlex Zinenko   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
40400d1a1a2SAlex Zinenko   effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
40500d1a1a2SAlex Zinenko }
40600d1a1a2SAlex Zinenko 
onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)40700d1a1a2SAlex Zinenko void transform::onlyReadsPayload(
40800d1a1a2SAlex Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
40900d1a1a2SAlex Zinenko   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
41000d1a1a2SAlex Zinenko }
41100d1a1a2SAlex Zinenko 
41200d1a1a2SAlex Zinenko //===----------------------------------------------------------------------===//
413d064c480SAlex Zinenko // Generated interface implementation.
414d064c480SAlex Zinenko //===----------------------------------------------------------------------===//
415d064c480SAlex Zinenko 
416d064c480SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc"
417