1 //===- FunctionImplementation.cpp - Utilities for function-like ops -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/FunctionImplementation.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/FunctionInterfaces.h"
12 #include "mlir/IR/SymbolTable.h"
13 
14 using namespace mlir;
15 
16 ParseResult mlir::function_interface_impl::parseFunctionArgumentList(
17     OpAsmParser &parser, bool allowAttributes, bool allowVariadic,
18     SmallVectorImpl<OpAsmParser::OperandType> &argNames,
19     SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
20     bool &isVariadic) {
21   if (parser.parseLParen())
22     return failure();
23 
24   // The argument list either has to consistently have ssa-id's followed by
25   // types, or just be a type list.  It isn't ok to sometimes have SSA ID's and
26   // sometimes not.
27   auto parseArgument = [&]() -> ParseResult {
28     llvm::SMLoc loc = parser.getCurrentLocation();
29 
30     // Parse argument name if present.
31     OpAsmParser::OperandType argument;
32     Type argumentType;
33     if (succeeded(parser.parseOptionalRegionArgument(argument)) &&
34         !argument.name.empty()) {
35       // Reject this if the preceding argument was missing a name.
36       if (argNames.empty() && !argTypes.empty())
37         return parser.emitError(loc, "expected type instead of SSA identifier");
38       argNames.push_back(argument);
39 
40       if (parser.parseColonType(argumentType))
41         return failure();
42     } else if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) {
43       isVariadic = true;
44       return success();
45     } else if (!argNames.empty()) {
46       // Reject this if the preceding argument had a name.
47       return parser.emitError(loc, "expected SSA identifier");
48     } else if (parser.parseType(argumentType)) {
49       return failure();
50     }
51 
52     // Add the argument type.
53     argTypes.push_back(argumentType);
54 
55     // Parse any argument attributes.
56     NamedAttrList attrs;
57     if (parser.parseOptionalAttrDict(attrs))
58       return failure();
59     if (!allowAttributes && !attrs.empty())
60       return parser.emitError(loc, "expected arguments without attributes");
61     argAttrs.push_back(attrs);
62 
63     // Parse a location if specified.  TODO: Don't drop it on the floor.
64     Optional<Location> explicitLoc;
65     if (!argument.name.empty() &&
66         parser.parseOptionalLocationSpecifier(explicitLoc))
67       return failure();
68 
69     return success();
70   };
71 
72   // Parse the function arguments.
73   isVariadic = false;
74   if (failed(parser.parseOptionalRParen())) {
75     do {
76       unsigned numTypedArguments = argTypes.size();
77       if (parseArgument())
78         return failure();
79 
80       llvm::SMLoc loc = parser.getCurrentLocation();
81       if (argTypes.size() == numTypedArguments &&
82           succeeded(parser.parseOptionalComma()))
83         return parser.emitError(
84             loc, "variadic arguments must be in the end of the argument list");
85     } while (succeeded(parser.parseOptionalComma()));
86     parser.parseRParen();
87   }
88 
89   return success();
90 }
91 
92 /// Parse a function result list.
93 ///
94 ///   function-result-list ::= function-result-list-parens
95 ///                          | non-function-type
96 ///   function-result-list-parens ::= `(` `)`
97 ///                                 | `(` function-result-list-no-parens `)`
98 ///   function-result-list-no-parens ::= function-result (`,` function-result)*
99 ///   function-result ::= type attribute-dict?
100 ///
101 static ParseResult
102 parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
103                         SmallVectorImpl<NamedAttrList> &resultAttrs) {
104   if (failed(parser.parseOptionalLParen())) {
105     // We already know that there is no `(`, so parse a type.
106     // Because there is no `(`, it cannot be a function type.
107     Type ty;
108     if (parser.parseType(ty))
109       return failure();
110     resultTypes.push_back(ty);
111     resultAttrs.emplace_back();
112     return success();
113   }
114 
115   // Special case for an empty set of parens.
116   if (succeeded(parser.parseOptionalRParen()))
117     return success();
118 
119   // Parse individual function results.
120   do {
121     resultTypes.emplace_back();
122     resultAttrs.emplace_back();
123     if (parser.parseType(resultTypes.back()) ||
124         parser.parseOptionalAttrDict(resultAttrs.back())) {
125       return failure();
126     }
127   } while (succeeded(parser.parseOptionalComma()));
128   return parser.parseRParen();
129 }
130 
131 ParseResult mlir::function_interface_impl::parseFunctionSignature(
132     OpAsmParser &parser, bool allowVariadic,
133     SmallVectorImpl<OpAsmParser::OperandType> &argNames,
134     SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
135     bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
136     SmallVectorImpl<NamedAttrList> &resultAttrs) {
137   bool allowArgAttrs = true;
138   if (parseFunctionArgumentList(parser, allowArgAttrs, allowVariadic, argNames,
139                                 argTypes, argAttrs, isVariadic))
140     return failure();
141   if (succeeded(parser.parseOptionalArrow()))
142     return parseFunctionResultList(parser, resultTypes, resultAttrs);
143   return success();
144 }
145 
146 /// Implementation of `addArgAndResultAttrs` that is attribute list type
147 /// agnostic.
148 template <typename AttrListT, typename AttrArrayBuildFnT>
149 static void addArgAndResultAttrsImpl(Builder &builder, OperationState &result,
150                                      ArrayRef<AttrListT> argAttrs,
151                                      ArrayRef<AttrListT> resultAttrs,
152                                      AttrArrayBuildFnT &&buildAttrArrayFn) {
153   auto nonEmptyAttrsFn = [](const AttrListT &attrs) { return !attrs.empty(); };
154 
155   // Add the attributes to the function arguments.
156   if (!argAttrs.empty() && llvm::any_of(argAttrs, nonEmptyAttrsFn)) {
157     ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(argAttrs));
158     result.addAttribute(function_interface_impl::getArgDictAttrName(),
159                         attrDicts);
160   }
161   // Add the attributes to the function results.
162   if (!resultAttrs.empty() && llvm::any_of(resultAttrs, nonEmptyAttrsFn)) {
163     ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(resultAttrs));
164     result.addAttribute(function_interface_impl::getResultDictAttrName(),
165                         attrDicts);
166   }
167 }
168 
169 void mlir::function_interface_impl::addArgAndResultAttrs(
170     Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
171     ArrayRef<DictionaryAttr> resultAttrs) {
172   auto buildFn = [](ArrayRef<DictionaryAttr> attrs) {
173     return ArrayRef<Attribute>(attrs.data(), attrs.size());
174   };
175   addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn);
176 }
177 void mlir::function_interface_impl::addArgAndResultAttrs(
178     Builder &builder, OperationState &result, ArrayRef<NamedAttrList> argAttrs,
179     ArrayRef<NamedAttrList> resultAttrs) {
180   MLIRContext *context = builder.getContext();
181   auto buildFn = [=](ArrayRef<NamedAttrList> attrs) {
182     return llvm::to_vector<8>(
183         llvm::map_range(attrs, [=](const NamedAttrList &attrList) -> Attribute {
184           return attrList.getDictionary(context);
185         }));
186   };
187   addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn);
188 }
189 
190 ParseResult mlir::function_interface_impl::parseFunctionOp(
191     OpAsmParser &parser, OperationState &result, bool allowVariadic,
192     FuncTypeBuilder funcTypeBuilder) {
193   SmallVector<OpAsmParser::OperandType, 4> entryArgs;
194   SmallVector<NamedAttrList, 4> argAttrs;
195   SmallVector<NamedAttrList, 4> resultAttrs;
196   SmallVector<Type, 4> argTypes;
197   SmallVector<Type, 4> resultTypes;
198   auto &builder = parser.getBuilder();
199 
200   // Parse visibility.
201   impl::parseOptionalVisibilityKeyword(parser, result.attributes);
202 
203   // Parse the name as a symbol.
204   StringAttr nameAttr;
205   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
206                              result.attributes))
207     return failure();
208 
209   // Parse the function signature.
210   llvm::SMLoc signatureLocation = parser.getCurrentLocation();
211   bool isVariadic = false;
212   if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes,
213                              argAttrs, isVariadic, resultTypes, resultAttrs))
214     return failure();
215 
216   std::string errorMessage;
217   Type type = funcTypeBuilder(builder, argTypes, resultTypes,
218                               VariadicFlag(isVariadic), errorMessage);
219   if (!type) {
220     return parser.emitError(signatureLocation)
221            << "failed to construct function type"
222            << (errorMessage.empty() ? "" : ": ") << errorMessage;
223   }
224   result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
225 
226   // If function attributes are present, parse them.
227   NamedAttrList parsedAttributes;
228   llvm::SMLoc attributeDictLocation = parser.getCurrentLocation();
229   if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes))
230     return failure();
231 
232   // Disallow attributes that are inferred from elsewhere in the attribute
233   // dictionary.
234   for (StringRef disallowed :
235        {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(),
236         getTypeAttrName()}) {
237     if (parsedAttributes.get(disallowed))
238       return parser.emitError(attributeDictLocation, "'")
239              << disallowed
240              << "' is an inferred attribute and should not be specified in the "
241                 "explicit attribute dictionary";
242   }
243   result.attributes.append(parsedAttributes);
244 
245   // Add the attributes to the function arguments.
246   assert(argAttrs.size() == argTypes.size());
247   assert(resultAttrs.size() == resultTypes.size());
248   addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
249 
250   // Parse the optional function body. The printer will not print the body if
251   // its empty, so disallow parsing of empty body in the parser.
252   auto *body = result.addRegion();
253   llvm::SMLoc loc = parser.getCurrentLocation();
254   OptionalParseResult parseResult = parser.parseOptionalRegion(
255       *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes,
256       /*enableNameShadowing=*/false);
257   if (parseResult.hasValue()) {
258     if (failed(*parseResult))
259       return failure();
260     // Function body was parsed, make sure its not empty.
261     if (body->empty())
262       return parser.emitError(loc, "expected non-empty function body");
263   }
264   return success();
265 }
266 
267 /// Print a function result list. The provided `attrs` must either be null, or
268 /// contain a set of DictionaryAttrs of the same arity as `types`.
269 static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
270                                     ArrayAttr attrs) {
271   assert(!types.empty() && "Should not be called for empty result list.");
272   assert((!attrs || attrs.size() == types.size()) &&
273          "Invalid number of attributes.");
274 
275   auto &os = p.getStream();
276   bool needsParens = types.size() > 1 || types[0].isa<FunctionType>() ||
277                      (attrs && !attrs[0].cast<DictionaryAttr>().empty());
278   if (needsParens)
279     os << '(';
280   llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
281     p.printType(types[i]);
282     if (attrs)
283       p.printOptionalAttrDict(attrs[i].cast<DictionaryAttr>().getValue());
284   });
285   if (needsParens)
286     os << ')';
287 }
288 
289 void mlir::function_interface_impl::printFunctionSignature(
290     OpAsmPrinter &p, Operation *op, ArrayRef<Type> argTypes, bool isVariadic,
291     ArrayRef<Type> resultTypes) {
292   Region &body = op->getRegion(0);
293   bool isExternal = body.empty();
294 
295   p << '(';
296   ArrayAttr argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
297   for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
298     if (i > 0)
299       p << ", ";
300 
301     if (!isExternal) {
302       ArrayRef<NamedAttribute> attrs;
303       if (argAttrs)
304         attrs = argAttrs[i].cast<DictionaryAttr>().getValue();
305       p.printRegionArgument(body.getArgument(i), attrs);
306     } else {
307       p.printType(argTypes[i]);
308       if (argAttrs)
309         p.printOptionalAttrDict(argAttrs[i].cast<DictionaryAttr>().getValue());
310     }
311   }
312 
313   if (isVariadic) {
314     if (!argTypes.empty())
315       p << ", ";
316     p << "...";
317   }
318 
319   p << ')';
320 
321   if (!resultTypes.empty()) {
322     p.getStream() << " -> ";
323     auto resultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
324     printFunctionResultList(p, resultTypes, resultAttrs);
325   }
326 }
327 
328 void mlir::function_interface_impl::printFunctionAttributes(
329     OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults,
330     ArrayRef<StringRef> elided) {
331   // Print out function attributes, if present.
332   SmallVector<StringRef, 2> ignoredAttrs = {
333       ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(),
334       getArgDictAttrName(), getResultDictAttrName()};
335   ignoredAttrs.append(elided.begin(), elided.end());
336 
337   p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
338 }
339 
340 void mlir::function_interface_impl::printFunctionOp(
341     OpAsmPrinter &p, Operation *op, ArrayRef<Type> argTypes, bool isVariadic,
342     ArrayRef<Type> resultTypes) {
343   // Print the operation and the function name.
344   auto funcName =
345       op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
346           .getValue();
347   p << ' ';
348 
349   StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
350   if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
351     p << visibility.getValue() << ' ';
352   p.printSymbolName(funcName);
353 
354   printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
355   printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(),
356                           {visibilityAttrName});
357   // Print the body if this is not an external function.
358   Region &body = op->getRegion(0);
359   if (!body.empty()) {
360     p << ' ';
361     p.printRegion(body, /*printEntryBlockArgs=*/false,
362                   /*printBlockTerminators=*/true);
363   }
364 }
365