1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
10 #include "mlir/IR/Diagnostics.h"
11 #include "mlir/IR/Operation.h"
12 #include "llvm/ADT/ScopeExit.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 
15 using namespace mlir;
16 
17 //===----------------------------------------------------------------------===//
18 // TransformState
19 //===----------------------------------------------------------------------===//
20 
21 constexpr const Value transform::TransformState::kTopLevelValue;
22 
23 transform::TransformState::TransformState(Region &region, Operation *root)
24     : topLevel(root) {
25   auto result = mappings.try_emplace(&region);
26   assert(result.second && "the region scope is already present");
27   (void)result;
28 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
29   regionStack.push_back(&region);
30 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
31 }
32 
33 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
34 
35 ArrayRef<Operation *>
36 transform::TransformState::getPayloadOps(Value value) const {
37   const TransformOpMapping &operationMapping = getMapping(value).direct;
38   auto iter = operationMapping.find(value);
39   assert(iter != operationMapping.end() && "unknown handle");
40   return iter->getSecond();
41 }
42 
43 LogicalResult
44 transform::TransformState::setPayloadOps(Value value,
45                                          ArrayRef<Operation *> targets) {
46   assert(value != kTopLevelValue &&
47          "attempting to reset the transformation root");
48 
49   if (value.use_empty())
50     return success();
51 
52   // Setting new payload for the value without cleaning it first is a misuse of
53   // the API, assert here.
54   SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
55   Mappings &mappings = getMapping(value);
56   bool inserted =
57       mappings.direct.insert({value, std::move(storedTargets)}).second;
58   assert(inserted && "value is already associated with another list");
59   (void)inserted;
60 
61   // Having multiple handles to the same operation is an error in the transform
62   // expressed using the dialect and may be constructed by valid API calls from
63   // valid IR. Emit an error here.
64   for (Operation *op : targets) {
65     auto insertionResult = mappings.reverse.insert({op, value});
66     if (!insertionResult.second) {
67       InFlightDiagnostic diag = op->emitError()
68                                 << "operation tracked by two handles";
69       diag.attachNote(value.getLoc()) << "handle";
70       diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
71       return diag;
72     }
73   }
74 
75   return success();
76 }
77 
78 void transform::TransformState::removePayloadOps(Value value) {
79   Mappings &mappings = getMapping(value);
80   for (Operation *op : mappings.direct[value])
81     mappings.reverse.erase(op);
82   mappings.direct.erase(value);
83 }
84 
85 void transform::TransformState::updatePayloadOps(
86     Value value, function_ref<Operation *(Operation *)> callback) {
87   auto it = getMapping(value).direct.find(value);
88   assert(it != getMapping(value).direct.end() && "unknown handle");
89   SmallVector<Operation *> &association = it->getSecond();
90   SmallVector<Operation *> updated;
91   updated.reserve(association.size());
92 
93   for (Operation *op : association)
94     if (Operation *updatedOp = callback(op))
95       updated.push_back(updatedOp);
96 
97   std::swap(association, updated);
98 }
99 
100 LogicalResult
101 transform::TransformState::applyTransform(TransformOpInterface transform) {
102   transform::TransformResults results(transform->getNumResults());
103   if (failed(transform.apply(results, *this)))
104     return failure();
105 
106   for (Value target : transform->getOperands())
107     removePayloadOps(target);
108 
109   for (auto &en : llvm::enumerate(transform->getResults())) {
110     assert(en.value().getDefiningOp() == transform.getOperation() &&
111            "payload IR association for a value other than the result of the "
112            "current transform op");
113     if (failed(setPayloadOps(en.value(), results.get(en.index()))))
114       return failure();
115   }
116 
117   return success();
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // TransformResults
122 //===----------------------------------------------------------------------===//
123 
124 transform::TransformResults::TransformResults(unsigned numSegments) {
125   segments.resize(numSegments,
126                   ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
127 }
128 
129 void transform::TransformResults::set(OpResult value,
130                                       ArrayRef<Operation *> ops) {
131   unsigned position = value.getResultNumber();
132   assert(position < segments.size() &&
133          "setting results for a non-existent handle");
134   assert(segments[position].data() == nullptr && "results already set");
135   unsigned start = operations.size();
136   llvm::append_range(operations, ops);
137   segments[position] = makeArrayRef(operations).drop_front(start);
138 }
139 
140 ArrayRef<Operation *>
141 transform::TransformResults::get(unsigned resultNumber) const {
142   assert(resultNumber < segments.size() &&
143          "querying results for a non-existent handle");
144   assert(segments[resultNumber].data() != nullptr && "querying unset results");
145   return segments[resultNumber];
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // Generated interface implementation.
150 //===----------------------------------------------------------------------===//
151 
152 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc"
153