1 //===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/FunctionInterfaces.h"
10 #include "mlir/Support/LLVM.h"
11 #include "llvm/ADT/BitVector.h"
12 
13 using namespace mlir;
14 
15 /// Helper to call a callback once on each index in the range
16 /// [0, `totalIndices`), *except* for the indices given in `indices`.
17 /// `indices` is allowed to have duplicates and can be in any order.
18 inline static void iterateIndicesExcept(unsigned totalIndices,
19                                         ArrayRef<unsigned> indices,
20                                         function_ref<void(unsigned)> callback) {
21   llvm::BitVector skipIndices(totalIndices);
22   for (unsigned i : indices)
23     skipIndices.set(i);
24 
25   for (unsigned i = 0; i < totalIndices; ++i)
26     if (!skipIndices.test(i))
27       callback(i);
28 }
29 
30 //===----------------------------------------------------------------------===//
31 // Tablegen Interface Definitions
32 //===----------------------------------------------------------------------===//
33 
34 #include "mlir/IR/FunctionOpInterfaces.cpp.inc"
35 
36 //===----------------------------------------------------------------------===//
37 // Function Arguments and Results.
38 //===----------------------------------------------------------------------===//
39 
40 static bool isEmptyAttrDict(Attribute attr) {
41   return attr.cast<DictionaryAttr>().empty();
42 }
43 
44 DictionaryAttr mlir::function_interface_impl::getArgAttrDict(Operation *op,
45                                                              unsigned index) {
46   ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
47   DictionaryAttr argAttrs =
48       attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
49   return argAttrs;
50 }
51 
52 DictionaryAttr
53 mlir::function_interface_impl::getResultAttrDict(Operation *op,
54                                                  unsigned index) {
55   ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
56   DictionaryAttr resAttrs =
57       attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
58   return resAttrs;
59 }
60 
61 void mlir::function_interface_impl::detail::setArgResAttrDict(
62     Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index,
63     DictionaryAttr attrs) {
64   ArrayAttr allAttrs = op->getAttrOfType<ArrayAttr>(attrName);
65   if (!allAttrs) {
66     if (attrs.empty())
67       return;
68 
69     // If this attribute is not empty, we need to create a new attribute array.
70     SmallVector<Attribute, 8> newAttrs(numTotalIndices,
71                                        DictionaryAttr::get(op->getContext()));
72     newAttrs[index] = attrs;
73     op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
74     return;
75   }
76   // Check to see if the attribute is different from what we already have.
77   if (allAttrs[index] == attrs)
78     return;
79 
80   // If it is, check to see if the attribute array would now contain only empty
81   // dictionaries.
82   ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
83   if (attrs.empty() &&
84       llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
85       llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) {
86     op->removeAttr(attrName);
87     return;
88   }
89 
90   // Otherwise, create a new attribute array with the updated dictionary.
91   SmallVector<Attribute, 8> newAttrs(rawAttrArray.begin(), rawAttrArray.end());
92   newAttrs[index] = attrs;
93   op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
94 }
95 
96 /// Set all of the argument or result attribute dictionaries for a function.
97 static void setAllArgResAttrDicts(Operation *op, StringRef attrName,
98                                   ArrayRef<Attribute> attrs) {
99   if (llvm::all_of(attrs, isEmptyAttrDict))
100     op->removeAttr(attrName);
101   else
102     op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs));
103 }
104 
105 void mlir::function_interface_impl::setAllArgAttrDicts(
106     Operation *op, ArrayRef<DictionaryAttr> attrs) {
107   setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
108 }
109 void mlir::function_interface_impl::setAllArgAttrDicts(
110     Operation *op, ArrayRef<Attribute> attrs) {
111   auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
112     return !attr ? DictionaryAttr::get(op->getContext()) : attr;
113   });
114   setAllArgResAttrDicts(op, getArgDictAttrName(),
115                         llvm::to_vector<8>(wrappedAttrs));
116 }
117 
118 void mlir::function_interface_impl::setAllResultAttrDicts(
119     Operation *op, ArrayRef<DictionaryAttr> attrs) {
120   setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
121 }
122 void mlir::function_interface_impl::setAllResultAttrDicts(
123     Operation *op, ArrayRef<Attribute> attrs) {
124   auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
125     return !attr ? DictionaryAttr::get(op->getContext()) : attr;
126   });
127   setAllArgResAttrDicts(op, getResultDictAttrName(),
128                         llvm::to_vector<8>(wrappedAttrs));
129 }
130 
131 void mlir::function_interface_impl::insertFunctionArguments(
132     Operation *op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
133     ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Optional<Location>> argLocs,
134     unsigned originalNumArgs, Type newType) {
135   assert(argIndices.size() == argTypes.size());
136   assert(argIndices.size() == argAttrs.size() || argAttrs.empty());
137   assert(argIndices.size() == argLocs.size() || argLocs.empty());
138   if (argIndices.empty())
139     return;
140 
141   // There are 3 things that need to be updated:
142   // - Function type.
143   // - Arg attrs.
144   // - Block arguments of entry block.
145   Block &entry = op->getRegion(0).front();
146 
147   // Update the argument attributes of the function.
148   auto oldArgAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
149   if (oldArgAttrs || !argAttrs.empty()) {
150     SmallVector<DictionaryAttr, 4> newArgAttrs;
151     newArgAttrs.reserve(originalNumArgs + argIndices.size());
152     unsigned oldIdx = 0;
153     auto migrate = [&](unsigned untilIdx) {
154       if (!oldArgAttrs) {
155         newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx);
156       } else {
157         auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>();
158         newArgAttrs.append(oldArgAttrRange.begin() + oldIdx,
159                            oldArgAttrRange.begin() + untilIdx);
160       }
161       oldIdx = untilIdx;
162     };
163     for (unsigned i = 0, e = argIndices.size(); i < e; ++i) {
164       migrate(argIndices[i]);
165       newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]);
166     }
167     migrate(originalNumArgs);
168     setAllArgAttrDicts(op, newArgAttrs);
169   }
170 
171   // Update the function type and any entry block arguments.
172   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
173   for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
174     entry.insertArgument(argIndices[i] + i, argTypes[i],
175                          argLocs.empty() ? Optional<Location>{} : argLocs[i]);
176 }
177 
178 void mlir::function_interface_impl::insertFunctionResults(
179     Operation *op, ArrayRef<unsigned> resultIndices, TypeRange resultTypes,
180     ArrayRef<DictionaryAttr> resultAttrs, unsigned originalNumResults,
181     Type newType) {
182   assert(resultIndices.size() == resultTypes.size());
183   assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
184   if (resultIndices.empty())
185     return;
186 
187   // There are 2 things that need to be updated:
188   // - Function type.
189   // - Result attrs.
190 
191   // Update the result attributes of the function.
192   auto oldResultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
193   if (oldResultAttrs || !resultAttrs.empty()) {
194     SmallVector<DictionaryAttr, 4> newResultAttrs;
195     newResultAttrs.reserve(originalNumResults + resultIndices.size());
196     unsigned oldIdx = 0;
197     auto migrate = [&](unsigned untilIdx) {
198       if (!oldResultAttrs) {
199         newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx);
200       } else {
201         auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>();
202         newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx,
203                               oldResultAttrsRange.begin() + untilIdx);
204       }
205       oldIdx = untilIdx;
206     };
207     for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) {
208       migrate(resultIndices[i]);
209       newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{}
210                                                    : resultAttrs[i]);
211     }
212     migrate(originalNumResults);
213     setAllResultAttrDicts(op, newResultAttrs);
214   }
215 
216   // Update the function type.
217   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
218 }
219 
220 void mlir::function_interface_impl::eraseFunctionArguments(
221     Operation *op, ArrayRef<unsigned> argIndices, unsigned originalNumArgs,
222     Type newType) {
223   // There are 3 things that need to be updated:
224   // - Function type.
225   // - Arg attrs.
226   // - Block arguments of entry block.
227   Block &entry = op->getRegion(0).front();
228 
229   // Update the argument attributes of the function.
230   if (auto argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName())) {
231     SmallVector<DictionaryAttr, 4> newArgAttrs;
232     newArgAttrs.reserve(argAttrs.size());
233     iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
234       newArgAttrs.emplace_back(argAttrs[i].cast<DictionaryAttr>());
235     });
236     setAllArgAttrDicts(op, newArgAttrs);
237   }
238 
239   // Update the function type and any entry block arguments.
240   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
241   entry.eraseArguments(argIndices);
242 }
243 
244 void mlir::function_interface_impl::eraseFunctionResults(
245     Operation *op, ArrayRef<unsigned> resultIndices,
246     unsigned originalNumResults, Type newType) {
247   // There are 2 things that need to be updated:
248   // - Function type.
249   // - Result attrs.
250 
251   // Update the result attributes of the function.
252   if (auto resAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName())) {
253     SmallVector<DictionaryAttr, 4> newResultAttrs;
254     newResultAttrs.reserve(resAttrs.size());
255     iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
256       newResultAttrs.emplace_back(resAttrs[i].cast<DictionaryAttr>());
257     });
258     setAllResultAttrDicts(op, newResultAttrs);
259   }
260 
261   // Update the function type.
262   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
263 }
264 
265 TypeRange mlir::function_interface_impl::insertTypesInto(
266     TypeRange oldTypes, ArrayRef<unsigned> indices, TypeRange newTypes,
267     SmallVectorImpl<Type> &storage) {
268   assert(indices.size() == newTypes.size() &&
269          "mismatch between indice and type count");
270   if (indices.empty())
271     return oldTypes;
272 
273   auto fromIt = oldTypes.begin();
274   for (auto it : llvm::zip(indices, newTypes)) {
275     const auto toIt = oldTypes.begin() + std::get<0>(it);
276     storage.append(fromIt, toIt);
277     storage.push_back(std::get<1>(it));
278     fromIt = toIt;
279   }
280   storage.append(fromIt, oldTypes.end());
281   return storage;
282 }
283 
284 TypeRange
285 mlir::function_interface_impl::filterTypesOut(TypeRange types,
286                                               ArrayRef<unsigned> indices,
287                                               SmallVectorImpl<Type> &storage) {
288   if (indices.empty())
289     return types;
290   iterateIndicesExcept(types.size(), indices,
291                        [&](unsigned i) { storage.emplace_back(types[i]); });
292   return storage;
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // Function type signature.
297 //===----------------------------------------------------------------------===//
298 
299 void mlir::function_interface_impl::setFunctionType(Operation *op,
300                                                     Type newType) {
301   FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
302   unsigned oldNumArgs = funcOp.getNumArguments();
303   unsigned oldNumResults = funcOp.getNumResults();
304   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
305   unsigned newNumArgs = funcOp.getNumArguments();
306   unsigned newNumResults = funcOp.getNumResults();
307 
308   // Functor used to update the argument and result attributes of the function.
309   auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,
310                           unsigned newCount, auto setAttrFn) {
311     if (oldCount == newCount)
312       return;
313     // The new type has no arguments/results, just drop the attribute.
314     if (newCount == 0) {
315       op->removeAttr(attrName);
316       return;
317     }
318     ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(attrName);
319     if (!attrs)
320       return;
321 
322     // The new type has less arguments/results, take the first N attributes.
323     if (newCount < oldCount)
324       return setAttrFn(op, attrs.getValue().take_front(newCount));
325 
326     // Otherwise, the new type has more arguments/results. Initialize the new
327     // arguments/results with empty attributes.
328     SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
329     newAttrs.resize(newCount);
330     setAttrFn(op, newAttrs);
331   };
332 
333   // Update the argument and result attributes.
334   updateAttrFn(
335       getArgDictAttrName(), oldNumArgs, newNumArgs,
336       [&](Operation *op, auto &&attrs) { setAllArgAttrDicts(op, attrs); });
337   updateAttrFn(
338       getResultDictAttrName(), oldNumResults, newNumResults,
339       [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); });
340 }
341