1 //===- FunctionImplementation.cpp - Utilities for function-like ops -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/FunctionImplementation.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/FunctionInterfaces.h" 12 #include "mlir/IR/SymbolTable.h" 13 14 using namespace mlir; 15 16 ParseResult mlir::function_interface_impl::parseFunctionArgumentList( 17 OpAsmParser &parser, bool allowAttributes, bool allowVariadic, 18 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames, 19 SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs, 20 bool &isVariadic) { 21 if (parser.parseLParen()) 22 return failure(); 23 24 // The argument list either has to consistently have ssa-id's followed by 25 // types, or just be a type list. It isn't ok to sometimes have SSA ID's and 26 // sometimes not. 27 auto parseArgument = [&]() -> ParseResult { 28 SMLoc loc = parser.getCurrentLocation(); 29 30 // Parse argument name if present. 31 OpAsmParser::UnresolvedOperand argument; 32 Type argumentType; 33 if (succeeded(parser.parseOptionalRegionArgument(argument)) && 34 !argument.name.empty()) { 35 // Reject this if the preceding argument was missing a name. 36 if (argNames.empty() && !argTypes.empty()) 37 return parser.emitError(loc, "expected type instead of SSA identifier"); 38 39 // Parse required type. 40 if (parser.parseColonType(argumentType)) 41 return failure(); 42 } else if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) { 43 isVariadic = true; 44 return success(); 45 } else if (!argNames.empty()) { 46 // Reject this if the preceding argument had a name. 47 return parser.emitError(loc, "expected SSA identifier"); 48 } else if (parser.parseType(argumentType)) { 49 return failure(); 50 } 51 52 // Add the argument type. 53 argTypes.push_back(argumentType); 54 55 // Parse any argument attributes and source location information. 56 NamedAttrList attrs; 57 if (parser.parseOptionalAttrDict(attrs) || 58 parser.parseOptionalLocationSpecifier(argument.sourceLoc)) 59 return failure(); 60 61 if (!allowAttributes && !attrs.empty()) 62 return parser.emitError(loc, "expected arguments without attributes"); 63 argAttrs.push_back(attrs); 64 65 // If we had an argument name, then remember the parsed argument. 66 if (!argument.name.empty()) 67 argNames.push_back(argument); 68 return success(); 69 }; 70 71 // Parse the function arguments. 72 isVariadic = false; 73 if (failed(parser.parseOptionalRParen())) { 74 do { 75 unsigned numTypedArguments = argTypes.size(); 76 if (parseArgument()) 77 return failure(); 78 79 SMLoc loc = parser.getCurrentLocation(); 80 if (argTypes.size() == numTypedArguments && 81 succeeded(parser.parseOptionalComma())) 82 return parser.emitError( 83 loc, "variadic arguments must be in the end of the argument list"); 84 } while (succeeded(parser.parseOptionalComma())); 85 parser.parseRParen(); 86 } 87 88 return success(); 89 } 90 91 /// Parse a function result list. 92 /// 93 /// function-result-list ::= function-result-list-parens 94 /// | non-function-type 95 /// function-result-list-parens ::= `(` `)` 96 /// | `(` function-result-list-no-parens `)` 97 /// function-result-list-no-parens ::= function-result (`,` function-result)* 98 /// function-result ::= type attribute-dict? 99 /// 100 static ParseResult 101 parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes, 102 SmallVectorImpl<NamedAttrList> &resultAttrs) { 103 if (failed(parser.parseOptionalLParen())) { 104 // We already know that there is no `(`, so parse a type. 105 // Because there is no `(`, it cannot be a function type. 106 Type ty; 107 if (parser.parseType(ty)) 108 return failure(); 109 resultTypes.push_back(ty); 110 resultAttrs.emplace_back(); 111 return success(); 112 } 113 114 // Special case for an empty set of parens. 115 if (succeeded(parser.parseOptionalRParen())) 116 return success(); 117 118 // Parse individual function results. 119 do { 120 resultTypes.emplace_back(); 121 resultAttrs.emplace_back(); 122 if (parser.parseType(resultTypes.back()) || 123 parser.parseOptionalAttrDict(resultAttrs.back())) { 124 return failure(); 125 } 126 } while (succeeded(parser.parseOptionalComma())); 127 return parser.parseRParen(); 128 } 129 130 ParseResult mlir::function_interface_impl::parseFunctionSignature( 131 OpAsmParser &parser, bool allowVariadic, 132 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames, 133 SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs, 134 bool &isVariadic, SmallVectorImpl<Type> &resultTypes, 135 SmallVectorImpl<NamedAttrList> &resultAttrs) { 136 bool allowArgAttrs = true; 137 if (parseFunctionArgumentList(parser, allowArgAttrs, allowVariadic, argNames, 138 argTypes, argAttrs, isVariadic)) 139 return failure(); 140 if (succeeded(parser.parseOptionalArrow())) 141 return parseFunctionResultList(parser, resultTypes, resultAttrs); 142 return success(); 143 } 144 145 /// Implementation of `addArgAndResultAttrs` that is attribute list type 146 /// agnostic. 147 template <typename AttrListT, typename AttrArrayBuildFnT> 148 static void addArgAndResultAttrsImpl(Builder &builder, OperationState &result, 149 ArrayRef<AttrListT> argAttrs, 150 ArrayRef<AttrListT> resultAttrs, 151 AttrArrayBuildFnT &&buildAttrArrayFn) { 152 auto nonEmptyAttrsFn = [](const AttrListT &attrs) { return !attrs.empty(); }; 153 154 // Add the attributes to the function arguments. 155 if (!argAttrs.empty() && llvm::any_of(argAttrs, nonEmptyAttrsFn)) { 156 ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(argAttrs)); 157 result.addAttribute(function_interface_impl::getArgDictAttrName(), 158 attrDicts); 159 } 160 // Add the attributes to the function results. 161 if (!resultAttrs.empty() && llvm::any_of(resultAttrs, nonEmptyAttrsFn)) { 162 ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(resultAttrs)); 163 result.addAttribute(function_interface_impl::getResultDictAttrName(), 164 attrDicts); 165 } 166 } 167 168 void mlir::function_interface_impl::addArgAndResultAttrs( 169 Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs, 170 ArrayRef<DictionaryAttr> resultAttrs) { 171 auto buildFn = [](ArrayRef<DictionaryAttr> attrs) { 172 return ArrayRef<Attribute>(attrs.data(), attrs.size()); 173 }; 174 addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn); 175 } 176 void mlir::function_interface_impl::addArgAndResultAttrs( 177 Builder &builder, OperationState &result, ArrayRef<NamedAttrList> argAttrs, 178 ArrayRef<NamedAttrList> resultAttrs) { 179 MLIRContext *context = builder.getContext(); 180 auto buildFn = [=](ArrayRef<NamedAttrList> attrs) { 181 return llvm::to_vector<8>( 182 llvm::map_range(attrs, [=](const NamedAttrList &attrList) -> Attribute { 183 return attrList.getDictionary(context); 184 })); 185 }; 186 addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn); 187 } 188 189 ParseResult mlir::function_interface_impl::parseFunctionOp( 190 OpAsmParser &parser, OperationState &result, bool allowVariadic, 191 FuncTypeBuilder funcTypeBuilder) { 192 SmallVector<OpAsmParser::UnresolvedOperand> entryArgs; 193 SmallVector<NamedAttrList> argAttrs; 194 SmallVector<NamedAttrList> resultAttrs; 195 SmallVector<Type> argTypes; 196 SmallVector<Type> resultTypes; 197 auto &builder = parser.getBuilder(); 198 199 // Parse visibility. 200 impl::parseOptionalVisibilityKeyword(parser, result.attributes); 201 202 // Parse the name as a symbol. 203 StringAttr nameAttr; 204 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 205 result.attributes)) 206 return failure(); 207 208 // Parse the function signature. 209 SMLoc signatureLocation = parser.getCurrentLocation(); 210 bool isVariadic = false; 211 if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes, 212 argAttrs, isVariadic, resultTypes, resultAttrs)) 213 return failure(); 214 215 std::string errorMessage; 216 Type type = funcTypeBuilder(builder, argTypes, resultTypes, 217 VariadicFlag(isVariadic), errorMessage); 218 if (!type) { 219 return parser.emitError(signatureLocation) 220 << "failed to construct function type" 221 << (errorMessage.empty() ? "" : ": ") << errorMessage; 222 } 223 result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); 224 225 // If function attributes are present, parse them. 226 NamedAttrList parsedAttributes; 227 SMLoc attributeDictLocation = parser.getCurrentLocation(); 228 if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes)) 229 return failure(); 230 231 // Disallow attributes that are inferred from elsewhere in the attribute 232 // dictionary. 233 for (StringRef disallowed : 234 {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(), 235 getTypeAttrName()}) { 236 if (parsedAttributes.get(disallowed)) 237 return parser.emitError(attributeDictLocation, "'") 238 << disallowed 239 << "' is an inferred attribute and should not be specified in the " 240 "explicit attribute dictionary"; 241 } 242 result.attributes.append(parsedAttributes); 243 244 // Add the attributes to the function arguments. 245 assert(argAttrs.size() == argTypes.size()); 246 assert(resultAttrs.size() == resultTypes.size()); 247 addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); 248 249 // Parse the optional function body. The printer will not print the body if 250 // its empty, so disallow parsing of empty body in the parser. 251 auto *body = result.addRegion(); 252 SMLoc loc = parser.getCurrentLocation(); 253 OptionalParseResult parseResult = parser.parseOptionalRegion( 254 *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes, 255 /*enableNameShadowing=*/false); 256 if (parseResult.hasValue()) { 257 if (failed(*parseResult)) 258 return failure(); 259 // Function body was parsed, make sure its not empty. 260 if (body->empty()) 261 return parser.emitError(loc, "expected non-empty function body"); 262 } 263 return success(); 264 } 265 266 /// Print a function result list. The provided `attrs` must either be null, or 267 /// contain a set of DictionaryAttrs of the same arity as `types`. 268 static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types, 269 ArrayAttr attrs) { 270 assert(!types.empty() && "Should not be called for empty result list."); 271 assert((!attrs || attrs.size() == types.size()) && 272 "Invalid number of attributes."); 273 274 auto &os = p.getStream(); 275 bool needsParens = types.size() > 1 || types[0].isa<FunctionType>() || 276 (attrs && !attrs[0].cast<DictionaryAttr>().empty()); 277 if (needsParens) 278 os << '('; 279 llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) { 280 p.printType(types[i]); 281 if (attrs) 282 p.printOptionalAttrDict(attrs[i].cast<DictionaryAttr>().getValue()); 283 }); 284 if (needsParens) 285 os << ')'; 286 } 287 288 void mlir::function_interface_impl::printFunctionSignature( 289 OpAsmPrinter &p, Operation *op, ArrayRef<Type> argTypes, bool isVariadic, 290 ArrayRef<Type> resultTypes) { 291 Region &body = op->getRegion(0); 292 bool isExternal = body.empty(); 293 294 p << '('; 295 ArrayAttr argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName()); 296 for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { 297 if (i > 0) 298 p << ", "; 299 300 if (!isExternal) { 301 ArrayRef<NamedAttribute> attrs; 302 if (argAttrs) 303 attrs = argAttrs[i].cast<DictionaryAttr>().getValue(); 304 p.printRegionArgument(body.getArgument(i), attrs); 305 } else { 306 p.printType(argTypes[i]); 307 if (argAttrs) 308 p.printOptionalAttrDict(argAttrs[i].cast<DictionaryAttr>().getValue()); 309 } 310 } 311 312 if (isVariadic) { 313 if (!argTypes.empty()) 314 p << ", "; 315 p << "..."; 316 } 317 318 p << ')'; 319 320 if (!resultTypes.empty()) { 321 p.getStream() << " -> "; 322 auto resultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName()); 323 printFunctionResultList(p, resultTypes, resultAttrs); 324 } 325 } 326 327 void mlir::function_interface_impl::printFunctionAttributes( 328 OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults, 329 ArrayRef<StringRef> elided) { 330 // Print out function attributes, if present. 331 SmallVector<StringRef, 2> ignoredAttrs = { 332 ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(), 333 getArgDictAttrName(), getResultDictAttrName()}; 334 ignoredAttrs.append(elided.begin(), elided.end()); 335 336 p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); 337 } 338 339 void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p, 340 FunctionOpInterface op, 341 bool isVariadic) { 342 // Print the operation and the function name. 343 auto funcName = 344 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()) 345 .getValue(); 346 p << ' '; 347 348 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName(); 349 if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName)) 350 p << visibility.getValue() << ' '; 351 p.printSymbolName(funcName); 352 353 ArrayRef<Type> argTypes = op.getArgumentTypes(); 354 ArrayRef<Type> resultTypes = op.getResultTypes(); 355 printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); 356 printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(), 357 {visibilityAttrName}); 358 // Print the body if this is not an external function. 359 Region &body = op->getRegion(0); 360 if (!body.empty()) { 361 p << ' '; 362 p.printRegion(body, /*printEntryBlockArgs=*/false, 363 /*printBlockTerminators=*/true); 364 } 365 } 366