1 //===- FunctionImplementation.cpp - Utilities for function-like ops -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/FunctionImplementation.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/FunctionInterfaces.h" 12 #include "mlir/IR/SymbolTable.h" 13 14 using namespace mlir; 15 16 static ParseResult 17 parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic, 18 SmallVectorImpl<OpAsmParser::Argument> &arguments, 19 bool &isVariadic) { 20 21 // Parse the function arguments. The argument list either has to consistently 22 // have ssa-id's followed by types, or just be a type list. It isn't ok to 23 // sometimes have SSA ID's and sometimes not. 24 isVariadic = false; 25 26 return parser.parseCommaSeparatedList( 27 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 28 // Ellipsis must be at end of the list. 29 if (isVariadic) 30 return parser.emitError( 31 parser.getCurrentLocation(), 32 "variadic arguments must be in the end of the argument list"); 33 34 // Handle ellipsis as a special case. 35 if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) { 36 // This is a variadic designator. 37 isVariadic = true; 38 return success(); // Stop parsing arguments. 39 } 40 // Parse argument name if present. 41 OpAsmParser::Argument argument; 42 auto argPresent = parser.parseOptionalArgument( 43 argument, /*allowType=*/true, /*allowAttrs=*/true); 44 if (argPresent.hasValue()) { 45 if (failed(argPresent.getValue())) 46 return failure(); // Present but malformed. 47 48 // Reject this if the preceding argument was missing a name. 49 if (!arguments.empty() && arguments.back().ssaName.name.empty()) 50 return parser.emitError(argument.ssaName.location, 51 "expected type instead of SSA identifier"); 52 53 } else { 54 argument.ssaName.location = parser.getCurrentLocation(); 55 // Otherwise we just have a type list without SSA names. Reject 56 // this if the preceding argument had a name. 57 if (!arguments.empty() && !arguments.back().ssaName.name.empty()) 58 return parser.emitError(argument.ssaName.location, 59 "expected SSA identifier"); 60 61 NamedAttrList attrs; 62 if (parser.parseType(argument.type) || 63 parser.parseOptionalAttrDict(attrs) || 64 parser.parseOptionalLocationSpecifier(argument.sourceLoc)) 65 return failure(); 66 argument.attrs = attrs.getDictionary(parser.getContext()); 67 } 68 arguments.push_back(argument); 69 return success(); 70 }); 71 } 72 73 /// Parse a function result list. 74 /// 75 /// function-result-list ::= function-result-list-parens 76 /// | non-function-type 77 /// function-result-list-parens ::= `(` `)` 78 /// | `(` function-result-list-no-parens `)` 79 /// function-result-list-no-parens ::= function-result (`,` function-result)* 80 /// function-result ::= type attribute-dict? 81 /// 82 static ParseResult 83 parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes, 84 SmallVectorImpl<DictionaryAttr> &resultAttrs) { 85 if (failed(parser.parseOptionalLParen())) { 86 // We already know that there is no `(`, so parse a type. 87 // Because there is no `(`, it cannot be a function type. 88 Type ty; 89 if (parser.parseType(ty)) 90 return failure(); 91 resultTypes.push_back(ty); 92 resultAttrs.emplace_back(); 93 return success(); 94 } 95 96 // Special case for an empty set of parens. 97 if (succeeded(parser.parseOptionalRParen())) 98 return success(); 99 100 // Parse individual function results. 101 if (parser.parseCommaSeparatedList([&]() -> ParseResult { 102 resultTypes.emplace_back(); 103 resultAttrs.emplace_back(); 104 NamedAttrList attrs; 105 if (parser.parseType(resultTypes.back()) || 106 parser.parseOptionalAttrDict(attrs)) 107 return failure(); 108 resultAttrs.back() = attrs.getDictionary(parser.getContext()); 109 return success(); 110 })) 111 return failure(); 112 113 return parser.parseRParen(); 114 } 115 116 ParseResult mlir::function_interface_impl::parseFunctionSignature( 117 OpAsmParser &parser, bool allowVariadic, 118 SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic, 119 SmallVectorImpl<Type> &resultTypes, 120 SmallVectorImpl<DictionaryAttr> &resultAttrs) { 121 if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic)) 122 return failure(); 123 if (succeeded(parser.parseOptionalArrow())) 124 return parseFunctionResultList(parser, resultTypes, resultAttrs); 125 return success(); 126 } 127 128 void mlir::function_interface_impl::addArgAndResultAttrs( 129 Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs, 130 ArrayRef<DictionaryAttr> resultAttrs) { 131 auto nonEmptyAttrsFn = [](DictionaryAttr attrs) { 132 return attrs && !attrs.empty(); 133 }; 134 // Convert the specified array of dictionary attrs (which may have null 135 // entries) to an ArrayAttr of dictionaries. 136 auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) { 137 SmallVector<Attribute> attrs; 138 for (auto &dict : dictAttrs) 139 attrs.push_back(dict ? dict : builder.getDictionaryAttr({})); 140 return builder.getArrayAttr(attrs); 141 }; 142 143 // Add the attributes to the function arguments. 144 if (llvm::any_of(argAttrs, nonEmptyAttrsFn)) 145 result.addAttribute(function_interface_impl::getArgDictAttrName(), 146 getArrayAttr(argAttrs)); 147 148 // Add the attributes to the function results. 149 if (llvm::any_of(resultAttrs, nonEmptyAttrsFn)) 150 result.addAttribute(function_interface_impl::getResultDictAttrName(), 151 getArrayAttr(resultAttrs)); 152 } 153 154 void mlir::function_interface_impl::addArgAndResultAttrs( 155 Builder &builder, OperationState &result, 156 ArrayRef<OpAsmParser::Argument> args, 157 ArrayRef<DictionaryAttr> resultAttrs) { 158 SmallVector<DictionaryAttr> argAttrs; 159 for (const auto &arg : args) 160 argAttrs.push_back(arg.attrs); 161 addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); 162 } 163 164 ParseResult mlir::function_interface_impl::parseFunctionOp( 165 OpAsmParser &parser, OperationState &result, bool allowVariadic, 166 FuncTypeBuilder funcTypeBuilder) { 167 SmallVector<OpAsmParser::Argument> entryArgs; 168 SmallVector<DictionaryAttr> resultAttrs; 169 SmallVector<Type> resultTypes; 170 auto &builder = parser.getBuilder(); 171 172 // Parse visibility. 173 impl::parseOptionalVisibilityKeyword(parser, result.attributes); 174 175 // Parse the name as a symbol. 176 StringAttr nameAttr; 177 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), 178 result.attributes)) 179 return failure(); 180 181 // Parse the function signature. 182 SMLoc signatureLocation = parser.getCurrentLocation(); 183 bool isVariadic = false; 184 if (parseFunctionSignature(parser, allowVariadic, entryArgs, isVariadic, 185 resultTypes, resultAttrs)) 186 return failure(); 187 188 std::string errorMessage; 189 SmallVector<Type> argTypes; 190 argTypes.reserve(entryArgs.size()); 191 for (auto &arg : entryArgs) 192 argTypes.push_back(arg.type); 193 Type type = funcTypeBuilder(builder, argTypes, resultTypes, 194 VariadicFlag(isVariadic), errorMessage); 195 if (!type) { 196 return parser.emitError(signatureLocation) 197 << "failed to construct function type" 198 << (errorMessage.empty() ? "" : ": ") << errorMessage; 199 } 200 result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); 201 202 // If function attributes are present, parse them. 203 NamedAttrList parsedAttributes; 204 SMLoc attributeDictLocation = parser.getCurrentLocation(); 205 if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes)) 206 return failure(); 207 208 // Disallow attributes that are inferred from elsewhere in the attribute 209 // dictionary. 210 for (StringRef disallowed : 211 {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(), 212 getTypeAttrName()}) { 213 if (parsedAttributes.get(disallowed)) 214 return parser.emitError(attributeDictLocation, "'") 215 << disallowed 216 << "' is an inferred attribute and should not be specified in the " 217 "explicit attribute dictionary"; 218 } 219 result.attributes.append(parsedAttributes); 220 221 // Add the attributes to the function arguments. 222 assert(resultAttrs.size() == resultTypes.size()); 223 addArgAndResultAttrs(builder, result, entryArgs, resultAttrs); 224 225 // Parse the optional function body. The printer will not print the body if 226 // its empty, so disallow parsing of empty body in the parser. 227 auto *body = result.addRegion(); 228 SMLoc loc = parser.getCurrentLocation(); 229 OptionalParseResult parseResult = 230 parser.parseOptionalRegion(*body, entryArgs, 231 /*enableNameShadowing=*/false); 232 if (parseResult.hasValue()) { 233 if (failed(*parseResult)) 234 return failure(); 235 // Function body was parsed, make sure its not empty. 236 if (body->empty()) 237 return parser.emitError(loc, "expected non-empty function body"); 238 } 239 return success(); 240 } 241 242 /// Print a function result list. The provided `attrs` must either be null, or 243 /// contain a set of DictionaryAttrs of the same arity as `types`. 244 static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types, 245 ArrayAttr attrs) { 246 assert(!types.empty() && "Should not be called for empty result list."); 247 assert((!attrs || attrs.size() == types.size()) && 248 "Invalid number of attributes."); 249 250 auto &os = p.getStream(); 251 bool needsParens = types.size() > 1 || types[0].isa<FunctionType>() || 252 (attrs && !attrs[0].cast<DictionaryAttr>().empty()); 253 if (needsParens) 254 os << '('; 255 llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) { 256 p.printType(types[i]); 257 if (attrs) 258 p.printOptionalAttrDict(attrs[i].cast<DictionaryAttr>().getValue()); 259 }); 260 if (needsParens) 261 os << ')'; 262 } 263 264 void mlir::function_interface_impl::printFunctionSignature( 265 OpAsmPrinter &p, Operation *op, ArrayRef<Type> argTypes, bool isVariadic, 266 ArrayRef<Type> resultTypes) { 267 Region &body = op->getRegion(0); 268 bool isExternal = body.empty(); 269 270 p << '('; 271 ArrayAttr argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName()); 272 for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { 273 if (i > 0) 274 p << ", "; 275 276 if (!isExternal) { 277 ArrayRef<NamedAttribute> attrs; 278 if (argAttrs) 279 attrs = argAttrs[i].cast<DictionaryAttr>().getValue(); 280 p.printRegionArgument(body.getArgument(i), attrs); 281 } else { 282 p.printType(argTypes[i]); 283 if (argAttrs) 284 p.printOptionalAttrDict(argAttrs[i].cast<DictionaryAttr>().getValue()); 285 } 286 } 287 288 if (isVariadic) { 289 if (!argTypes.empty()) 290 p << ", "; 291 p << "..."; 292 } 293 294 p << ')'; 295 296 if (!resultTypes.empty()) { 297 p.getStream() << " -> "; 298 auto resultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName()); 299 printFunctionResultList(p, resultTypes, resultAttrs); 300 } 301 } 302 303 void mlir::function_interface_impl::printFunctionAttributes( 304 OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults, 305 ArrayRef<StringRef> elided) { 306 // Print out function attributes, if present. 307 SmallVector<StringRef, 2> ignoredAttrs = { 308 ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(), 309 getArgDictAttrName(), getResultDictAttrName()}; 310 ignoredAttrs.append(elided.begin(), elided.end()); 311 312 p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); 313 } 314 315 void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p, 316 FunctionOpInterface op, 317 bool isVariadic) { 318 // Print the operation and the function name. 319 auto funcName = 320 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()) 321 .getValue(); 322 p << ' '; 323 324 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName(); 325 if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName)) 326 p << visibility.getValue() << ' '; 327 p.printSymbolName(funcName); 328 329 ArrayRef<Type> argTypes = op.getArgumentTypes(); 330 ArrayRef<Type> resultTypes = op.getResultTypes(); 331 printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); 332 printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(), 333 {visibilityAttrName}); 334 // Print the body if this is not an external function. 335 Region &body = op->getRegion(0); 336 if (!body.empty()) { 337 p << ' '; 338 p.printRegion(body, /*printEntryBlockArgs=*/false, 339 /*printBlockTerminators=*/true); 340 } 341 } 342