1 //===- FunctionImplementation.cpp - Utilities for function-like ops -------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/IR/FunctionImplementation.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/FunctionSupport.h" 21 #include "mlir/IR/SymbolTable.h" 22 23 using namespace mlir; 24 25 static ParseResult 26 parseArgumentList(OpAsmParser &parser, bool allowVariadic, 27 SmallVectorImpl<Type> &argTypes, 28 SmallVectorImpl<OpAsmParser::OperandType> &argNames, 29 SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs, 30 bool &isVariadic) { 31 if (parser.parseLParen()) 32 return failure(); 33 34 // The argument list either has to consistently have ssa-id's followed by 35 // types, or just be a type list. It isn't ok to sometimes have SSA ID's and 36 // sometimes not. 37 auto parseArgument = [&]() -> ParseResult { 38 llvm::SMLoc loc = parser.getCurrentLocation(); 39 40 // Parse argument name if present. 41 OpAsmParser::OperandType argument; 42 Type argumentType; 43 if (succeeded(parser.parseOptionalRegionArgument(argument)) && 44 !argument.name.empty()) { 45 // Reject this if the preceding argument was missing a name. 46 if (argNames.empty() && !argTypes.empty()) 47 return parser.emitError(loc, "expected type instead of SSA identifier"); 48 argNames.push_back(argument); 49 50 if (parser.parseColonType(argumentType)) 51 return failure(); 52 } else if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) { 53 isVariadic = true; 54 return success(); 55 } else if (!argNames.empty()) { 56 // Reject this if the preceding argument had a name. 57 return parser.emitError(loc, "expected SSA identifier"); 58 } else if (parser.parseType(argumentType)) { 59 return failure(); 60 } 61 62 // Add the argument type. 63 argTypes.push_back(argumentType); 64 65 // Parse any argument attributes. 66 SmallVector<NamedAttribute, 2> attrs; 67 if (parser.parseOptionalAttrDict(attrs)) 68 return failure(); 69 argAttrs.push_back(attrs); 70 return success(); 71 }; 72 73 // Parse the function arguments. 74 if (parser.parseOptionalRParen()) { 75 do { 76 unsigned numTypedArguments = argTypes.size(); 77 if (parseArgument()) 78 return failure(); 79 80 llvm::SMLoc loc = parser.getCurrentLocation(); 81 if (argTypes.size() == numTypedArguments && 82 succeeded(parser.parseOptionalComma())) 83 return parser.emitError( 84 loc, "variadic arguments must be in the end of the argument list"); 85 } while (succeeded(parser.parseOptionalComma())); 86 parser.parseRParen(); 87 } 88 89 return success(); 90 } 91 92 /// Parse a function result list. 93 /// 94 /// function-result-list ::= function-result-list-parens 95 /// | non-function-type 96 /// function-result-list-parens ::= `(` `)` 97 /// | `(` function-result-list-no-parens `)` 98 /// function-result-list-no-parens ::= function-result (`,` function-result)* 99 /// function-result ::= type attribute-dict? 100 /// 101 static ParseResult parseFunctionResultList( 102 OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes, 103 SmallVectorImpl<SmallVector<NamedAttribute, 2>> &resultAttrs) { 104 if (failed(parser.parseOptionalLParen())) { 105 // We already know that there is no `(`, so parse a type. 106 // Because there is no `(`, it cannot be a function type. 107 Type ty; 108 if (parser.parseType(ty)) 109 return failure(); 110 resultTypes.push_back(ty); 111 resultAttrs.emplace_back(); 112 return success(); 113 } 114 115 // Special case for an empty set of parens. 116 if (succeeded(parser.parseOptionalRParen())) 117 return success(); 118 119 // Parse individual function results. 120 do { 121 resultTypes.emplace_back(); 122 resultAttrs.emplace_back(); 123 if (parser.parseType(resultTypes.back()) || 124 parser.parseOptionalAttrDict(resultAttrs.back())) { 125 return failure(); 126 } 127 } while (succeeded(parser.parseOptionalComma())); 128 return parser.parseRParen(); 129 } 130 131 /// Parses a function signature using `parser`. The `allowVariadic` argument 132 /// indicates whether functions with variadic arguments are supported. The 133 /// trailing arguments are populated by this function with names, types and 134 /// attributes of the arguments and those of the results. 135 ParseResult mlir::impl::parseFunctionSignature( 136 OpAsmParser &parser, bool allowVariadic, 137 SmallVectorImpl<OpAsmParser::OperandType> &argNames, 138 SmallVectorImpl<Type> &argTypes, 139 SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs, bool &isVariadic, 140 SmallVectorImpl<Type> &resultTypes, 141 SmallVectorImpl<SmallVector<NamedAttribute, 2>> &resultAttrs) { 142 if (parseArgumentList(parser, allowVariadic, argTypes, argNames, argAttrs, 143 isVariadic)) 144 return failure(); 145 if (succeeded(parser.parseOptionalArrow())) 146 return parseFunctionResultList(parser, resultTypes, resultAttrs); 147 return success(); 148 } 149 150 void mlir::impl::addArgAndResultAttrs( 151 Builder &builder, OperationState &result, 152 ArrayRef<SmallVector<NamedAttribute, 2>> argAttrs, 153 ArrayRef<SmallVector<NamedAttribute, 2>> resultAttrs) { 154 // Add the attributes to the function arguments. 155 SmallString<8> attrNameBuf; 156 for (unsigned i = 0, e = argAttrs.size(); i != e; ++i) 157 if (!argAttrs[i].empty()) 158 result.addAttribute(getArgAttrName(i, attrNameBuf), 159 builder.getDictionaryAttr(argAttrs[i])); 160 161 // Add the attributes to the function results. 162 for (unsigned i = 0, e = resultAttrs.size(); i != e; ++i) 163 if (!resultAttrs[i].empty()) 164 result.addAttribute(getResultAttrName(i, attrNameBuf), 165 builder.getDictionaryAttr(resultAttrs[i])); 166 } 167 168 /// Parser implementation for function-like operations. Uses `funcTypeBuilder` 169 /// to construct the custom function type given lists of input and output types. 170 ParseResult 171 mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, 172 bool allowVariadic, 173 mlir::impl::FuncTypeBuilder funcTypeBuilder) { 174 SmallVector<OpAsmParser::OperandType, 4> entryArgs; 175 SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs; 176 SmallVector<SmallVector<NamedAttribute, 2>, 4> resultAttrs; 177 SmallVector<Type, 4> argTypes; 178 SmallVector<Type, 4> resultTypes; 179 auto &builder = parser.getBuilder(); 180 181 // Parse the name as a symbol. 182 StringAttr nameAttr; 183 if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), 184 result.attributes)) 185 return failure(); 186 187 // Parse the function signature. 188 auto signatureLocation = parser.getCurrentLocation(); 189 bool isVariadic = false; 190 if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes, 191 argAttrs, isVariadic, resultTypes, resultAttrs)) 192 return failure(); 193 194 std::string errorMessage; 195 if (auto type = funcTypeBuilder(builder, argTypes, resultTypes, 196 impl::VariadicFlag(isVariadic), errorMessage)) 197 result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); 198 else 199 return parser.emitError(signatureLocation) 200 << "failed to construct function type" 201 << (errorMessage.empty() ? "" : ": ") << errorMessage; 202 203 // If function attributes are present, parse them. 204 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 205 return failure(); 206 207 // Add the attributes to the function arguments. 208 assert(argAttrs.size() == argTypes.size()); 209 assert(resultAttrs.size() == resultTypes.size()); 210 addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); 211 212 // Parse the optional function body. 213 auto *body = result.addRegion(); 214 return parser.parseOptionalRegion( 215 *body, entryArgs, entryArgs.empty() ? llvm::ArrayRef<Type>() : argTypes); 216 } 217 218 // Print a function result list. 219 static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types, 220 ArrayRef<ArrayRef<NamedAttribute>> attrs) { 221 assert(!types.empty() && "Should not be called for empty result list."); 222 auto &os = p.getStream(); 223 bool needsParens = 224 types.size() > 1 || types[0].isa<FunctionType>() || !attrs[0].empty(); 225 if (needsParens) 226 os << '('; 227 interleaveComma(llvm::zip(types, attrs), os, 228 [&](const std::tuple<Type, ArrayRef<NamedAttribute>> &t) { 229 p.printType(std::get<0>(t)); 230 p.printOptionalAttrDict(std::get<1>(t)); 231 }); 232 if (needsParens) 233 os << ')'; 234 } 235 236 /// Print the signature of the function-like operation `op`. Assumes `op` has 237 /// the FunctionLike trait and passed the verification. 238 void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op, 239 ArrayRef<Type> argTypes, 240 bool isVariadic, 241 ArrayRef<Type> resultTypes) { 242 Region &body = op->getRegion(0); 243 bool isExternal = body.empty(); 244 245 p << '('; 246 for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { 247 if (i > 0) 248 p << ", "; 249 250 if (!isExternal) { 251 p.printOperand(body.front().getArgument(i)); 252 p << ": "; 253 } 254 255 p.printType(argTypes[i]); 256 p.printOptionalAttrDict(::mlir::impl::getArgAttrs(op, i)); 257 } 258 259 if (isVariadic) { 260 if (!argTypes.empty()) 261 p << ", "; 262 p << "..."; 263 } 264 265 p << ')'; 266 267 if (!resultTypes.empty()) { 268 p.getStream() << " -> "; 269 SmallVector<ArrayRef<NamedAttribute>, 4> resultAttrs; 270 for (int i = 0, e = resultTypes.size(); i < e; ++i) 271 resultAttrs.push_back(::mlir::impl::getResultAttrs(op, i)); 272 printFunctionResultList(p, resultTypes, resultAttrs); 273 } 274 } 275 276 /// Prints the list of function prefixed with the "attributes" keyword. The 277 /// attributes with names listed in "elided" as well as those used by the 278 /// function-like operation internally are not printed. Nothing is printed 279 /// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and 280 /// passed the verification. 281 void mlir::impl::printFunctionAttributes(OpAsmPrinter &p, Operation *op, 282 unsigned numInputs, 283 unsigned numResults, 284 ArrayRef<StringRef> elided) { 285 // Print out function attributes, if present. 286 SmallVector<StringRef, 2> ignoredAttrs = { 287 ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName()}; 288 ignoredAttrs.append(elided.begin(), elided.end()); 289 290 SmallString<8> attrNameBuf; 291 292 // Ignore any argument attributes. 293 std::vector<SmallString<8>> argAttrStorage; 294 for (unsigned i = 0; i != numInputs; ++i) 295 if (op->getAttr(getArgAttrName(i, attrNameBuf))) 296 argAttrStorage.emplace_back(attrNameBuf); 297 ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end()); 298 299 // Ignore any result attributes. 300 std::vector<SmallString<8>> resultAttrStorage; 301 for (unsigned i = 0; i != numResults; ++i) 302 if (op->getAttr(getResultAttrName(i, attrNameBuf))) 303 resultAttrStorage.emplace_back(attrNameBuf); 304 ignoredAttrs.append(resultAttrStorage.begin(), resultAttrStorage.end()); 305 306 p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); 307 } 308 309 /// Printer implementation for function-like operations. Accepts lists of 310 /// argument and result types to use while printing. 311 void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op, 312 ArrayRef<Type> argTypes, bool isVariadic, 313 ArrayRef<Type> resultTypes) { 314 // Print the operation and the function name. 315 auto funcName = 316 op->getAttrOfType<StringAttr>(::mlir::SymbolTable::getSymbolAttrName()) 317 .getValue(); 318 p << op->getName() << ' '; 319 p.printSymbolName(funcName); 320 321 printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); 322 printFunctionAttributes(p, op, argTypes.size(), resultTypes.size()); 323 324 // Print the body if this is not an external function. 325 Region &body = op->getRegion(0); 326 if (!body.empty()) 327 p.printRegion(body, /*printEntryBlockArgs=*/false, 328 /*printBlockTerminators=*/true); 329 } 330