1 //===- FunctionImplementation.cpp - Utilities for function-like ops -------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/IR/FunctionImplementation.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/FunctionSupport.h"
21 #include "mlir/IR/SymbolTable.h"
22 
23 using namespace mlir;
24 
25 static ParseResult
26 parseArgumentList(OpAsmParser &parser, bool allowVariadic,
27                   SmallVectorImpl<Type> &argTypes,
28                   SmallVectorImpl<OpAsmParser::OperandType> &argNames,
29                   SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs,
30                   bool &isVariadic) {
31   if (parser.parseLParen())
32     return failure();
33 
34   // The argument list either has to consistently have ssa-id's followed by
35   // types, or just be a type list.  It isn't ok to sometimes have SSA ID's and
36   // sometimes not.
37   auto parseArgument = [&]() -> ParseResult {
38     llvm::SMLoc loc = parser.getCurrentLocation();
39 
40     // Parse argument name if present.
41     OpAsmParser::OperandType argument;
42     Type argumentType;
43     if (succeeded(parser.parseOptionalRegionArgument(argument)) &&
44         !argument.name.empty()) {
45       // Reject this if the preceding argument was missing a name.
46       if (argNames.empty() && !argTypes.empty())
47         return parser.emitError(loc, "expected type instead of SSA identifier");
48       argNames.push_back(argument);
49 
50       if (parser.parseColonType(argumentType))
51         return failure();
52     } else if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) {
53       isVariadic = true;
54       return success();
55     } else if (!argNames.empty()) {
56       // Reject this if the preceding argument had a name.
57       return parser.emitError(loc, "expected SSA identifier");
58     } else if (parser.parseType(argumentType)) {
59       return failure();
60     }
61 
62     // Add the argument type.
63     argTypes.push_back(argumentType);
64 
65     // Parse any argument attributes.
66     SmallVector<NamedAttribute, 2> attrs;
67     if (parser.parseOptionalAttrDict(attrs))
68       return failure();
69     argAttrs.push_back(attrs);
70     return success();
71   };
72 
73   // Parse the function arguments.
74   if (parser.parseOptionalRParen()) {
75     do {
76       unsigned numTypedArguments = argTypes.size();
77       if (parseArgument())
78         return failure();
79 
80       llvm::SMLoc loc = parser.getCurrentLocation();
81       if (argTypes.size() == numTypedArguments &&
82           succeeded(parser.parseOptionalComma()))
83         return parser.emitError(
84             loc, "variadic arguments must be in the end of the argument list");
85     } while (succeeded(parser.parseOptionalComma()));
86     parser.parseRParen();
87   }
88 
89   return success();
90 }
91 
92 /// Parse a function result list.
93 ///
94 ///   function-result-list ::= function-result-list-parens
95 ///                          | non-function-type
96 ///   function-result-list-parens ::= `(` `)`
97 ///                                 | `(` function-result-list-no-parens `)`
98 ///   function-result-list-no-parens ::= function-result (`,` function-result)*
99 ///   function-result ::= type attribute-dict?
100 ///
101 static ParseResult parseFunctionResultList(
102     OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
103     SmallVectorImpl<SmallVector<NamedAttribute, 2>> &resultAttrs) {
104   if (failed(parser.parseOptionalLParen())) {
105     // We already know that there is no `(`, so parse a type.
106     // Because there is no `(`, it cannot be a function type.
107     Type ty;
108     if (parser.parseType(ty))
109       return failure();
110     resultTypes.push_back(ty);
111     resultAttrs.emplace_back();
112     return success();
113   }
114 
115   // Special case for an empty set of parens.
116   if (succeeded(parser.parseOptionalRParen()))
117     return success();
118 
119   // Parse individual function results.
120   do {
121     resultTypes.emplace_back();
122     resultAttrs.emplace_back();
123     if (parser.parseType(resultTypes.back()) ||
124         parser.parseOptionalAttrDict(resultAttrs.back())) {
125       return failure();
126     }
127   } while (succeeded(parser.parseOptionalComma()));
128   return parser.parseRParen();
129 }
130 
131 /// Parses a function signature using `parser`. The `allowVariadic` argument
132 /// indicates whether functions with variadic arguments are supported. The
133 /// trailing arguments are populated by this function with names, types and
134 /// attributes of the arguments and those of the results.
135 ParseResult mlir::impl::parseFunctionSignature(
136     OpAsmParser &parser, bool allowVariadic,
137     SmallVectorImpl<OpAsmParser::OperandType> &argNames,
138     SmallVectorImpl<Type> &argTypes,
139     SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs, bool &isVariadic,
140     SmallVectorImpl<Type> &resultTypes,
141     SmallVectorImpl<SmallVector<NamedAttribute, 2>> &resultAttrs) {
142   if (parseArgumentList(parser, allowVariadic, argTypes, argNames, argAttrs,
143                         isVariadic))
144     return failure();
145   if (succeeded(parser.parseOptionalArrow()))
146     return parseFunctionResultList(parser, resultTypes, resultAttrs);
147   return success();
148 }
149 
150 void mlir::impl::addArgAndResultAttrs(
151     Builder &builder, OperationState &result,
152     ArrayRef<SmallVector<NamedAttribute, 2>> argAttrs,
153     ArrayRef<SmallVector<NamedAttribute, 2>> resultAttrs) {
154   // Add the attributes to the function arguments.
155   SmallString<8> attrNameBuf;
156   for (unsigned i = 0, e = argAttrs.size(); i != e; ++i)
157     if (!argAttrs[i].empty())
158       result.addAttribute(getArgAttrName(i, attrNameBuf),
159                           builder.getDictionaryAttr(argAttrs[i]));
160 
161   // Add the attributes to the function results.
162   for (unsigned i = 0, e = resultAttrs.size(); i != e; ++i)
163     if (!resultAttrs[i].empty())
164       result.addAttribute(getResultAttrName(i, attrNameBuf),
165                           builder.getDictionaryAttr(resultAttrs[i]));
166 }
167 
168 /// Parser implementation for function-like operations.  Uses `funcTypeBuilder`
169 /// to construct the custom function type given lists of input and output types.
170 ParseResult
171 mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
172                                 bool allowVariadic,
173                                 mlir::impl::FuncTypeBuilder funcTypeBuilder) {
174   SmallVector<OpAsmParser::OperandType, 4> entryArgs;
175   SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs;
176   SmallVector<SmallVector<NamedAttribute, 2>, 4> resultAttrs;
177   SmallVector<Type, 4> argTypes;
178   SmallVector<Type, 4> resultTypes;
179   auto &builder = parser.getBuilder();
180 
181   // Parse the name as a symbol.
182   StringAttr nameAttr;
183   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
184                              result.attributes))
185     return failure();
186 
187   // Parse the function signature.
188   auto signatureLocation = parser.getCurrentLocation();
189   bool isVariadic = false;
190   if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes,
191                              argAttrs, isVariadic, resultTypes, resultAttrs))
192     return failure();
193 
194   std::string errorMessage;
195   if (auto type = funcTypeBuilder(builder, argTypes, resultTypes,
196                                   impl::VariadicFlag(isVariadic), errorMessage))
197     result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
198   else
199     return parser.emitError(signatureLocation)
200            << "failed to construct function type"
201            << (errorMessage.empty() ? "" : ": ") << errorMessage;
202 
203   // If function attributes are present, parse them.
204   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
205     return failure();
206 
207   // Add the attributes to the function arguments.
208   assert(argAttrs.size() == argTypes.size());
209   assert(resultAttrs.size() == resultTypes.size());
210   addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
211 
212   // Parse the optional function body.
213   auto *body = result.addRegion();
214   return parser.parseOptionalRegion(
215       *body, entryArgs, entryArgs.empty() ? llvm::ArrayRef<Type>() : argTypes);
216 }
217 
218 // Print a function result list.
219 static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
220                                     ArrayRef<ArrayRef<NamedAttribute>> attrs) {
221   assert(!types.empty() && "Should not be called for empty result list.");
222   auto &os = p.getStream();
223   bool needsParens =
224       types.size() > 1 || types[0].isa<FunctionType>() || !attrs[0].empty();
225   if (needsParens)
226     os << '(';
227   interleaveComma(llvm::zip(types, attrs), os,
228                   [&](const std::tuple<Type, ArrayRef<NamedAttribute>> &t) {
229                     p.printType(std::get<0>(t));
230                     p.printOptionalAttrDict(std::get<1>(t));
231                   });
232   if (needsParens)
233     os << ')';
234 }
235 
236 /// Print the signature of the function-like operation `op`.  Assumes `op` has
237 /// the FunctionLike trait and passed the verification.
238 void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op,
239                                         ArrayRef<Type> argTypes,
240                                         bool isVariadic,
241                                         ArrayRef<Type> resultTypes) {
242   Region &body = op->getRegion(0);
243   bool isExternal = body.empty();
244 
245   p << '(';
246   for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
247     if (i > 0)
248       p << ", ";
249 
250     if (!isExternal) {
251       p.printOperand(body.front().getArgument(i));
252       p << ": ";
253     }
254 
255     p.printType(argTypes[i]);
256     p.printOptionalAttrDict(::mlir::impl::getArgAttrs(op, i));
257   }
258 
259   if (isVariadic) {
260     if (!argTypes.empty())
261       p << ", ";
262     p << "...";
263   }
264 
265   p << ')';
266 
267   if (!resultTypes.empty()) {
268     p.getStream() << " -> ";
269     SmallVector<ArrayRef<NamedAttribute>, 4> resultAttrs;
270     for (int i = 0, e = resultTypes.size(); i < e; ++i)
271       resultAttrs.push_back(::mlir::impl::getResultAttrs(op, i));
272     printFunctionResultList(p, resultTypes, resultAttrs);
273   }
274 }
275 
276 /// Prints the list of function prefixed with the "attributes" keyword. The
277 /// attributes with names listed in "elided" as well as those used by the
278 /// function-like operation internally are not printed. Nothing is printed
279 /// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and
280 /// passed the verification.
281 void mlir::impl::printFunctionAttributes(OpAsmPrinter &p, Operation *op,
282                                          unsigned numInputs,
283                                          unsigned numResults,
284                                          ArrayRef<StringRef> elided) {
285   // Print out function attributes, if present.
286   SmallVector<StringRef, 2> ignoredAttrs = {
287       ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName()};
288   ignoredAttrs.append(elided.begin(), elided.end());
289 
290   SmallString<8> attrNameBuf;
291 
292   // Ignore any argument attributes.
293   std::vector<SmallString<8>> argAttrStorage;
294   for (unsigned i = 0; i != numInputs; ++i)
295     if (op->getAttr(getArgAttrName(i, attrNameBuf)))
296       argAttrStorage.emplace_back(attrNameBuf);
297   ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end());
298 
299   // Ignore any result attributes.
300   std::vector<SmallString<8>> resultAttrStorage;
301   for (unsigned i = 0; i != numResults; ++i)
302     if (op->getAttr(getResultAttrName(i, attrNameBuf)))
303       resultAttrStorage.emplace_back(attrNameBuf);
304   ignoredAttrs.append(resultAttrStorage.begin(), resultAttrStorage.end());
305 
306   p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
307 }
308 
309 /// Printer implementation for function-like operations.  Accepts lists of
310 /// argument and result types to use while printing.
311 void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op,
312                                      ArrayRef<Type> argTypes, bool isVariadic,
313                                      ArrayRef<Type> resultTypes) {
314   // Print the operation and the function name.
315   auto funcName =
316       op->getAttrOfType<StringAttr>(::mlir::SymbolTable::getSymbolAttrName())
317           .getValue();
318   p << op->getName() << ' ';
319   p.printSymbolName(funcName);
320 
321   printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
322   printFunctionAttributes(p, op, argTypes.size(), resultTypes.size());
323 
324   // Print the body if this is not an external function.
325   Region &body = op->getRegion(0);
326   if (!body.empty())
327     p.printRegion(body, /*printEntryBlockArgs=*/false,
328                   /*printBlockTerminators=*/true);
329 }
330