1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
10 #include "mlir/IR/Diagnostics.h"
11 #include "mlir/IR/Operation.h"
12 #include "llvm/ADT/SmallPtrSet.h"
13 
14 using namespace mlir;
15 
16 //===----------------------------------------------------------------------===//
17 // TransformState
18 //===----------------------------------------------------------------------===//
19 
20 constexpr const Value transform::TransformState::kTopLevelValue;
21 
22 transform::TransformState::TransformState(Operation *root) {
23   operationMapping[kTopLevelValue].push_back(root);
24 }
25 
26 Operation *transform::TransformState::getTopLevel() const {
27   return operationMapping.lookup(kTopLevelValue).front();
28 }
29 
30 ArrayRef<Operation *>
31 transform::TransformState::getPayloadOps(Value value) const {
32   auto iter = operationMapping.find(value);
33   assert(iter != operationMapping.end() && "unknown handle");
34   return iter->getSecond();
35 }
36 
37 LogicalResult
38 transform::TransformState::setPayloadOps(Value value,
39                                          ArrayRef<Operation *> targets) {
40   assert(value != kTopLevelValue &&
41          "attempting to reset the transformation root");
42 
43   if (value.use_empty())
44     return success();
45 
46   // Setting new payload for the value without cleaning it first is a misuse of
47   // the API, assert here.
48   SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
49   bool inserted =
50       operationMapping.insert({value, std::move(storedTargets)}).second;
51   assert(inserted && "value is already associated with another list");
52   (void)inserted;
53 
54   // Having multiple handles to the same operation is an error in the transform
55   // expressed using the dialect and may be constructed by valid API calls from
56   // valid IR. Emit an error here.
57   for (Operation *op : targets) {
58     auto insertionResult = reverseMapping.insert({op, value});
59     if (!insertionResult.second) {
60       InFlightDiagnostic diag = op->emitError()
61                                 << "operation tracked by two handles";
62       diag.attachNote(value.getLoc()) << "handle";
63       diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
64       return diag;
65     }
66   }
67 
68   return success();
69 }
70 
71 void transform::TransformState::removePayloadOps(Value value) {
72   for (Operation *op : operationMapping[value])
73     reverseMapping.erase(op);
74   operationMapping.erase(value);
75 }
76 
77 void transform::TransformState::updatePayloadOps(
78     Value value, function_ref<Operation *(Operation *)> callback) {
79   auto it = operationMapping.find(value);
80   assert(it != operationMapping.end() && "unknown handle");
81   SmallVector<Operation *> &association = it->getSecond();
82   SmallVector<Operation *> updated;
83   updated.reserve(association.size());
84 
85   for (Operation *op : association)
86     if (Operation *updatedOp = callback(op))
87       updated.push_back(updatedOp);
88 
89   std::swap(association, updated);
90 }
91 
92 LogicalResult
93 transform::TransformState::applyTransform(TransformOpInterface transform) {
94   transform::TransformResults results(transform->getNumResults());
95   if (failed(transform.apply(results, *this)))
96     return failure();
97 
98   for (Value target : transform->getOperands())
99     removePayloadOps(target);
100 
101   for (auto &en : llvm::enumerate(transform->getResults()))
102     if (failed(setPayloadOps(en.value(), results.get(en.index()))))
103       return failure();
104 
105   return success();
106 }
107 
108 //===----------------------------------------------------------------------===//
109 // TransformResults
110 //===----------------------------------------------------------------------===//
111 
112 transform::TransformResults::TransformResults(unsigned numSegments) {
113   segments.resize(numSegments,
114                   ArrayRef<Operation *>(nullptr, static_cast<size_t>(0)));
115 }
116 
117 void transform::TransformResults::set(OpResult value,
118                                       ArrayRef<Operation *> ops) {
119   unsigned position = value.getResultNumber();
120   assert(position < segments.size() &&
121          "setting results for a non-existent handle");
122   assert(segments[position].data() == nullptr && "results already set");
123   unsigned start = operations.size();
124   llvm::append_range(operations, ops);
125   segments[position] = makeArrayRef(operations).drop_front(start);
126 }
127 
128 ArrayRef<Operation *>
129 transform::TransformResults::get(unsigned resultNumber) const {
130   assert(resultNumber < segments.size() &&
131          "querying results for a non-existent handle");
132   assert(segments[resultNumber].data() != nullptr && "querying unset results");
133   return segments[resultNumber];
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // Generated interface implementation.
138 //===----------------------------------------------------------------------===//
139 
140 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc"
141