1 //===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/FunctionInterfaces.h"
10
11 using namespace mlir;
12
13 //===----------------------------------------------------------------------===//
14 // Tablegen Interface Definitions
15 //===----------------------------------------------------------------------===//
16
17 #include "mlir/IR/FunctionOpInterfaces.cpp.inc"
18
19 //===----------------------------------------------------------------------===//
20 // Function Arguments and Results.
21 //===----------------------------------------------------------------------===//
22
isEmptyAttrDict(Attribute attr)23 static bool isEmptyAttrDict(Attribute attr) {
24 return attr.cast<DictionaryAttr>().empty();
25 }
26
getArgAttrDict(Operation * op,unsigned index)27 DictionaryAttr mlir::function_interface_impl::getArgAttrDict(Operation *op,
28 unsigned index) {
29 ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
30 DictionaryAttr argAttrs =
31 attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
32 return argAttrs;
33 }
34
35 DictionaryAttr
getResultAttrDict(Operation * op,unsigned index)36 mlir::function_interface_impl::getResultAttrDict(Operation *op,
37 unsigned index) {
38 ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
39 DictionaryAttr resAttrs =
40 attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
41 return resAttrs;
42 }
43
setArgResAttrDict(Operation * op,StringRef attrName,unsigned numTotalIndices,unsigned index,DictionaryAttr attrs)44 void mlir::function_interface_impl::detail::setArgResAttrDict(
45 Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index,
46 DictionaryAttr attrs) {
47 ArrayAttr allAttrs = op->getAttrOfType<ArrayAttr>(attrName);
48 if (!allAttrs) {
49 if (attrs.empty())
50 return;
51
52 // If this attribute is not empty, we need to create a new attribute array.
53 SmallVector<Attribute, 8> newAttrs(numTotalIndices,
54 DictionaryAttr::get(op->getContext()));
55 newAttrs[index] = attrs;
56 op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
57 return;
58 }
59 // Check to see if the attribute is different from what we already have.
60 if (allAttrs[index] == attrs)
61 return;
62
63 // If it is, check to see if the attribute array would now contain only empty
64 // dictionaries.
65 ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
66 if (attrs.empty() &&
67 llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
68 llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) {
69 op->removeAttr(attrName);
70 return;
71 }
72
73 // Otherwise, create a new attribute array with the updated dictionary.
74 SmallVector<Attribute, 8> newAttrs(rawAttrArray.begin(), rawAttrArray.end());
75 newAttrs[index] = attrs;
76 op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
77 }
78
79 /// Set all of the argument or result attribute dictionaries for a function.
setAllArgResAttrDicts(Operation * op,StringRef attrName,ArrayRef<Attribute> attrs)80 static void setAllArgResAttrDicts(Operation *op, StringRef attrName,
81 ArrayRef<Attribute> attrs) {
82 if (llvm::all_of(attrs, isEmptyAttrDict))
83 op->removeAttr(attrName);
84 else
85 op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs));
86 }
87
setAllArgAttrDicts(Operation * op,ArrayRef<DictionaryAttr> attrs)88 void mlir::function_interface_impl::setAllArgAttrDicts(
89 Operation *op, ArrayRef<DictionaryAttr> attrs) {
90 setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
91 }
setAllArgAttrDicts(Operation * op,ArrayRef<Attribute> attrs)92 void mlir::function_interface_impl::setAllArgAttrDicts(
93 Operation *op, ArrayRef<Attribute> attrs) {
94 auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
95 return !attr ? DictionaryAttr::get(op->getContext()) : attr;
96 });
97 setAllArgResAttrDicts(op, getArgDictAttrName(),
98 llvm::to_vector<8>(wrappedAttrs));
99 }
100
setAllResultAttrDicts(Operation * op,ArrayRef<DictionaryAttr> attrs)101 void mlir::function_interface_impl::setAllResultAttrDicts(
102 Operation *op, ArrayRef<DictionaryAttr> attrs) {
103 setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
104 }
setAllResultAttrDicts(Operation * op,ArrayRef<Attribute> attrs)105 void mlir::function_interface_impl::setAllResultAttrDicts(
106 Operation *op, ArrayRef<Attribute> attrs) {
107 auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
108 return !attr ? DictionaryAttr::get(op->getContext()) : attr;
109 });
110 setAllArgResAttrDicts(op, getResultDictAttrName(),
111 llvm::to_vector<8>(wrappedAttrs));
112 }
113
insertFunctionArguments(Operation * op,ArrayRef<unsigned> argIndices,TypeRange argTypes,ArrayRef<DictionaryAttr> argAttrs,ArrayRef<Location> argLocs,unsigned originalNumArgs,Type newType)114 void mlir::function_interface_impl::insertFunctionArguments(
115 Operation *op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
116 ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs,
117 unsigned originalNumArgs, Type newType) {
118 assert(argIndices.size() == argTypes.size());
119 assert(argIndices.size() == argAttrs.size() || argAttrs.empty());
120 assert(argIndices.size() == argLocs.size());
121 if (argIndices.empty())
122 return;
123
124 // There are 3 things that need to be updated:
125 // - Function type.
126 // - Arg attrs.
127 // - Block arguments of entry block.
128 Block &entry = op->getRegion(0).front();
129
130 // Update the argument attributes of the function.
131 auto oldArgAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
132 if (oldArgAttrs || !argAttrs.empty()) {
133 SmallVector<DictionaryAttr, 4> newArgAttrs;
134 newArgAttrs.reserve(originalNumArgs + argIndices.size());
135 unsigned oldIdx = 0;
136 auto migrate = [&](unsigned untilIdx) {
137 if (!oldArgAttrs) {
138 newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx);
139 } else {
140 auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>();
141 newArgAttrs.append(oldArgAttrRange.begin() + oldIdx,
142 oldArgAttrRange.begin() + untilIdx);
143 }
144 oldIdx = untilIdx;
145 };
146 for (unsigned i = 0, e = argIndices.size(); i < e; ++i) {
147 migrate(argIndices[i]);
148 newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]);
149 }
150 migrate(originalNumArgs);
151 setAllArgAttrDicts(op, newArgAttrs);
152 }
153
154 // Update the function type and any entry block arguments.
155 op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
156 for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
157 entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
158 }
159
insertFunctionResults(Operation * op,ArrayRef<unsigned> resultIndices,TypeRange resultTypes,ArrayRef<DictionaryAttr> resultAttrs,unsigned originalNumResults,Type newType)160 void mlir::function_interface_impl::insertFunctionResults(
161 Operation *op, ArrayRef<unsigned> resultIndices, TypeRange resultTypes,
162 ArrayRef<DictionaryAttr> resultAttrs, unsigned originalNumResults,
163 Type newType) {
164 assert(resultIndices.size() == resultTypes.size());
165 assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
166 if (resultIndices.empty())
167 return;
168
169 // There are 2 things that need to be updated:
170 // - Function type.
171 // - Result attrs.
172
173 // Update the result attributes of the function.
174 auto oldResultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
175 if (oldResultAttrs || !resultAttrs.empty()) {
176 SmallVector<DictionaryAttr, 4> newResultAttrs;
177 newResultAttrs.reserve(originalNumResults + resultIndices.size());
178 unsigned oldIdx = 0;
179 auto migrate = [&](unsigned untilIdx) {
180 if (!oldResultAttrs) {
181 newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx);
182 } else {
183 auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>();
184 newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx,
185 oldResultAttrsRange.begin() + untilIdx);
186 }
187 oldIdx = untilIdx;
188 };
189 for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) {
190 migrate(resultIndices[i]);
191 newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{}
192 : resultAttrs[i]);
193 }
194 migrate(originalNumResults);
195 setAllResultAttrDicts(op, newResultAttrs);
196 }
197
198 // Update the function type.
199 op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
200 }
201
eraseFunctionArguments(Operation * op,const BitVector & argIndices,Type newType)202 void mlir::function_interface_impl::eraseFunctionArguments(
203 Operation *op, const BitVector &argIndices, Type newType) {
204 // There are 3 things that need to be updated:
205 // - Function type.
206 // - Arg attrs.
207 // - Block arguments of entry block.
208 Block &entry = op->getRegion(0).front();
209
210 // Update the argument attributes of the function.
211 if (auto argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName())) {
212 SmallVector<DictionaryAttr, 4> newArgAttrs;
213 newArgAttrs.reserve(argAttrs.size());
214 for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
215 if (!argIndices[i])
216 newArgAttrs.emplace_back(argAttrs[i].cast<DictionaryAttr>());
217 setAllArgAttrDicts(op, newArgAttrs);
218 }
219
220 // Update the function type and any entry block arguments.
221 op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
222 entry.eraseArguments(argIndices);
223 }
224
eraseFunctionResults(Operation * op,const BitVector & resultIndices,Type newType)225 void mlir::function_interface_impl::eraseFunctionResults(
226 Operation *op, const BitVector &resultIndices, Type newType) {
227 // There are 2 things that need to be updated:
228 // - Function type.
229 // - Result attrs.
230
231 // Update the result attributes of the function.
232 if (auto resAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName())) {
233 SmallVector<DictionaryAttr, 4> newResultAttrs;
234 newResultAttrs.reserve(resAttrs.size());
235 for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
236 if (!resultIndices[i])
237 newResultAttrs.emplace_back(resAttrs[i].cast<DictionaryAttr>());
238 setAllResultAttrDicts(op, newResultAttrs);
239 }
240
241 // Update the function type.
242 op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
243 }
244
insertTypesInto(TypeRange oldTypes,ArrayRef<unsigned> indices,TypeRange newTypes,SmallVectorImpl<Type> & storage)245 TypeRange mlir::function_interface_impl::insertTypesInto(
246 TypeRange oldTypes, ArrayRef<unsigned> indices, TypeRange newTypes,
247 SmallVectorImpl<Type> &storage) {
248 assert(indices.size() == newTypes.size() &&
249 "mismatch between indice and type count");
250 if (indices.empty())
251 return oldTypes;
252
253 auto fromIt = oldTypes.begin();
254 for (auto it : llvm::zip(indices, newTypes)) {
255 const auto toIt = oldTypes.begin() + std::get<0>(it);
256 storage.append(fromIt, toIt);
257 storage.push_back(std::get<1>(it));
258 fromIt = toIt;
259 }
260 storage.append(fromIt, oldTypes.end());
261 return storage;
262 }
263
filterTypesOut(TypeRange types,const BitVector & indices,SmallVectorImpl<Type> & storage)264 TypeRange mlir::function_interface_impl::filterTypesOut(
265 TypeRange types, const BitVector &indices, SmallVectorImpl<Type> &storage) {
266 if (indices.none())
267 return types;
268
269 for (unsigned i = 0, e = types.size(); i < e; ++i)
270 if (!indices[i])
271 storage.emplace_back(types[i]);
272 return storage;
273 }
274
275 //===----------------------------------------------------------------------===//
276 // Function type signature.
277 //===----------------------------------------------------------------------===//
278
setFunctionType(Operation * op,Type newType)279 void mlir::function_interface_impl::setFunctionType(Operation *op,
280 Type newType) {
281 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
282 unsigned oldNumArgs = funcOp.getNumArguments();
283 unsigned oldNumResults = funcOp.getNumResults();
284 op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
285 unsigned newNumArgs = funcOp.getNumArguments();
286 unsigned newNumResults = funcOp.getNumResults();
287
288 // Functor used to update the argument and result attributes of the function.
289 auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,
290 unsigned newCount, auto setAttrFn) {
291 if (oldCount == newCount)
292 return;
293 // The new type has no arguments/results, just drop the attribute.
294 if (newCount == 0) {
295 op->removeAttr(attrName);
296 return;
297 }
298 ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(attrName);
299 if (!attrs)
300 return;
301
302 // The new type has less arguments/results, take the first N attributes.
303 if (newCount < oldCount)
304 return setAttrFn(op, attrs.getValue().take_front(newCount));
305
306 // Otherwise, the new type has more arguments/results. Initialize the new
307 // arguments/results with empty attributes.
308 SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
309 newAttrs.resize(newCount);
310 setAttrFn(op, newAttrs);
311 };
312
313 // Update the argument and result attributes.
314 updateAttrFn(
315 getArgDictAttrName(), oldNumArgs, newNumArgs,
316 [&](Operation *op, auto &&attrs) { setAllArgAttrDicts(op, attrs); });
317 updateAttrFn(
318 getResultDictAttrName(), oldNumResults, newNumResults,
319 [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); });
320 }
321