1 //===- FunctionImplementation.cpp - Utilities for function-like ops -------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/IR/FunctionImplementation.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/FunctionSupport.h"
21 #include "mlir/IR/SymbolTable.h"
22 
23 using namespace mlir;
24 
25 static ParseResult
26 parseArgumentList(OpAsmParser &parser, bool allowVariadic,
27                   SmallVectorImpl<Type> &argTypes,
28                   SmallVectorImpl<OpAsmParser::OperandType> &argNames,
29                   SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs,
30                   bool &isVariadic) {
31   if (parser.parseLParen())
32     return failure();
33 
34   // The argument list either has to consistently have ssa-id's followed by
35   // types, or just be a type list.  It isn't ok to sometimes have SSA ID's and
36   // sometimes not.
37   auto parseArgument = [&]() -> ParseResult {
38     llvm::SMLoc loc = parser.getCurrentLocation();
39 
40     // Parse argument name if present.
41     OpAsmParser::OperandType argument;
42     Type argumentType;
43     if (succeeded(parser.parseOptionalRegionArgument(argument)) &&
44         !argument.name.empty()) {
45       // Reject this if the preceding argument was missing a name.
46       if (argNames.empty() && !argTypes.empty())
47         return parser.emitError(loc, "expected type instead of SSA identifier");
48       argNames.push_back(argument);
49 
50       if (parser.parseColonType(argumentType))
51         return failure();
52     } else if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) {
53       isVariadic = true;
54       return success();
55     } else if (!argNames.empty()) {
56       // Reject this if the preceding argument had a name.
57       return parser.emitError(loc, "expected SSA identifier");
58     } else if (parser.parseType(argumentType)) {
59       return failure();
60     }
61 
62     // Add the argument type.
63     argTypes.push_back(argumentType);
64 
65     // Parse any argument attributes.
66     SmallVector<NamedAttribute, 2> attrs;
67     if (parser.parseOptionalAttrDict(attrs))
68       return failure();
69     argAttrs.push_back(attrs);
70     return success();
71   };
72 
73   // Parse the function arguments.
74   isVariadic = false;
75   if (failed(parser.parseOptionalRParen())) {
76     do {
77       unsigned numTypedArguments = argTypes.size();
78       if (parseArgument())
79         return failure();
80 
81       llvm::SMLoc loc = parser.getCurrentLocation();
82       if (argTypes.size() == numTypedArguments &&
83           succeeded(parser.parseOptionalComma()))
84         return parser.emitError(
85             loc, "variadic arguments must be in the end of the argument list");
86     } while (succeeded(parser.parseOptionalComma()));
87     parser.parseRParen();
88   }
89 
90   return success();
91 }
92 
93 /// Parse a function result list.
94 ///
95 ///   function-result-list ::= function-result-list-parens
96 ///                          | non-function-type
97 ///   function-result-list-parens ::= `(` `)`
98 ///                                 | `(` function-result-list-no-parens `)`
99 ///   function-result-list-no-parens ::= function-result (`,` function-result)*
100 ///   function-result ::= type attribute-dict?
101 ///
102 static ParseResult parseFunctionResultList(
103     OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
104     SmallVectorImpl<SmallVector<NamedAttribute, 2>> &resultAttrs) {
105   if (failed(parser.parseOptionalLParen())) {
106     // We already know that there is no `(`, so parse a type.
107     // Because there is no `(`, it cannot be a function type.
108     Type ty;
109     if (parser.parseType(ty))
110       return failure();
111     resultTypes.push_back(ty);
112     resultAttrs.emplace_back();
113     return success();
114   }
115 
116   // Special case for an empty set of parens.
117   if (succeeded(parser.parseOptionalRParen()))
118     return success();
119 
120   // Parse individual function results.
121   do {
122     resultTypes.emplace_back();
123     resultAttrs.emplace_back();
124     if (parser.parseType(resultTypes.back()) ||
125         parser.parseOptionalAttrDict(resultAttrs.back())) {
126       return failure();
127     }
128   } while (succeeded(parser.parseOptionalComma()));
129   return parser.parseRParen();
130 }
131 
132 /// Parses a function signature using `parser`. The `allowVariadic` argument
133 /// indicates whether functions with variadic arguments are supported. The
134 /// trailing arguments are populated by this function with names, types and
135 /// attributes of the arguments and those of the results.
136 ParseResult mlir::impl::parseFunctionSignature(
137     OpAsmParser &parser, bool allowVariadic,
138     SmallVectorImpl<OpAsmParser::OperandType> &argNames,
139     SmallVectorImpl<Type> &argTypes,
140     SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs, bool &isVariadic,
141     SmallVectorImpl<Type> &resultTypes,
142     SmallVectorImpl<SmallVector<NamedAttribute, 2>> &resultAttrs) {
143   if (parseArgumentList(parser, allowVariadic, argTypes, argNames, argAttrs,
144                         isVariadic))
145     return failure();
146   if (succeeded(parser.parseOptionalArrow()))
147     return parseFunctionResultList(parser, resultTypes, resultAttrs);
148   return success();
149 }
150 
151 void mlir::impl::addArgAndResultAttrs(
152     Builder &builder, OperationState &result,
153     ArrayRef<SmallVector<NamedAttribute, 2>> argAttrs,
154     ArrayRef<SmallVector<NamedAttribute, 2>> resultAttrs) {
155   // Add the attributes to the function arguments.
156   SmallString<8> attrNameBuf;
157   for (unsigned i = 0, e = argAttrs.size(); i != e; ++i)
158     if (!argAttrs[i].empty())
159       result.addAttribute(getArgAttrName(i, attrNameBuf),
160                           builder.getDictionaryAttr(argAttrs[i]));
161 
162   // Add the attributes to the function results.
163   for (unsigned i = 0, e = resultAttrs.size(); i != e; ++i)
164     if (!resultAttrs[i].empty())
165       result.addAttribute(getResultAttrName(i, attrNameBuf),
166                           builder.getDictionaryAttr(resultAttrs[i]));
167 }
168 
169 /// Parser implementation for function-like operations.  Uses `funcTypeBuilder`
170 /// to construct the custom function type given lists of input and output types.
171 ParseResult
172 mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
173                                 bool allowVariadic,
174                                 mlir::impl::FuncTypeBuilder funcTypeBuilder) {
175   SmallVector<OpAsmParser::OperandType, 4> entryArgs;
176   SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs;
177   SmallVector<SmallVector<NamedAttribute, 2>, 4> resultAttrs;
178   SmallVector<Type, 4> argTypes;
179   SmallVector<Type, 4> resultTypes;
180   auto &builder = parser.getBuilder();
181 
182   // Parse the name as a symbol.
183   StringAttr nameAttr;
184   if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
185                              result.attributes))
186     return failure();
187 
188   // Parse the function signature.
189   auto signatureLocation = parser.getCurrentLocation();
190   bool isVariadic = false;
191   if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes,
192                              argAttrs, isVariadic, resultTypes, resultAttrs))
193     return failure();
194 
195   std::string errorMessage;
196   if (auto type = funcTypeBuilder(builder, argTypes, resultTypes,
197                                   impl::VariadicFlag(isVariadic), errorMessage))
198     result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
199   else
200     return parser.emitError(signatureLocation)
201            << "failed to construct function type"
202            << (errorMessage.empty() ? "" : ": ") << errorMessage;
203 
204   // If function attributes are present, parse them.
205   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
206     return failure();
207 
208   // Add the attributes to the function arguments.
209   assert(argAttrs.size() == argTypes.size());
210   assert(resultAttrs.size() == resultTypes.size());
211   addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
212 
213   // Parse the optional function body.
214   auto *body = result.addRegion();
215   return parser.parseOptionalRegion(
216       *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
217 }
218 
219 // Print a function result list.
220 static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
221                                     ArrayRef<ArrayRef<NamedAttribute>> attrs) {
222   assert(!types.empty() && "Should not be called for empty result list.");
223   auto &os = p.getStream();
224   bool needsParens =
225       types.size() > 1 || types[0].isa<FunctionType>() || !attrs[0].empty();
226   if (needsParens)
227     os << '(';
228   interleaveComma(llvm::zip(types, attrs), os,
229                   [&](const std::tuple<Type, ArrayRef<NamedAttribute>> &t) {
230                     p.printType(std::get<0>(t));
231                     p.printOptionalAttrDict(std::get<1>(t));
232                   });
233   if (needsParens)
234     os << ')';
235 }
236 
237 /// Print the signature of the function-like operation `op`.  Assumes `op` has
238 /// the FunctionLike trait and passed the verification.
239 void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op,
240                                         ArrayRef<Type> argTypes,
241                                         bool isVariadic,
242                                         ArrayRef<Type> resultTypes) {
243   Region &body = op->getRegion(0);
244   bool isExternal = body.empty();
245 
246   p << '(';
247   for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
248     if (i > 0)
249       p << ", ";
250 
251     if (!isExternal) {
252       p.printOperand(body.front().getArgument(i));
253       p << ": ";
254     }
255 
256     p.printType(argTypes[i]);
257     p.printOptionalAttrDict(::mlir::impl::getArgAttrs(op, i));
258   }
259 
260   if (isVariadic) {
261     if (!argTypes.empty())
262       p << ", ";
263     p << "...";
264   }
265 
266   p << ')';
267 
268   if (!resultTypes.empty()) {
269     p.getStream() << " -> ";
270     SmallVector<ArrayRef<NamedAttribute>, 4> resultAttrs;
271     for (int i = 0, e = resultTypes.size(); i < e; ++i)
272       resultAttrs.push_back(::mlir::impl::getResultAttrs(op, i));
273     printFunctionResultList(p, resultTypes, resultAttrs);
274   }
275 }
276 
277 /// Prints the list of function prefixed with the "attributes" keyword. The
278 /// attributes with names listed in "elided" as well as those used by the
279 /// function-like operation internally are not printed. Nothing is printed
280 /// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and
281 /// passed the verification.
282 void mlir::impl::printFunctionAttributes(OpAsmPrinter &p, Operation *op,
283                                          unsigned numInputs,
284                                          unsigned numResults,
285                                          ArrayRef<StringRef> elided) {
286   // Print out function attributes, if present.
287   SmallVector<StringRef, 2> ignoredAttrs = {
288       ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName()};
289   ignoredAttrs.append(elided.begin(), elided.end());
290 
291   SmallString<8> attrNameBuf;
292 
293   // Ignore any argument attributes.
294   std::vector<SmallString<8>> argAttrStorage;
295   for (unsigned i = 0; i != numInputs; ++i)
296     if (op->getAttr(getArgAttrName(i, attrNameBuf)))
297       argAttrStorage.emplace_back(attrNameBuf);
298   ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end());
299 
300   // Ignore any result attributes.
301   std::vector<SmallString<8>> resultAttrStorage;
302   for (unsigned i = 0; i != numResults; ++i)
303     if (op->getAttr(getResultAttrName(i, attrNameBuf)))
304       resultAttrStorage.emplace_back(attrNameBuf);
305   ignoredAttrs.append(resultAttrStorage.begin(), resultAttrStorage.end());
306 
307   p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
308 }
309 
310 /// Printer implementation for function-like operations.  Accepts lists of
311 /// argument and result types to use while printing.
312 void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op,
313                                      ArrayRef<Type> argTypes, bool isVariadic,
314                                      ArrayRef<Type> resultTypes) {
315   // Print the operation and the function name.
316   auto funcName =
317       op->getAttrOfType<StringAttr>(::mlir::SymbolTable::getSymbolAttrName())
318           .getValue();
319   p << op->getName() << ' ';
320   p.printSymbolName(funcName);
321 
322   printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
323   printFunctionAttributes(p, op, argTypes.size(), resultTypes.size());
324 
325   // Print the body if this is not an external function.
326   Region &body = op->getRegion(0);
327   if (!body.empty())
328     p.printRegion(body, /*printEntryBlockArgs=*/false,
329                   /*printBlockTerminators=*/true);
330 }
331