1 //===- FunctionSupport.cpp - Utility types for function-like ops ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/FunctionInterfaces.h" 10 11 using namespace mlir; 12 13 //===----------------------------------------------------------------------===// 14 // Tablegen Interface Definitions 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/IR/FunctionOpInterfaces.cpp.inc" 18 19 //===----------------------------------------------------------------------===// 20 // Function Arguments and Results. 21 //===----------------------------------------------------------------------===// 22 23 static bool isEmptyAttrDict(Attribute attr) { 24 return attr.cast<DictionaryAttr>().empty(); 25 } 26 27 DictionaryAttr mlir::function_interface_impl::getArgAttrDict(Operation *op, 28 unsigned index) { 29 ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName()); 30 DictionaryAttr argAttrs = 31 attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr(); 32 return argAttrs; 33 } 34 35 DictionaryAttr 36 mlir::function_interface_impl::getResultAttrDict(Operation *op, 37 unsigned index) { 38 ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName()); 39 DictionaryAttr resAttrs = 40 attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr(); 41 return resAttrs; 42 } 43 44 void mlir::function_interface_impl::detail::setArgResAttrDict( 45 Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index, 46 DictionaryAttr attrs) { 47 ArrayAttr allAttrs = op->getAttrOfType<ArrayAttr>(attrName); 48 if (!allAttrs) { 49 if (attrs.empty()) 50 return; 51 52 // If this attribute is not empty, we need to create a new attribute array. 53 SmallVector<Attribute, 8> newAttrs(numTotalIndices, 54 DictionaryAttr::get(op->getContext())); 55 newAttrs[index] = attrs; 56 op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs)); 57 return; 58 } 59 // Check to see if the attribute is different from what we already have. 60 if (allAttrs[index] == attrs) 61 return; 62 63 // If it is, check to see if the attribute array would now contain only empty 64 // dictionaries. 65 ArrayRef<Attribute> rawAttrArray = allAttrs.getValue(); 66 if (attrs.empty() && 67 llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) && 68 llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) { 69 op->removeAttr(attrName); 70 return; 71 } 72 73 // Otherwise, create a new attribute array with the updated dictionary. 74 SmallVector<Attribute, 8> newAttrs(rawAttrArray.begin(), rawAttrArray.end()); 75 newAttrs[index] = attrs; 76 op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs)); 77 } 78 79 /// Set all of the argument or result attribute dictionaries for a function. 80 static void setAllArgResAttrDicts(Operation *op, StringRef attrName, 81 ArrayRef<Attribute> attrs) { 82 if (llvm::all_of(attrs, isEmptyAttrDict)) 83 op->removeAttr(attrName); 84 else 85 op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs)); 86 } 87 88 void mlir::function_interface_impl::setAllArgAttrDicts( 89 Operation *op, ArrayRef<DictionaryAttr> attrs) { 90 setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size())); 91 } 92 void mlir::function_interface_impl::setAllArgAttrDicts( 93 Operation *op, ArrayRef<Attribute> attrs) { 94 auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { 95 return !attr ? DictionaryAttr::get(op->getContext()) : attr; 96 }); 97 setAllArgResAttrDicts(op, getArgDictAttrName(), 98 llvm::to_vector<8>(wrappedAttrs)); 99 } 100 101 void mlir::function_interface_impl::setAllResultAttrDicts( 102 Operation *op, ArrayRef<DictionaryAttr> attrs) { 103 setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size())); 104 } 105 void mlir::function_interface_impl::setAllResultAttrDicts( 106 Operation *op, ArrayRef<Attribute> attrs) { 107 auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { 108 return !attr ? DictionaryAttr::get(op->getContext()) : attr; 109 }); 110 setAllArgResAttrDicts(op, getResultDictAttrName(), 111 llvm::to_vector<8>(wrappedAttrs)); 112 } 113 114 void mlir::function_interface_impl::insertFunctionArguments( 115 Operation *op, ArrayRef<unsigned> argIndices, TypeRange argTypes, 116 ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs, 117 unsigned originalNumArgs, Type newType) { 118 assert(argIndices.size() == argTypes.size()); 119 assert(argIndices.size() == argAttrs.size() || argAttrs.empty()); 120 assert(argIndices.size() == argLocs.size()); 121 if (argIndices.empty()) 122 return; 123 124 // There are 3 things that need to be updated: 125 // - Function type. 126 // - Arg attrs. 127 // - Block arguments of entry block. 128 Block &entry = op->getRegion(0).front(); 129 130 // Update the argument attributes of the function. 131 auto oldArgAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName()); 132 if (oldArgAttrs || !argAttrs.empty()) { 133 SmallVector<DictionaryAttr, 4> newArgAttrs; 134 newArgAttrs.reserve(originalNumArgs + argIndices.size()); 135 unsigned oldIdx = 0; 136 auto migrate = [&](unsigned untilIdx) { 137 if (!oldArgAttrs) { 138 newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx); 139 } else { 140 auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>(); 141 newArgAttrs.append(oldArgAttrRange.begin() + oldIdx, 142 oldArgAttrRange.begin() + untilIdx); 143 } 144 oldIdx = untilIdx; 145 }; 146 for (unsigned i = 0, e = argIndices.size(); i < e; ++i) { 147 migrate(argIndices[i]); 148 newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]); 149 } 150 migrate(originalNumArgs); 151 setAllArgAttrDicts(op, newArgAttrs); 152 } 153 154 // Update the function type and any entry block arguments. 155 op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); 156 for (unsigned i = 0, e = argIndices.size(); i < e; ++i) 157 entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]); 158 } 159 160 void mlir::function_interface_impl::insertFunctionResults( 161 Operation *op, ArrayRef<unsigned> resultIndices, TypeRange resultTypes, 162 ArrayRef<DictionaryAttr> resultAttrs, unsigned originalNumResults, 163 Type newType) { 164 assert(resultIndices.size() == resultTypes.size()); 165 assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty()); 166 if (resultIndices.empty()) 167 return; 168 169 // There are 2 things that need to be updated: 170 // - Function type. 171 // - Result attrs. 172 173 // Update the result attributes of the function. 174 auto oldResultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName()); 175 if (oldResultAttrs || !resultAttrs.empty()) { 176 SmallVector<DictionaryAttr, 4> newResultAttrs; 177 newResultAttrs.reserve(originalNumResults + resultIndices.size()); 178 unsigned oldIdx = 0; 179 auto migrate = [&](unsigned untilIdx) { 180 if (!oldResultAttrs) { 181 newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx); 182 } else { 183 auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>(); 184 newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx, 185 oldResultAttrsRange.begin() + untilIdx); 186 } 187 oldIdx = untilIdx; 188 }; 189 for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) { 190 migrate(resultIndices[i]); 191 newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{} 192 : resultAttrs[i]); 193 } 194 migrate(originalNumResults); 195 setAllResultAttrDicts(op, newResultAttrs); 196 } 197 198 // Update the function type. 199 op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); 200 } 201 202 void mlir::function_interface_impl::eraseFunctionArguments( 203 Operation *op, const BitVector &argIndices, Type newType) { 204 // There are 3 things that need to be updated: 205 // - Function type. 206 // - Arg attrs. 207 // - Block arguments of entry block. 208 Block &entry = op->getRegion(0).front(); 209 210 // Update the argument attributes of the function. 211 if (auto argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName())) { 212 SmallVector<DictionaryAttr, 4> newArgAttrs; 213 newArgAttrs.reserve(argAttrs.size()); 214 for (unsigned i = 0, e = argIndices.size(); i < e; ++i) 215 if (!argIndices[i]) 216 newArgAttrs.emplace_back(argAttrs[i].cast<DictionaryAttr>()); 217 setAllArgAttrDicts(op, newArgAttrs); 218 } 219 220 // Update the function type and any entry block arguments. 221 op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); 222 entry.eraseArguments(argIndices); 223 } 224 225 void mlir::function_interface_impl::eraseFunctionResults( 226 Operation *op, const BitVector &resultIndices, Type newType) { 227 // There are 2 things that need to be updated: 228 // - Function type. 229 // - Result attrs. 230 231 // Update the result attributes of the function. 232 if (auto resAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName())) { 233 SmallVector<DictionaryAttr, 4> newResultAttrs; 234 newResultAttrs.reserve(resAttrs.size()); 235 for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) 236 if (!resultIndices[i]) 237 newResultAttrs.emplace_back(resAttrs[i].cast<DictionaryAttr>()); 238 setAllResultAttrDicts(op, newResultAttrs); 239 } 240 241 // Update the function type. 242 op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); 243 } 244 245 TypeRange mlir::function_interface_impl::insertTypesInto( 246 TypeRange oldTypes, ArrayRef<unsigned> indices, TypeRange newTypes, 247 SmallVectorImpl<Type> &storage) { 248 assert(indices.size() == newTypes.size() && 249 "mismatch between indice and type count"); 250 if (indices.empty()) 251 return oldTypes; 252 253 auto fromIt = oldTypes.begin(); 254 for (auto it : llvm::zip(indices, newTypes)) { 255 const auto toIt = oldTypes.begin() + std::get<0>(it); 256 storage.append(fromIt, toIt); 257 storage.push_back(std::get<1>(it)); 258 fromIt = toIt; 259 } 260 storage.append(fromIt, oldTypes.end()); 261 return storage; 262 } 263 264 TypeRange 265 mlir::function_interface_impl::filterTypesOut(TypeRange types, 266 const BitVector &indices, 267 SmallVectorImpl<Type> &storage) { 268 if (indices.none()) 269 return types; 270 271 for (unsigned i = 0, e = types.size(); i < e; ++i) 272 if (!indices[i]) 273 storage.emplace_back(types[i]); 274 return storage; 275 } 276 277 //===----------------------------------------------------------------------===// 278 // Function type signature. 279 //===----------------------------------------------------------------------===// 280 281 void mlir::function_interface_impl::setFunctionType(Operation *op, 282 Type newType) { 283 FunctionOpInterface funcOp = cast<FunctionOpInterface>(op); 284 unsigned oldNumArgs = funcOp.getNumArguments(); 285 unsigned oldNumResults = funcOp.getNumResults(); 286 op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); 287 unsigned newNumArgs = funcOp.getNumArguments(); 288 unsigned newNumResults = funcOp.getNumResults(); 289 290 // Functor used to update the argument and result attributes of the function. 291 auto updateAttrFn = [&](StringRef attrName, unsigned oldCount, 292 unsigned newCount, auto setAttrFn) { 293 if (oldCount == newCount) 294 return; 295 // The new type has no arguments/results, just drop the attribute. 296 if (newCount == 0) { 297 op->removeAttr(attrName); 298 return; 299 } 300 ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(attrName); 301 if (!attrs) 302 return; 303 304 // The new type has less arguments/results, take the first N attributes. 305 if (newCount < oldCount) 306 return setAttrFn(op, attrs.getValue().take_front(newCount)); 307 308 // Otherwise, the new type has more arguments/results. Initialize the new 309 // arguments/results with empty attributes. 310 SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end()); 311 newAttrs.resize(newCount); 312 setAttrFn(op, newAttrs); 313 }; 314 315 // Update the argument and result attributes. 316 updateAttrFn( 317 getArgDictAttrName(), oldNumArgs, newNumArgs, 318 [&](Operation *op, auto &&attrs) { setAllArgAttrDicts(op, attrs); }); 319 updateAttrFn( 320 getResultDictAttrName(), oldNumResults, newNumResults, 321 [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); }); 322 } 323