1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
10 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/Operation.h"
13 #include "llvm/ADT/ScopeExit.h"
14 #include "llvm/ADT/SmallPtrSet.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // TransformState
20 //===----------------------------------------------------------------------===//
21 
22 constexpr const Value transform::TransformState::kTopLevelValue;
23 
24 transform::TransformState::TransformState(Region &region, Operation *root)
25     : topLevel(root) {
26   auto result = mappings.try_emplace(&region);
27   assert(result.second && "the region scope is already present");
28   (void)result;
29 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
30   regionStack.push_back(&region);
31 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
32 }
33 
34 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
35 
36 ArrayRef<Operation *>
37 transform::TransformState::getPayloadOps(Value value) const {
38   const TransformOpMapping &operationMapping = getMapping(value).direct;
39   auto iter = operationMapping.find(value);
40   assert(iter != operationMapping.end() && "unknown handle");
41   return iter->getSecond();
42 }
43 
44 LogicalResult
45 transform::TransformState::setPayloadOps(Value value,
46                                          ArrayRef<Operation *> targets) {
47   assert(value != kTopLevelValue &&
48          "attempting to reset the transformation root");
49 
50   if (value.use_empty())
51     return success();
52 
53   // Setting new payload for the value without cleaning it first is a misuse of
54   // the API, assert here.
55   SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
56   Mappings &mappings = getMapping(value);
57   bool inserted =
58       mappings.direct.insert({value, std::move(storedTargets)}).second;
59   assert(inserted && "value is already associated with another list");
60   (void)inserted;
61 
62   // Having multiple handles to the same operation is an error in the transform
63   // expressed using the dialect and may be constructed by valid API calls from
64   // valid IR. Emit an error here.
65   for (Operation *op : targets) {
66     auto insertionResult = mappings.reverse.insert({op, value});
67     if (!insertionResult.second) {
68       InFlightDiagnostic diag = op->emitError()
69                                 << "operation tracked by two handles";
70       diag.attachNote(value.getLoc()) << "handle";
71       diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
72       return diag;
73     }
74   }
75 
76   return success();
77 }
78 
79 void transform::TransformState::removePayloadOps(Value value) {
80   Mappings &mappings = getMapping(value);
81   for (Operation *op : mappings.direct[value])
82     mappings.reverse.erase(op);
83   mappings.direct.erase(value);
84 }
85 
86 void transform::TransformState::updatePayloadOps(
87     Value value, function_ref<Operation *(Operation *)> callback) {
88   auto it = getMapping(value).direct.find(value);
89   assert(it != getMapping(value).direct.end() && "unknown handle");
90   SmallVector<Operation *> &association = it->getSecond();
91   SmallVector<Operation *> updated;
92   updated.reserve(association.size());
93 
94   for (Operation *op : association)
95     if (Operation *updatedOp = callback(op))
96       updated.push_back(updatedOp);
97 
98   std::swap(association, updated);
99 }
100 
101 LogicalResult
102 transform::TransformState::applyTransform(TransformOpInterface transform) {
103   transform::TransformResults results(transform->getNumResults());
104   if (failed(transform.apply(results, *this)))
105     return failure();
106 
107   for (Value target : transform->getOperands())
108     removePayloadOps(target);
109 
110   for (auto &en : llvm::enumerate(transform->getResults())) {
111     assert(en.value().getDefiningOp() == transform.getOperation() &&
112            "payload IR association for a value other than the result of the "
113            "current transform op");
114     if (failed(setPayloadOps(en.value(), results.get(en.index()))))
115       return failure();
116   }
117 
118   return success();
119 }
120 
121 transform::TransformState::Extension::~Extension() = default;
122 
123 //===----------------------------------------------------------------------===//
124 // TransformResults
125 //===----------------------------------------------------------------------===//
126 
127 transform::TransformResults::TransformResults(unsigned numSegments) {
128   segments.resize(numSegments,
129                   ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
130 }
131 
132 void transform::TransformResults::set(OpResult value,
133                                       ArrayRef<Operation *> ops) {
134   unsigned position = value.getResultNumber();
135   assert(position < segments.size() &&
136          "setting results for a non-existent handle");
137   assert(segments[position].data() == nullptr && "results already set");
138   unsigned start = operations.size();
139   llvm::append_range(operations, ops);
140   segments[position] = makeArrayRef(operations).drop_front(start);
141 }
142 
143 ArrayRef<Operation *>
144 transform::TransformResults::get(unsigned resultNumber) const {
145   assert(resultNumber < segments.size() &&
146          "querying results for a non-existent handle");
147   assert(segments[resultNumber].data() != nullptr && "querying unset results");
148   return segments[resultNumber];
149 }
150 
151 //===----------------------------------------------------------------------===//
152 // Utilities for PossibleTopLevelTransformOpTrait.
153 //===----------------------------------------------------------------------===//
154 
155 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
156     TransformState &state, Operation *op) {
157   SmallVector<Operation *> targets;
158   if (op->getNumOperands() != 0)
159     llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
160   else
161     targets.push_back(state.getTopLevel());
162 
163   return state.mapBlockArguments(op->getRegion(0).front().getArgument(0),
164                                  targets);
165 }
166 
167 LogicalResult
168 transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
169   // Attaching this trait without the interface is a misuse of the API, but it
170   // cannot be caught via a static_assert because interface registration is
171   // dynamic.
172   assert(isa<TransformOpInterface>(op) &&
173          "should implement TransformOpInterface to have "
174          "PossibleTopLevelTransformOpTrait");
175 
176   if (op->getNumRegions() != 1)
177     return op->emitOpError() << "expects one region";
178 
179   Region *bodyRegion = &op->getRegion(0);
180   if (!llvm::hasNItems(*bodyRegion, 1))
181     return op->emitOpError() << "expects a single-block region";
182 
183   Block *body = &bodyRegion->front();
184   if (body->getNumArguments() != 1 ||
185       !body->getArgumentTypes()[0].isa<pdl::OperationType>()) {
186     return op->emitOpError()
187            << "expects the entry block to have one argument of type "
188            << pdl::OperationType::get(op->getContext());
189   }
190 
191   if (auto *parent =
192           op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
193     if (op->getNumOperands() == 0) {
194       InFlightDiagnostic diag =
195           op->emitOpError()
196           << "expects the root operation to be provided for a nested op";
197       diag.attachNote(parent->getLoc())
198           << "nested in another possible top-level op";
199       return diag;
200     }
201   }
202 
203   return success();
204 }
205 
206 //===----------------------------------------------------------------------===//
207 // Generated interface implementation.
208 //===----------------------------------------------------------------------===//
209 
210 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc"
211