1 //===- FunctionImplementation.cpp - Utilities for function-like ops -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/FunctionImplementation.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/FunctionInterfaces.h"
12 #include "mlir/IR/SymbolTable.h"
13 
14 using namespace mlir;
15 
16 ParseResult mlir::function_interface_impl::parseFunctionArgumentList(
17     OpAsmParser &parser, bool allowAttributes, bool allowVariadic,
18     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames,
19     SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
20     bool &isVariadic) {
21   if (parser.parseLParen())
22     return failure();
23 
24   // The argument list either has to consistently have ssa-id's followed by
25   // types, or just be a type list.  It isn't ok to sometimes have SSA ID's and
26   // sometimes not.
27   auto parseArgument = [&]() -> ParseResult {
28     SMLoc loc = parser.getCurrentLocation();
29 
30     // Parse argument name if present.
31     OpAsmParser::UnresolvedOperand argument;
32     Type argumentType;
33     if (succeeded(parser.parseOptionalRegionArgument(argument)) &&
34         !argument.name.empty()) {
35       // Reject this if the preceding argument was missing a name.
36       if (argNames.empty() && !argTypes.empty())
37         return parser.emitError(loc, "expected type instead of SSA identifier");
38 
39       // Parse required type.
40       if (parser.parseColonType(argumentType))
41         return failure();
42     } else if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) {
43       isVariadic = true;
44       return success();
45     } else if (!argNames.empty()) {
46       // Reject this if the preceding argument had a name.
47       return parser.emitError(loc, "expected SSA identifier");
48     } else if (parser.parseType(argumentType)) {
49       return failure();
50     }
51 
52     // Add the argument type.
53     argTypes.push_back(argumentType);
54 
55     // Parse any argument attributes and source location information.
56     NamedAttrList attrs;
57     if (parser.parseOptionalAttrDict(attrs) ||
58         parser.parseOptionalLocationSpecifier(argument.sourceLoc))
59       return failure();
60 
61     if (!allowAttributes && !attrs.empty())
62       return parser.emitError(loc, "expected arguments without attributes");
63     argAttrs.push_back(attrs);
64 
65     // If we had an argument name, then remember the parsed argument.
66     if (!argument.name.empty())
67       argNames.push_back(argument);
68     return success();
69   };
70 
71   // Parse the function arguments.
72   isVariadic = false;
73   if (failed(parser.parseOptionalRParen())) {
74     do {
75       unsigned numTypedArguments = argTypes.size();
76       if (parseArgument())
77         return failure();
78 
79       SMLoc loc = parser.getCurrentLocation();
80       if (argTypes.size() == numTypedArguments &&
81           succeeded(parser.parseOptionalComma()))
82         return parser.emitError(
83             loc, "variadic arguments must be in the end of the argument list");
84     } while (succeeded(parser.parseOptionalComma()));
85     parser.parseRParen();
86   }
87 
88   return success();
89 }
90 
91 /// Parse a function result list.
92 ///
93 ///   function-result-list ::= function-result-list-parens
94 ///                          | non-function-type
95 ///   function-result-list-parens ::= `(` `)`
96 ///                                 | `(` function-result-list-no-parens `)`
97 ///   function-result-list-no-parens ::= function-result (`,` function-result)*
98 ///   function-result ::= type attribute-dict?
99 ///
100 static ParseResult
101 parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
102                         SmallVectorImpl<NamedAttrList> &resultAttrs) {
103   if (failed(parser.parseOptionalLParen())) {
104     // We already know that there is no `(`, so parse a type.
105     // Because there is no `(`, it cannot be a function type.
106     Type ty;
107     if (parser.parseType(ty))
108       return failure();
109     resultTypes.push_back(ty);
110     resultAttrs.emplace_back();
111     return success();
112   }
113 
114   // Special case for an empty set of parens.
115   if (succeeded(parser.parseOptionalRParen()))
116     return success();
117 
118   // Parse individual function results.
119   do {
120     resultTypes.emplace_back();
121     resultAttrs.emplace_back();
122     if (parser.parseType(resultTypes.back()) ||
123         parser.parseOptionalAttrDict(resultAttrs.back())) {
124       return failure();
125     }
126   } while (succeeded(parser.parseOptionalComma()));
127   return parser.parseRParen();
128 }
129 
130 ParseResult mlir::function_interface_impl::parseFunctionSignature(
131     OpAsmParser &parser, bool allowVariadic,
132     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames,
133     SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
134     bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
135     SmallVectorImpl<NamedAttrList> &resultAttrs) {
136   bool allowArgAttrs = true;
137   if (parseFunctionArgumentList(parser, allowArgAttrs, allowVariadic, argNames,
138                                 argTypes, argAttrs, isVariadic))
139     return failure();
140   if (succeeded(parser.parseOptionalArrow()))
141     return parseFunctionResultList(parser, resultTypes, resultAttrs);
142   return success();
143 }
144 
145 /// Implementation of `addArgAndResultAttrs` that is attribute list type
146 /// agnostic.
147 template <typename AttrListT, typename AttrArrayBuildFnT>
148 static void addArgAndResultAttrsImpl(Builder &builder, OperationState &result,
149                                      ArrayRef<AttrListT> argAttrs,
150                                      ArrayRef<AttrListT> resultAttrs,
151                                      AttrArrayBuildFnT &&buildAttrArrayFn) {
152   auto nonEmptyAttrsFn = [](const AttrListT &attrs) { return !attrs.empty(); };
153 
154   // Add the attributes to the function arguments.
155   if (!argAttrs.empty() && llvm::any_of(argAttrs, nonEmptyAttrsFn)) {
156     ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(argAttrs));
157     result.addAttribute(function_interface_impl::getArgDictAttrName(),
158                         attrDicts);
159   }
160   // Add the attributes to the function results.
161   if (!resultAttrs.empty() && llvm::any_of(resultAttrs, nonEmptyAttrsFn)) {
162     ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(resultAttrs));
163     result.addAttribute(function_interface_impl::getResultDictAttrName(),
164                         attrDicts);
165   }
166 }
167 
168 void mlir::function_interface_impl::addArgAndResultAttrs(
169     Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
170     ArrayRef<DictionaryAttr> resultAttrs) {
171   auto buildFn = [](ArrayRef<DictionaryAttr> attrs) {
172     return ArrayRef<Attribute>(attrs.data(), attrs.size());
173   };
174   addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn);
175 }
176 void mlir::function_interface_impl::addArgAndResultAttrs(
177     Builder &builder, OperationState &result, ArrayRef<NamedAttrList> argAttrs,
178     ArrayRef<NamedAttrList> resultAttrs) {
179   MLIRContext *context = builder.getContext();
180   auto buildFn = [=](ArrayRef<NamedAttrList> attrs) {
181     return llvm::to_vector<8>(
182         llvm::map_range(attrs, [=](const NamedAttrList &attrList) -> Attribute {
183           return attrList.getDictionary(context);
184         }));
185   };
186   addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn);
187 }
188 
189 ParseResult mlir::function_interface_impl::parseFunctionOp(
190     OpAsmParser &parser, OperationState &result, bool allowVariadic,
191     FuncTypeBuilder funcTypeBuilder) {
192   SmallVector<OpAsmParser::UnresolvedOperand> entryArgs;
193   SmallVector<NamedAttrList> argAttrs;
194   SmallVector<NamedAttrList> resultAttrs;
195   SmallVector<Type> argTypes;
196   SmallVector<Type> resultTypes;
197   auto &builder = parser.getBuilder();
198 
199   // Parse visibility.
200   impl::parseOptionalVisibilityKeyword(parser, result.attributes);
201 
202   // Parse the name as a symbol.
203   StringAttr nameAttr;
204   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
205                              result.attributes))
206     return failure();
207 
208   // Parse the function signature.
209   SMLoc signatureLocation = parser.getCurrentLocation();
210   bool isVariadic = false;
211   if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes,
212                              argAttrs, isVariadic, resultTypes, resultAttrs))
213     return failure();
214 
215   std::string errorMessage;
216   Type type = funcTypeBuilder(builder, argTypes, resultTypes,
217                               VariadicFlag(isVariadic), errorMessage);
218   if (!type) {
219     return parser.emitError(signatureLocation)
220            << "failed to construct function type"
221            << (errorMessage.empty() ? "" : ": ") << errorMessage;
222   }
223   result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
224 
225   // If function attributes are present, parse them.
226   NamedAttrList parsedAttributes;
227   SMLoc attributeDictLocation = parser.getCurrentLocation();
228   if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes))
229     return failure();
230 
231   // Disallow attributes that are inferred from elsewhere in the attribute
232   // dictionary.
233   for (StringRef disallowed :
234        {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(),
235         getTypeAttrName()}) {
236     if (parsedAttributes.get(disallowed))
237       return parser.emitError(attributeDictLocation, "'")
238              << disallowed
239              << "' is an inferred attribute and should not be specified in the "
240                 "explicit attribute dictionary";
241   }
242   result.attributes.append(parsedAttributes);
243 
244   // Add the attributes to the function arguments.
245   assert(argAttrs.size() == argTypes.size());
246   assert(resultAttrs.size() == resultTypes.size());
247   addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
248 
249   // Parse the optional function body. The printer will not print the body if
250   // its empty, so disallow parsing of empty body in the parser.
251   auto *body = result.addRegion();
252   SMLoc loc = parser.getCurrentLocation();
253   OptionalParseResult parseResult = parser.parseOptionalRegion(
254       *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes,
255       /*enableNameShadowing=*/false);
256   if (parseResult.hasValue()) {
257     if (failed(*parseResult))
258       return failure();
259     // Function body was parsed, make sure its not empty.
260     if (body->empty())
261       return parser.emitError(loc, "expected non-empty function body");
262   }
263   return success();
264 }
265 
266 /// Print a function result list. The provided `attrs` must either be null, or
267 /// contain a set of DictionaryAttrs of the same arity as `types`.
268 static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
269                                     ArrayAttr attrs) {
270   assert(!types.empty() && "Should not be called for empty result list.");
271   assert((!attrs || attrs.size() == types.size()) &&
272          "Invalid number of attributes.");
273 
274   auto &os = p.getStream();
275   bool needsParens = types.size() > 1 || types[0].isa<FunctionType>() ||
276                      (attrs && !attrs[0].cast<DictionaryAttr>().empty());
277   if (needsParens)
278     os << '(';
279   llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
280     p.printType(types[i]);
281     if (attrs)
282       p.printOptionalAttrDict(attrs[i].cast<DictionaryAttr>().getValue());
283   });
284   if (needsParens)
285     os << ')';
286 }
287 
288 void mlir::function_interface_impl::printFunctionSignature(
289     OpAsmPrinter &p, Operation *op, ArrayRef<Type> argTypes, bool isVariadic,
290     ArrayRef<Type> resultTypes) {
291   Region &body = op->getRegion(0);
292   bool isExternal = body.empty();
293 
294   p << '(';
295   ArrayAttr argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
296   for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
297     if (i > 0)
298       p << ", ";
299 
300     if (!isExternal) {
301       ArrayRef<NamedAttribute> attrs;
302       if (argAttrs)
303         attrs = argAttrs[i].cast<DictionaryAttr>().getValue();
304       p.printRegionArgument(body.getArgument(i), attrs);
305     } else {
306       p.printType(argTypes[i]);
307       if (argAttrs)
308         p.printOptionalAttrDict(argAttrs[i].cast<DictionaryAttr>().getValue());
309     }
310   }
311 
312   if (isVariadic) {
313     if (!argTypes.empty())
314       p << ", ";
315     p << "...";
316   }
317 
318   p << ')';
319 
320   if (!resultTypes.empty()) {
321     p.getStream() << " -> ";
322     auto resultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
323     printFunctionResultList(p, resultTypes, resultAttrs);
324   }
325 }
326 
327 void mlir::function_interface_impl::printFunctionAttributes(
328     OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults,
329     ArrayRef<StringRef> elided) {
330   // Print out function attributes, if present.
331   SmallVector<StringRef, 2> ignoredAttrs = {
332       ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(),
333       getArgDictAttrName(), getResultDictAttrName()};
334   ignoredAttrs.append(elided.begin(), elided.end());
335 
336   p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
337 }
338 
339 void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p,
340                                                     FunctionOpInterface op,
341                                                     bool isVariadic) {
342   // Print the operation and the function name.
343   auto funcName =
344       op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
345           .getValue();
346   p << ' ';
347 
348   StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
349   if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
350     p << visibility.getValue() << ' ';
351   p.printSymbolName(funcName);
352 
353   ArrayRef<Type> argTypes = op.getArgumentTypes();
354   ArrayRef<Type> resultTypes = op.getResultTypes();
355   printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
356   printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(),
357                           {visibilityAttrName});
358   // Print the body if this is not an external function.
359   Region &body = op->getRegion(0);
360   if (!body.empty()) {
361     p << ' ';
362     p.printRegion(body, /*printEntryBlockArgs=*/false,
363                   /*printBlockTerminators=*/true);
364   }
365 }
366