1 //===- FunctionSupport.cpp - Utility types for function-like ops ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/FunctionInterfaces.h" 10 #include "mlir/Support/LLVM.h" 11 #include "llvm/ADT/BitVector.h" 12 13 using namespace mlir; 14 15 /// Helper to call a callback once on each index in the range 16 /// [0, `totalIndices`), *except* for the indices given in `indices`. 17 /// `indices` is allowed to have duplicates and can be in any order. 18 inline static void iterateIndicesExcept(unsigned totalIndices, 19 ArrayRef<unsigned> indices, 20 function_ref<void(unsigned)> callback) { 21 llvm::BitVector skipIndices(totalIndices); 22 for (unsigned i : indices) 23 skipIndices.set(i); 24 25 for (unsigned i = 0; i < totalIndices; ++i) 26 if (!skipIndices.test(i)) 27 callback(i); 28 } 29 30 //===----------------------------------------------------------------------===// 31 // Tablegen Interface Definitions 32 //===----------------------------------------------------------------------===// 33 34 #include "mlir/IR/FunctionOpInterfaces.cpp.inc" 35 36 //===----------------------------------------------------------------------===// 37 // Function Arguments and Results. 38 //===----------------------------------------------------------------------===// 39 40 static bool isEmptyAttrDict(Attribute attr) { 41 return attr.cast<DictionaryAttr>().empty(); 42 } 43 44 DictionaryAttr mlir::function_interface_impl::getArgAttrDict(Operation *op, 45 unsigned index) { 46 ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName()); 47 DictionaryAttr argAttrs = 48 attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr(); 49 return argAttrs; 50 } 51 52 DictionaryAttr 53 mlir::function_interface_impl::getResultAttrDict(Operation *op, 54 unsigned index) { 55 ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName()); 56 DictionaryAttr resAttrs = 57 attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr(); 58 return resAttrs; 59 } 60 61 void mlir::function_interface_impl::detail::setArgResAttrDict( 62 Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index, 63 DictionaryAttr attrs) { 64 ArrayAttr allAttrs = op->getAttrOfType<ArrayAttr>(attrName); 65 if (!allAttrs) { 66 if (attrs.empty()) 67 return; 68 69 // If this attribute is not empty, we need to create a new attribute array. 70 SmallVector<Attribute, 8> newAttrs(numTotalIndices, 71 DictionaryAttr::get(op->getContext())); 72 newAttrs[index] = attrs; 73 op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs)); 74 return; 75 } 76 // Check to see if the attribute is different from what we already have. 77 if (allAttrs[index] == attrs) 78 return; 79 80 // If it is, check to see if the attribute array would now contain only empty 81 // dictionaries. 82 ArrayRef<Attribute> rawAttrArray = allAttrs.getValue(); 83 if (attrs.empty() && 84 llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) && 85 llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) { 86 op->removeAttr(attrName); 87 return; 88 } 89 90 // Otherwise, create a new attribute array with the updated dictionary. 91 SmallVector<Attribute, 8> newAttrs(rawAttrArray.begin(), rawAttrArray.end()); 92 newAttrs[index] = attrs; 93 op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs)); 94 } 95 96 /// Set all of the argument or result attribute dictionaries for a function. 97 static void setAllArgResAttrDicts(Operation *op, StringRef attrName, 98 ArrayRef<Attribute> attrs) { 99 if (llvm::all_of(attrs, isEmptyAttrDict)) 100 op->removeAttr(attrName); 101 else 102 op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs)); 103 } 104 105 void mlir::function_interface_impl::setAllArgAttrDicts( 106 Operation *op, ArrayRef<DictionaryAttr> attrs) { 107 setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size())); 108 } 109 void mlir::function_interface_impl::setAllArgAttrDicts( 110 Operation *op, ArrayRef<Attribute> attrs) { 111 auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { 112 return !attr ? DictionaryAttr::get(op->getContext()) : attr; 113 }); 114 setAllArgResAttrDicts(op, getArgDictAttrName(), 115 llvm::to_vector<8>(wrappedAttrs)); 116 } 117 118 void mlir::function_interface_impl::setAllResultAttrDicts( 119 Operation *op, ArrayRef<DictionaryAttr> attrs) { 120 setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size())); 121 } 122 void mlir::function_interface_impl::setAllResultAttrDicts( 123 Operation *op, ArrayRef<Attribute> attrs) { 124 auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { 125 return !attr ? DictionaryAttr::get(op->getContext()) : attr; 126 }); 127 setAllArgResAttrDicts(op, getResultDictAttrName(), 128 llvm::to_vector<8>(wrappedAttrs)); 129 } 130 131 void mlir::function_interface_impl::insertFunctionArguments( 132 Operation *op, ArrayRef<unsigned> argIndices, TypeRange argTypes, 133 ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Optional<Location>> argLocs, 134 unsigned originalNumArgs, Type newType) { 135 assert(argIndices.size() == argTypes.size()); 136 assert(argIndices.size() == argAttrs.size() || argAttrs.empty()); 137 assert(argIndices.size() == argLocs.size() || argLocs.empty()); 138 if (argIndices.empty()) 139 return; 140 141 // There are 3 things that need to be updated: 142 // - Function type. 143 // - Arg attrs. 144 // - Block arguments of entry block. 145 Block &entry = op->getRegion(0).front(); 146 147 // Update the argument attributes of the function. 148 auto oldArgAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName()); 149 if (oldArgAttrs || !argAttrs.empty()) { 150 SmallVector<DictionaryAttr, 4> newArgAttrs; 151 newArgAttrs.reserve(originalNumArgs + argIndices.size()); 152 unsigned oldIdx = 0; 153 auto migrate = [&](unsigned untilIdx) { 154 if (!oldArgAttrs) { 155 newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx); 156 } else { 157 auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>(); 158 newArgAttrs.append(oldArgAttrRange.begin() + oldIdx, 159 oldArgAttrRange.begin() + untilIdx); 160 } 161 oldIdx = untilIdx; 162 }; 163 for (unsigned i = 0, e = argIndices.size(); i < e; ++i) { 164 migrate(argIndices[i]); 165 newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]); 166 } 167 migrate(originalNumArgs); 168 setAllArgAttrDicts(op, newArgAttrs); 169 } 170 171 // Update the function type and any entry block arguments. 172 op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); 173 for (unsigned i = 0, e = argIndices.size(); i < e; ++i) 174 entry.insertArgument(argIndices[i] + i, argTypes[i], 175 argLocs.empty() ? Optional<Location>{} : argLocs[i]); 176 } 177 178 void mlir::function_interface_impl::insertFunctionResults( 179 Operation *op, ArrayRef<unsigned> resultIndices, TypeRange resultTypes, 180 ArrayRef<DictionaryAttr> resultAttrs, unsigned originalNumResults, 181 Type newType) { 182 assert(resultIndices.size() == resultTypes.size()); 183 assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty()); 184 if (resultIndices.empty()) 185 return; 186 187 // There are 2 things that need to be updated: 188 // - Function type. 189 // - Result attrs. 190 191 // Update the result attributes of the function. 192 auto oldResultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName()); 193 if (oldResultAttrs || !resultAttrs.empty()) { 194 SmallVector<DictionaryAttr, 4> newResultAttrs; 195 newResultAttrs.reserve(originalNumResults + resultIndices.size()); 196 unsigned oldIdx = 0; 197 auto migrate = [&](unsigned untilIdx) { 198 if (!oldResultAttrs) { 199 newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx); 200 } else { 201 auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>(); 202 newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx, 203 oldResultAttrsRange.begin() + untilIdx); 204 } 205 oldIdx = untilIdx; 206 }; 207 for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) { 208 migrate(resultIndices[i]); 209 newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{} 210 : resultAttrs[i]); 211 } 212 migrate(originalNumResults); 213 setAllResultAttrDicts(op, newResultAttrs); 214 } 215 216 // Update the function type. 217 op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); 218 } 219 220 void mlir::function_interface_impl::eraseFunctionArguments( 221 Operation *op, ArrayRef<unsigned> argIndices, unsigned originalNumArgs, 222 Type newType) { 223 // There are 3 things that need to be updated: 224 // - Function type. 225 // - Arg attrs. 226 // - Block arguments of entry block. 227 Block &entry = op->getRegion(0).front(); 228 229 // Update the argument attributes of the function. 230 if (auto argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName())) { 231 SmallVector<DictionaryAttr, 4> newArgAttrs; 232 newArgAttrs.reserve(argAttrs.size()); 233 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { 234 newArgAttrs.emplace_back(argAttrs[i].cast<DictionaryAttr>()); 235 }); 236 setAllArgAttrDicts(op, newArgAttrs); 237 } 238 239 // Update the function type and any entry block arguments. 240 op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); 241 entry.eraseArguments(argIndices); 242 } 243 244 void mlir::function_interface_impl::eraseFunctionResults( 245 Operation *op, ArrayRef<unsigned> resultIndices, 246 unsigned originalNumResults, Type newType) { 247 // There are 2 things that need to be updated: 248 // - Function type. 249 // - Result attrs. 250 251 // Update the result attributes of the function. 252 if (auto resAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName())) { 253 SmallVector<DictionaryAttr, 4> newResultAttrs; 254 newResultAttrs.reserve(resAttrs.size()); 255 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { 256 newResultAttrs.emplace_back(resAttrs[i].cast<DictionaryAttr>()); 257 }); 258 setAllResultAttrDicts(op, newResultAttrs); 259 } 260 261 // Update the function type. 262 op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); 263 } 264 265 TypeRange mlir::function_interface_impl::insertTypesInto( 266 TypeRange oldTypes, ArrayRef<unsigned> indices, TypeRange newTypes, 267 SmallVectorImpl<Type> &storage) { 268 assert(indices.size() == newTypes.size() && 269 "mismatch between indice and type count"); 270 if (indices.empty()) 271 return oldTypes; 272 273 auto fromIt = oldTypes.begin(); 274 for (auto it : llvm::zip(indices, newTypes)) { 275 const auto toIt = oldTypes.begin() + std::get<0>(it); 276 storage.append(fromIt, toIt); 277 storage.push_back(std::get<1>(it)); 278 fromIt = toIt; 279 } 280 storage.append(fromIt, oldTypes.end()); 281 return storage; 282 } 283 284 TypeRange 285 mlir::function_interface_impl::filterTypesOut(TypeRange types, 286 ArrayRef<unsigned> indices, 287 SmallVectorImpl<Type> &storage) { 288 if (indices.empty()) 289 return types; 290 iterateIndicesExcept(types.size(), indices, 291 [&](unsigned i) { storage.emplace_back(types[i]); }); 292 return storage; 293 } 294 295 //===----------------------------------------------------------------------===// 296 // Function type signature. 297 //===----------------------------------------------------------------------===// 298 299 void mlir::function_interface_impl::setFunctionType(Operation *op, 300 Type newType) { 301 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op); 302 unsigned oldNumArgs = funcOp.getNumArguments(); 303 unsigned oldNumResults = funcOp.getNumResults(); 304 op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); 305 unsigned newNumArgs = funcOp.getNumArguments(); 306 unsigned newNumResults = funcOp.getNumResults(); 307 308 // Functor used to update the argument and result attributes of the function. 309 auto updateAttrFn = [&](StringRef attrName, unsigned oldCount, 310 unsigned newCount, auto setAttrFn) { 311 if (oldCount == newCount) 312 return; 313 // The new type has no arguments/results, just drop the attribute. 314 if (newCount == 0) { 315 op->removeAttr(attrName); 316 return; 317 } 318 ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(attrName); 319 if (!attrs) 320 return; 321 322 // The new type has less arguments/results, take the first N attributes. 323 if (newCount < oldCount) 324 return setAttrFn(op, attrs.getValue().take_front(newCount)); 325 326 // Otherwise, the new type has more arguments/results. Initialize the new 327 // arguments/results with empty attributes. 328 SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end()); 329 newAttrs.resize(newCount); 330 setAttrFn(op, newAttrs); 331 }; 332 333 // Update the argument and result attributes. 334 updateAttrFn( 335 getArgDictAttrName(), oldNumArgs, newNumArgs, 336 [&](Operation *op, auto &&attrs) { setAllArgAttrDicts(op, attrs); }); 337 updateAttrFn( 338 getResultDictAttrName(), oldNumResults, newNumResults, 339 [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); }); 340 } 341