1 //===- FunctionImplementation.cpp - Utilities for function-like ops -------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/IR/FunctionImplementation.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/FunctionSupport.h" 21 #include "mlir/IR/SymbolTable.h" 22 23 using namespace mlir; 24 25 static ParseResult 26 parseArgumentList(OpAsmParser &parser, bool allowVariadic, 27 SmallVectorImpl<Type> &argTypes, 28 SmallVectorImpl<OpAsmParser::OperandType> &argNames, 29 SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs, 30 bool &isVariadic) { 31 if (parser.parseLParen()) 32 return failure(); 33 34 // The argument list either has to consistently have ssa-id's followed by 35 // types, or just be a type list. It isn't ok to sometimes have SSA ID's and 36 // sometimes not. 37 auto parseArgument = [&]() -> ParseResult { 38 llvm::SMLoc loc = parser.getCurrentLocation(); 39 40 // Parse argument name if present. 41 OpAsmParser::OperandType argument; 42 Type argumentType; 43 if (succeeded(parser.parseOptionalRegionArgument(argument)) && 44 !argument.name.empty()) { 45 // Reject this if the preceding argument was missing a name. 46 if (argNames.empty() && !argTypes.empty()) 47 return parser.emitError(loc, "expected type instead of SSA identifier"); 48 argNames.push_back(argument); 49 50 if (parser.parseColonType(argumentType)) 51 return failure(); 52 } else if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) { 53 isVariadic = true; 54 return success(); 55 } else if (!argNames.empty()) { 56 // Reject this if the preceding argument had a name. 57 return parser.emitError(loc, "expected SSA identifier"); 58 } else if (parser.parseType(argumentType)) { 59 return failure(); 60 } 61 62 // Add the argument type. 63 argTypes.push_back(argumentType); 64 65 // Parse any argument attributes. 66 SmallVector<NamedAttribute, 2> attrs; 67 if (parser.parseOptionalAttrDict(attrs)) 68 return failure(); 69 argAttrs.push_back(attrs); 70 return success(); 71 }; 72 73 // Parse the function arguments. 74 isVariadic = false; 75 if (failed(parser.parseOptionalRParen())) { 76 do { 77 unsigned numTypedArguments = argTypes.size(); 78 if (parseArgument()) 79 return failure(); 80 81 llvm::SMLoc loc = parser.getCurrentLocation(); 82 if (argTypes.size() == numTypedArguments && 83 succeeded(parser.parseOptionalComma())) 84 return parser.emitError( 85 loc, "variadic arguments must be in the end of the argument list"); 86 } while (succeeded(parser.parseOptionalComma())); 87 parser.parseRParen(); 88 } 89 90 return success(); 91 } 92 93 /// Parse a function result list. 94 /// 95 /// function-result-list ::= function-result-list-parens 96 /// | non-function-type 97 /// function-result-list-parens ::= `(` `)` 98 /// | `(` function-result-list-no-parens `)` 99 /// function-result-list-no-parens ::= function-result (`,` function-result)* 100 /// function-result ::= type attribute-dict? 101 /// 102 static ParseResult parseFunctionResultList( 103 OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes, 104 SmallVectorImpl<SmallVector<NamedAttribute, 2>> &resultAttrs) { 105 if (failed(parser.parseOptionalLParen())) { 106 // We already know that there is no `(`, so parse a type. 107 // Because there is no `(`, it cannot be a function type. 108 Type ty; 109 if (parser.parseType(ty)) 110 return failure(); 111 resultTypes.push_back(ty); 112 resultAttrs.emplace_back(); 113 return success(); 114 } 115 116 // Special case for an empty set of parens. 117 if (succeeded(parser.parseOptionalRParen())) 118 return success(); 119 120 // Parse individual function results. 121 do { 122 resultTypes.emplace_back(); 123 resultAttrs.emplace_back(); 124 if (parser.parseType(resultTypes.back()) || 125 parser.parseOptionalAttrDict(resultAttrs.back())) { 126 return failure(); 127 } 128 } while (succeeded(parser.parseOptionalComma())); 129 return parser.parseRParen(); 130 } 131 132 /// Parses a function signature using `parser`. The `allowVariadic` argument 133 /// indicates whether functions with variadic arguments are supported. The 134 /// trailing arguments are populated by this function with names, types and 135 /// attributes of the arguments and those of the results. 136 ParseResult mlir::impl::parseFunctionSignature( 137 OpAsmParser &parser, bool allowVariadic, 138 SmallVectorImpl<OpAsmParser::OperandType> &argNames, 139 SmallVectorImpl<Type> &argTypes, 140 SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs, bool &isVariadic, 141 SmallVectorImpl<Type> &resultTypes, 142 SmallVectorImpl<SmallVector<NamedAttribute, 2>> &resultAttrs) { 143 if (parseArgumentList(parser, allowVariadic, argTypes, argNames, argAttrs, 144 isVariadic)) 145 return failure(); 146 if (succeeded(parser.parseOptionalArrow())) 147 return parseFunctionResultList(parser, resultTypes, resultAttrs); 148 return success(); 149 } 150 151 void mlir::impl::addArgAndResultAttrs( 152 Builder &builder, OperationState &result, 153 ArrayRef<SmallVector<NamedAttribute, 2>> argAttrs, 154 ArrayRef<SmallVector<NamedAttribute, 2>> resultAttrs) { 155 // Add the attributes to the function arguments. 156 SmallString<8> attrNameBuf; 157 for (unsigned i = 0, e = argAttrs.size(); i != e; ++i) 158 if (!argAttrs[i].empty()) 159 result.addAttribute(getArgAttrName(i, attrNameBuf), 160 builder.getDictionaryAttr(argAttrs[i])); 161 162 // Add the attributes to the function results. 163 for (unsigned i = 0, e = resultAttrs.size(); i != e; ++i) 164 if (!resultAttrs[i].empty()) 165 result.addAttribute(getResultAttrName(i, attrNameBuf), 166 builder.getDictionaryAttr(resultAttrs[i])); 167 } 168 169 /// Parser implementation for function-like operations. Uses `funcTypeBuilder` 170 /// to construct the custom function type given lists of input and output types. 171 ParseResult 172 mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, 173 bool allowVariadic, 174 mlir::impl::FuncTypeBuilder funcTypeBuilder) { 175 SmallVector<OpAsmParser::OperandType, 4> entryArgs; 176 SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs; 177 SmallVector<SmallVector<NamedAttribute, 2>, 4> resultAttrs; 178 SmallVector<Type, 4> argTypes; 179 SmallVector<Type, 4> resultTypes; 180 auto &builder = parser.getBuilder(); 181 182 // Parse the name as a symbol. 183 StringAttr nameAttr; 184 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 185 result.attributes)) 186 return failure(); 187 188 // Parse the function signature. 189 auto signatureLocation = parser.getCurrentLocation(); 190 bool isVariadic = false; 191 if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes, 192 argAttrs, isVariadic, resultTypes, resultAttrs)) 193 return failure(); 194 195 std::string errorMessage; 196 if (auto type = funcTypeBuilder(builder, argTypes, resultTypes, 197 impl::VariadicFlag(isVariadic), errorMessage)) 198 result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); 199 else 200 return parser.emitError(signatureLocation) 201 << "failed to construct function type" 202 << (errorMessage.empty() ? "" : ": ") << errorMessage; 203 204 // If function attributes are present, parse them. 205 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 206 return failure(); 207 208 // Add the attributes to the function arguments. 209 assert(argAttrs.size() == argTypes.size()); 210 assert(resultAttrs.size() == resultTypes.size()); 211 addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); 212 213 // Parse the optional function body. 214 auto *body = result.addRegion(); 215 return parser.parseOptionalRegion( 216 *body, entryArgs, entryArgs.empty() ? llvm::ArrayRef<Type>() : argTypes); 217 } 218 219 // Print a function result list. 220 static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types, 221 ArrayRef<ArrayRef<NamedAttribute>> attrs) { 222 assert(!types.empty() && "Should not be called for empty result list."); 223 auto &os = p.getStream(); 224 bool needsParens = 225 types.size() > 1 || types[0].isa<FunctionType>() || !attrs[0].empty(); 226 if (needsParens) 227 os << '('; 228 interleaveComma(llvm::zip(types, attrs), os, 229 [&](const std::tuple<Type, ArrayRef<NamedAttribute>> &t) { 230 p.printType(std::get<0>(t)); 231 p.printOptionalAttrDict(std::get<1>(t)); 232 }); 233 if (needsParens) 234 os << ')'; 235 } 236 237 /// Print the signature of the function-like operation `op`. Assumes `op` has 238 /// the FunctionLike trait and passed the verification. 239 void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op, 240 ArrayRef<Type> argTypes, 241 bool isVariadic, 242 ArrayRef<Type> resultTypes) { 243 Region &body = op->getRegion(0); 244 bool isExternal = body.empty(); 245 246 p << '('; 247 for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { 248 if (i > 0) 249 p << ", "; 250 251 if (!isExternal) { 252 p.printOperand(body.front().getArgument(i)); 253 p << ": "; 254 } 255 256 p.printType(argTypes[i]); 257 p.printOptionalAttrDict(::mlir::impl::getArgAttrs(op, i)); 258 } 259 260 if (isVariadic) { 261 if (!argTypes.empty()) 262 p << ", "; 263 p << "..."; 264 } 265 266 p << ')'; 267 268 if (!resultTypes.empty()) { 269 p.getStream() << " -> "; 270 SmallVector<ArrayRef<NamedAttribute>, 4> resultAttrs; 271 for (int i = 0, e = resultTypes.size(); i < e; ++i) 272 resultAttrs.push_back(::mlir::impl::getResultAttrs(op, i)); 273 printFunctionResultList(p, resultTypes, resultAttrs); 274 } 275 } 276 277 /// Prints the list of function prefixed with the "attributes" keyword. The 278 /// attributes with names listed in "elided" as well as those used by the 279 /// function-like operation internally are not printed. Nothing is printed 280 /// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and 281 /// passed the verification. 282 void mlir::impl::printFunctionAttributes(OpAsmPrinter &p, Operation *op, 283 unsigned numInputs, 284 unsigned numResults, 285 ArrayRef<StringRef> elided) { 286 // Print out function attributes, if present. 287 SmallVector<StringRef, 2> ignoredAttrs = { 288 ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName()}; 289 ignoredAttrs.append(elided.begin(), elided.end()); 290 291 SmallString<8> attrNameBuf; 292 293 // Ignore any argument attributes. 294 std::vector<SmallString<8>> argAttrStorage; 295 for (unsigned i = 0; i != numInputs; ++i) 296 if (op->getAttr(getArgAttrName(i, attrNameBuf))) 297 argAttrStorage.emplace_back(attrNameBuf); 298 ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end()); 299 300 // Ignore any result attributes. 301 std::vector<SmallString<8>> resultAttrStorage; 302 for (unsigned i = 0; i != numResults; ++i) 303 if (op->getAttr(getResultAttrName(i, attrNameBuf))) 304 resultAttrStorage.emplace_back(attrNameBuf); 305 ignoredAttrs.append(resultAttrStorage.begin(), resultAttrStorage.end()); 306 307 p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); 308 } 309 310 /// Printer implementation for function-like operations. Accepts lists of 311 /// argument and result types to use while printing. 312 void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op, 313 ArrayRef<Type> argTypes, bool isVariadic, 314 ArrayRef<Type> resultTypes) { 315 // Print the operation and the function name. 316 auto funcName = 317 op->getAttrOfType<StringAttr>(::mlir::SymbolTable::getSymbolAttrName()) 318 .getValue(); 319 p << op->getName() << ' '; 320 p.printSymbolName(funcName); 321 322 printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); 323 printFunctionAttributes(p, op, argTypes.size(), resultTypes.size()); 324 325 // Print the body if this is not an external function. 326 Region &body = op->getRegion(0); 327 if (!body.empty()) 328 p.printRegion(body, /*printEntryBlockArgs=*/false, 329 /*printBlockTerminators=*/true); 330 } 331