1 //===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/FunctionInterfaces.h"
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // Tablegen Interface Definitions
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/IR/FunctionOpInterfaces.cpp.inc"
18 
19 //===----------------------------------------------------------------------===//
20 // Function Arguments and Results.
21 //===----------------------------------------------------------------------===//
22 
23 static bool isEmptyAttrDict(Attribute attr) {
24   return attr.cast<DictionaryAttr>().empty();
25 }
26 
27 DictionaryAttr mlir::function_interface_impl::getArgAttrDict(Operation *op,
28                                                              unsigned index) {
29   ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
30   DictionaryAttr argAttrs =
31       attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
32   return argAttrs;
33 }
34 
35 DictionaryAttr
36 mlir::function_interface_impl::getResultAttrDict(Operation *op,
37                                                  unsigned index) {
38   ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
39   DictionaryAttr resAttrs =
40       attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
41   return resAttrs;
42 }
43 
44 void mlir::function_interface_impl::detail::setArgResAttrDict(
45     Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index,
46     DictionaryAttr attrs) {
47   ArrayAttr allAttrs = op->getAttrOfType<ArrayAttr>(attrName);
48   if (!allAttrs) {
49     if (attrs.empty())
50       return;
51 
52     // If this attribute is not empty, we need to create a new attribute array.
53     SmallVector<Attribute, 8> newAttrs(numTotalIndices,
54                                        DictionaryAttr::get(op->getContext()));
55     newAttrs[index] = attrs;
56     op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
57     return;
58   }
59   // Check to see if the attribute is different from what we already have.
60   if (allAttrs[index] == attrs)
61     return;
62 
63   // If it is, check to see if the attribute array would now contain only empty
64   // dictionaries.
65   ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
66   if (attrs.empty() &&
67       llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
68       llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) {
69     op->removeAttr(attrName);
70     return;
71   }
72 
73   // Otherwise, create a new attribute array with the updated dictionary.
74   SmallVector<Attribute, 8> newAttrs(rawAttrArray.begin(), rawAttrArray.end());
75   newAttrs[index] = attrs;
76   op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
77 }
78 
79 /// Set all of the argument or result attribute dictionaries for a function.
80 static void setAllArgResAttrDicts(Operation *op, StringRef attrName,
81                                   ArrayRef<Attribute> attrs) {
82   if (llvm::all_of(attrs, isEmptyAttrDict))
83     op->removeAttr(attrName);
84   else
85     op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs));
86 }
87 
88 void mlir::function_interface_impl::setAllArgAttrDicts(
89     Operation *op, ArrayRef<DictionaryAttr> attrs) {
90   setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
91 }
92 void mlir::function_interface_impl::setAllArgAttrDicts(
93     Operation *op, ArrayRef<Attribute> attrs) {
94   auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
95     return !attr ? DictionaryAttr::get(op->getContext()) : attr;
96   });
97   setAllArgResAttrDicts(op, getArgDictAttrName(),
98                         llvm::to_vector<8>(wrappedAttrs));
99 }
100 
101 void mlir::function_interface_impl::setAllResultAttrDicts(
102     Operation *op, ArrayRef<DictionaryAttr> attrs) {
103   setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
104 }
105 void mlir::function_interface_impl::setAllResultAttrDicts(
106     Operation *op, ArrayRef<Attribute> attrs) {
107   auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
108     return !attr ? DictionaryAttr::get(op->getContext()) : attr;
109   });
110   setAllArgResAttrDicts(op, getResultDictAttrName(),
111                         llvm::to_vector<8>(wrappedAttrs));
112 }
113 
114 void mlir::function_interface_impl::insertFunctionArguments(
115     Operation *op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
116     ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs,
117     unsigned originalNumArgs, Type newType) {
118   assert(argIndices.size() == argTypes.size());
119   assert(argIndices.size() == argAttrs.size() || argAttrs.empty());
120   assert(argIndices.size() == argLocs.size());
121   if (argIndices.empty())
122     return;
123 
124   // There are 3 things that need to be updated:
125   // - Function type.
126   // - Arg attrs.
127   // - Block arguments of entry block.
128   Block &entry = op->getRegion(0).front();
129 
130   // Update the argument attributes of the function.
131   auto oldArgAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
132   if (oldArgAttrs || !argAttrs.empty()) {
133     SmallVector<DictionaryAttr, 4> newArgAttrs;
134     newArgAttrs.reserve(originalNumArgs + argIndices.size());
135     unsigned oldIdx = 0;
136     auto migrate = [&](unsigned untilIdx) {
137       if (!oldArgAttrs) {
138         newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx);
139       } else {
140         auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>();
141         newArgAttrs.append(oldArgAttrRange.begin() + oldIdx,
142                            oldArgAttrRange.begin() + untilIdx);
143       }
144       oldIdx = untilIdx;
145     };
146     for (unsigned i = 0, e = argIndices.size(); i < e; ++i) {
147       migrate(argIndices[i]);
148       newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]);
149     }
150     migrate(originalNumArgs);
151     setAllArgAttrDicts(op, newArgAttrs);
152   }
153 
154   // Update the function type and any entry block arguments.
155   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
156   for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
157     entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
158 }
159 
160 void mlir::function_interface_impl::insertFunctionResults(
161     Operation *op, ArrayRef<unsigned> resultIndices, TypeRange resultTypes,
162     ArrayRef<DictionaryAttr> resultAttrs, unsigned originalNumResults,
163     Type newType) {
164   assert(resultIndices.size() == resultTypes.size());
165   assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
166   if (resultIndices.empty())
167     return;
168 
169   // There are 2 things that need to be updated:
170   // - Function type.
171   // - Result attrs.
172 
173   // Update the result attributes of the function.
174   auto oldResultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
175   if (oldResultAttrs || !resultAttrs.empty()) {
176     SmallVector<DictionaryAttr, 4> newResultAttrs;
177     newResultAttrs.reserve(originalNumResults + resultIndices.size());
178     unsigned oldIdx = 0;
179     auto migrate = [&](unsigned untilIdx) {
180       if (!oldResultAttrs) {
181         newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx);
182       } else {
183         auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>();
184         newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx,
185                               oldResultAttrsRange.begin() + untilIdx);
186       }
187       oldIdx = untilIdx;
188     };
189     for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) {
190       migrate(resultIndices[i]);
191       newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{}
192                                                    : resultAttrs[i]);
193     }
194     migrate(originalNumResults);
195     setAllResultAttrDicts(op, newResultAttrs);
196   }
197 
198   // Update the function type.
199   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
200 }
201 
202 void mlir::function_interface_impl::eraseFunctionArguments(
203     Operation *op, const BitVector &argIndices, Type newType) {
204   // There are 3 things that need to be updated:
205   // - Function type.
206   // - Arg attrs.
207   // - Block arguments of entry block.
208   Block &entry = op->getRegion(0).front();
209 
210   // Update the argument attributes of the function.
211   if (auto argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName())) {
212     SmallVector<DictionaryAttr, 4> newArgAttrs;
213     newArgAttrs.reserve(argAttrs.size());
214     for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
215       if (!argIndices[i])
216         newArgAttrs.emplace_back(argAttrs[i].cast<DictionaryAttr>());
217     setAllArgAttrDicts(op, newArgAttrs);
218   }
219 
220   // Update the function type and any entry block arguments.
221   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
222   entry.eraseArguments(argIndices);
223 }
224 
225 void mlir::function_interface_impl::eraseFunctionResults(
226     Operation *op, const BitVector &resultIndices, Type newType) {
227   // There are 2 things that need to be updated:
228   // - Function type.
229   // - Result attrs.
230 
231   // Update the result attributes of the function.
232   if (auto resAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName())) {
233     SmallVector<DictionaryAttr, 4> newResultAttrs;
234     newResultAttrs.reserve(resAttrs.size());
235     for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
236       if (!resultIndices[i])
237         newResultAttrs.emplace_back(resAttrs[i].cast<DictionaryAttr>());
238     setAllResultAttrDicts(op, newResultAttrs);
239   }
240 
241   // Update the function type.
242   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
243 }
244 
245 TypeRange mlir::function_interface_impl::insertTypesInto(
246     TypeRange oldTypes, ArrayRef<unsigned> indices, TypeRange newTypes,
247     SmallVectorImpl<Type> &storage) {
248   assert(indices.size() == newTypes.size() &&
249          "mismatch between indice and type count");
250   if (indices.empty())
251     return oldTypes;
252 
253   auto fromIt = oldTypes.begin();
254   for (auto it : llvm::zip(indices, newTypes)) {
255     const auto toIt = oldTypes.begin() + std::get<0>(it);
256     storage.append(fromIt, toIt);
257     storage.push_back(std::get<1>(it));
258     fromIt = toIt;
259   }
260   storage.append(fromIt, oldTypes.end());
261   return storage;
262 }
263 
264 TypeRange
265 mlir::function_interface_impl::filterTypesOut(TypeRange types,
266                                               const BitVector &indices,
267                                               SmallVectorImpl<Type> &storage) {
268   if (indices.none())
269     return types;
270 
271   for (unsigned i = 0, e = types.size(); i < e; ++i)
272     if (!indices[i])
273       storage.emplace_back(types[i]);
274   return storage;
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // Function type signature.
279 //===----------------------------------------------------------------------===//
280 
281 void mlir::function_interface_impl::setFunctionType(Operation *op,
282                                                     Type newType) {
283   FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
284   unsigned oldNumArgs = funcOp.getNumArguments();
285   unsigned oldNumResults = funcOp.getNumResults();
286   op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
287   unsigned newNumArgs = funcOp.getNumArguments();
288   unsigned newNumResults = funcOp.getNumResults();
289 
290   // Functor used to update the argument and result attributes of the function.
291   auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,
292                           unsigned newCount, auto setAttrFn) {
293     if (oldCount == newCount)
294       return;
295     // The new type has no arguments/results, just drop the attribute.
296     if (newCount == 0) {
297       op->removeAttr(attrName);
298       return;
299     }
300     ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(attrName);
301     if (!attrs)
302       return;
303 
304     // The new type has less arguments/results, take the first N attributes.
305     if (newCount < oldCount)
306       return setAttrFn(op, attrs.getValue().take_front(newCount));
307 
308     // Otherwise, the new type has more arguments/results. Initialize the new
309     // arguments/results with empty attributes.
310     SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
311     newAttrs.resize(newCount);
312     setAttrFn(op, newAttrs);
313   };
314 
315   // Update the argument and result attributes.
316   updateAttrFn(
317       getArgDictAttrName(), oldNumArgs, newNumArgs,
318       [&](Operation *op, auto &&attrs) { setAllArgAttrDicts(op, attrs); });
319   updateAttrFn(
320       getResultDictAttrName(), oldNumResults, newNumResults,
321       [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); });
322 }
323