1 //===- FunctionImplementation.cpp - Utilities for function-like ops -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/FunctionImplementation.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/FunctionInterfaces.h" 12 #include "mlir/IR/SymbolTable.h" 13 14 using namespace mlir; 15 16 ParseResult mlir::function_interface_impl::parseFunctionArgumentList( 17 OpAsmParser &parser, bool allowAttributes, bool allowVariadic, 18 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames, 19 SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs, 20 bool &isVariadic) { 21 if (parser.parseLParen()) 22 return failure(); 23 24 // The argument list either has to consistently have ssa-id's followed by 25 // types, or just be a type list. It isn't ok to sometimes have SSA ID's and 26 // sometimes not. 27 auto parseArgument = [&]() -> ParseResult { 28 SMLoc loc = parser.getCurrentLocation(); 29 30 // Parse argument name if present. 31 OpAsmParser::UnresolvedOperand argument; 32 Type argumentType; 33 auto hadSSAValue = parser.parseOptionalOperand(argument, 34 /*allowResultNumber=*/false); 35 if (hadSSAValue.hasValue()) { 36 if (failed(hadSSAValue.getValue())) 37 return failure(); // Argument was present but malformed. 38 39 // Reject this if the preceding argument was missing a name. 40 if (argNames.empty() && !argTypes.empty()) 41 return parser.emitError(loc, "expected type instead of SSA identifier"); 42 43 // Parse required type. 44 if (parser.parseColonType(argumentType)) 45 return failure(); 46 } else if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) { 47 isVariadic = true; 48 return success(); 49 } else if (!argNames.empty()) { 50 // Reject this if the preceding argument had a name. 51 return parser.emitError(loc, "expected SSA identifier"); 52 } else if (parser.parseType(argumentType)) { 53 return failure(); 54 } 55 56 // Add the argument type. 57 argTypes.push_back(argumentType); 58 59 // Parse any argument attributes and source location information. 60 NamedAttrList attrs; 61 if (parser.parseOptionalAttrDict(attrs) || 62 parser.parseOptionalLocationSpecifier(argument.sourceLoc)) 63 return failure(); 64 65 if (!allowAttributes && !attrs.empty()) 66 return parser.emitError(loc, "expected arguments without attributes"); 67 argAttrs.push_back(attrs); 68 69 // If we had an argument name, then remember the parsed argument. 70 if (!argument.name.empty()) 71 argNames.push_back(argument); 72 return success(); 73 }; 74 75 // Parse the function arguments. 76 isVariadic = false; 77 if (failed(parser.parseOptionalRParen())) { 78 do { 79 unsigned numTypedArguments = argTypes.size(); 80 if (parseArgument()) 81 return failure(); 82 83 SMLoc loc = parser.getCurrentLocation(); 84 if (argTypes.size() == numTypedArguments && 85 succeeded(parser.parseOptionalComma())) 86 return parser.emitError( 87 loc, "variadic arguments must be in the end of the argument list"); 88 } while (succeeded(parser.parseOptionalComma())); 89 parser.parseRParen(); 90 } 91 92 return success(); 93 } 94 95 /// Parse a function result list. 96 /// 97 /// function-result-list ::= function-result-list-parens 98 /// | non-function-type 99 /// function-result-list-parens ::= `(` `)` 100 /// | `(` function-result-list-no-parens `)` 101 /// function-result-list-no-parens ::= function-result (`,` function-result)* 102 /// function-result ::= type attribute-dict? 103 /// 104 static ParseResult 105 parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes, 106 SmallVectorImpl<NamedAttrList> &resultAttrs) { 107 if (failed(parser.parseOptionalLParen())) { 108 // We already know that there is no `(`, so parse a type. 109 // Because there is no `(`, it cannot be a function type. 110 Type ty; 111 if (parser.parseType(ty)) 112 return failure(); 113 resultTypes.push_back(ty); 114 resultAttrs.emplace_back(); 115 return success(); 116 } 117 118 // Special case for an empty set of parens. 119 if (succeeded(parser.parseOptionalRParen())) 120 return success(); 121 122 // Parse individual function results. 123 do { 124 resultTypes.emplace_back(); 125 resultAttrs.emplace_back(); 126 if (parser.parseType(resultTypes.back()) || 127 parser.parseOptionalAttrDict(resultAttrs.back())) { 128 return failure(); 129 } 130 } while (succeeded(parser.parseOptionalComma())); 131 return parser.parseRParen(); 132 } 133 134 ParseResult mlir::function_interface_impl::parseFunctionSignature( 135 OpAsmParser &parser, bool allowVariadic, 136 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames, 137 SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs, 138 bool &isVariadic, SmallVectorImpl<Type> &resultTypes, 139 SmallVectorImpl<NamedAttrList> &resultAttrs) { 140 bool allowArgAttrs = true; 141 if (parseFunctionArgumentList(parser, allowArgAttrs, allowVariadic, argNames, 142 argTypes, argAttrs, isVariadic)) 143 return failure(); 144 if (succeeded(parser.parseOptionalArrow())) 145 return parseFunctionResultList(parser, resultTypes, resultAttrs); 146 return success(); 147 } 148 149 /// Implementation of `addArgAndResultAttrs` that is attribute list type 150 /// agnostic. 151 template <typename AttrListT, typename AttrArrayBuildFnT> 152 static void addArgAndResultAttrsImpl(Builder &builder, OperationState &result, 153 ArrayRef<AttrListT> argAttrs, 154 ArrayRef<AttrListT> resultAttrs, 155 AttrArrayBuildFnT &&buildAttrArrayFn) { 156 auto nonEmptyAttrsFn = [](const AttrListT &attrs) { return !attrs.empty(); }; 157 158 // Add the attributes to the function arguments. 159 if (!argAttrs.empty() && llvm::any_of(argAttrs, nonEmptyAttrsFn)) { 160 ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(argAttrs)); 161 result.addAttribute(function_interface_impl::getArgDictAttrName(), 162 attrDicts); 163 } 164 // Add the attributes to the function results. 165 if (!resultAttrs.empty() && llvm::any_of(resultAttrs, nonEmptyAttrsFn)) { 166 ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(resultAttrs)); 167 result.addAttribute(function_interface_impl::getResultDictAttrName(), 168 attrDicts); 169 } 170 } 171 172 void mlir::function_interface_impl::addArgAndResultAttrs( 173 Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs, 174 ArrayRef<DictionaryAttr> resultAttrs) { 175 auto buildFn = [](ArrayRef<DictionaryAttr> attrs) { 176 return ArrayRef<Attribute>(attrs.data(), attrs.size()); 177 }; 178 addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn); 179 } 180 void mlir::function_interface_impl::addArgAndResultAttrs( 181 Builder &builder, OperationState &result, ArrayRef<NamedAttrList> argAttrs, 182 ArrayRef<NamedAttrList> resultAttrs) { 183 MLIRContext *context = builder.getContext(); 184 auto buildFn = [=](ArrayRef<NamedAttrList> attrs) { 185 return llvm::to_vector<8>( 186 llvm::map_range(attrs, [=](const NamedAttrList &attrList) -> Attribute { 187 return attrList.getDictionary(context); 188 })); 189 }; 190 addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn); 191 } 192 193 ParseResult mlir::function_interface_impl::parseFunctionOp( 194 OpAsmParser &parser, OperationState &result, bool allowVariadic, 195 FuncTypeBuilder funcTypeBuilder) { 196 SmallVector<OpAsmParser::UnresolvedOperand> entryArgs; 197 SmallVector<NamedAttrList> argAttrs; 198 SmallVector<NamedAttrList> resultAttrs; 199 SmallVector<Type> argTypes; 200 SmallVector<Type> resultTypes; 201 auto &builder = parser.getBuilder(); 202 203 // Parse visibility. 204 impl::parseOptionalVisibilityKeyword(parser, result.attributes); 205 206 // Parse the name as a symbol. 207 StringAttr nameAttr; 208 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 209 result.attributes)) 210 return failure(); 211 212 // Parse the function signature. 213 SMLoc signatureLocation = parser.getCurrentLocation(); 214 bool isVariadic = false; 215 if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes, 216 argAttrs, isVariadic, resultTypes, resultAttrs)) 217 return failure(); 218 219 std::string errorMessage; 220 Type type = funcTypeBuilder(builder, argTypes, resultTypes, 221 VariadicFlag(isVariadic), errorMessage); 222 if (!type) { 223 return parser.emitError(signatureLocation) 224 << "failed to construct function type" 225 << (errorMessage.empty() ? "" : ": ") << errorMessage; 226 } 227 result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); 228 229 // If function attributes are present, parse them. 230 NamedAttrList parsedAttributes; 231 SMLoc attributeDictLocation = parser.getCurrentLocation(); 232 if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes)) 233 return failure(); 234 235 // Disallow attributes that are inferred from elsewhere in the attribute 236 // dictionary. 237 for (StringRef disallowed : 238 {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(), 239 getTypeAttrName()}) { 240 if (parsedAttributes.get(disallowed)) 241 return parser.emitError(attributeDictLocation, "'") 242 << disallowed 243 << "' is an inferred attribute and should not be specified in the " 244 "explicit attribute dictionary"; 245 } 246 result.attributes.append(parsedAttributes); 247 248 // Add the attributes to the function arguments. 249 assert(argAttrs.size() == argTypes.size()); 250 assert(resultAttrs.size() == resultTypes.size()); 251 addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); 252 253 // Parse the optional function body. The printer will not print the body if 254 // its empty, so disallow parsing of empty body in the parser. 255 auto *body = result.addRegion(); 256 SMLoc loc = parser.getCurrentLocation(); 257 OptionalParseResult parseResult = parser.parseOptionalRegion( 258 *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes, 259 /*enableNameShadowing=*/false); 260 if (parseResult.hasValue()) { 261 if (failed(*parseResult)) 262 return failure(); 263 // Function body was parsed, make sure its not empty. 264 if (body->empty()) 265 return parser.emitError(loc, "expected non-empty function body"); 266 } 267 return success(); 268 } 269 270 /// Print a function result list. The provided `attrs` must either be null, or 271 /// contain a set of DictionaryAttrs of the same arity as `types`. 272 static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types, 273 ArrayAttr attrs) { 274 assert(!types.empty() && "Should not be called for empty result list."); 275 assert((!attrs || attrs.size() == types.size()) && 276 "Invalid number of attributes."); 277 278 auto &os = p.getStream(); 279 bool needsParens = types.size() > 1 || types[0].isa<FunctionType>() || 280 (attrs && !attrs[0].cast<DictionaryAttr>().empty()); 281 if (needsParens) 282 os << '('; 283 llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) { 284 p.printType(types[i]); 285 if (attrs) 286 p.printOptionalAttrDict(attrs[i].cast<DictionaryAttr>().getValue()); 287 }); 288 if (needsParens) 289 os << ')'; 290 } 291 292 void mlir::function_interface_impl::printFunctionSignature( 293 OpAsmPrinter &p, Operation *op, ArrayRef<Type> argTypes, bool isVariadic, 294 ArrayRef<Type> resultTypes) { 295 Region &body = op->getRegion(0); 296 bool isExternal = body.empty(); 297 298 p << '('; 299 ArrayAttr argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName()); 300 for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { 301 if (i > 0) 302 p << ", "; 303 304 if (!isExternal) { 305 ArrayRef<NamedAttribute> attrs; 306 if (argAttrs) 307 attrs = argAttrs[i].cast<DictionaryAttr>().getValue(); 308 p.printRegionArgument(body.getArgument(i), attrs); 309 } else { 310 p.printType(argTypes[i]); 311 if (argAttrs) 312 p.printOptionalAttrDict(argAttrs[i].cast<DictionaryAttr>().getValue()); 313 } 314 } 315 316 if (isVariadic) { 317 if (!argTypes.empty()) 318 p << ", "; 319 p << "..."; 320 } 321 322 p << ')'; 323 324 if (!resultTypes.empty()) { 325 p.getStream() << " -> "; 326 auto resultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName()); 327 printFunctionResultList(p, resultTypes, resultAttrs); 328 } 329 } 330 331 void mlir::function_interface_impl::printFunctionAttributes( 332 OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults, 333 ArrayRef<StringRef> elided) { 334 // Print out function attributes, if present. 335 SmallVector<StringRef, 2> ignoredAttrs = { 336 ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(), 337 getArgDictAttrName(), getResultDictAttrName()}; 338 ignoredAttrs.append(elided.begin(), elided.end()); 339 340 p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); 341 } 342 343 void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p, 344 FunctionOpInterface op, 345 bool isVariadic) { 346 // Print the operation and the function name. 347 auto funcName = 348 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()) 349 .getValue(); 350 p << ' '; 351 352 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName(); 353 if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName)) 354 p << visibility.getValue() << ' '; 355 p.printSymbolName(funcName); 356 357 ArrayRef<Type> argTypes = op.getArgumentTypes(); 358 ArrayRef<Type> resultTypes = op.getResultTypes(); 359 printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); 360 printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(), 361 {visibilityAttrName}); 362 // Print the body if this is not an external function. 363 Region &body = op->getRegion(0); 364 if (!body.empty()) { 365 p << ' '; 366 p.printRegion(body, /*printEntryBlockArgs=*/false, 367 /*printBlockTerminators=*/true); 368 } 369 } 370