1 //===- FuncOps.cpp - Func Dialect Operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Func/IR/FuncOps.h"
10 
11 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
12 #include "mlir/IR/BlockAndValueMapping.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/FunctionImplementation.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/IR/Value.h"
22 #include "mlir/Support/MathExtras.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/APFloat.h"
25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/StringSwitch.h"
28 #include "llvm/Support/FormatVariadic.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <numeric>
31 
32 #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
33 
34 using namespace mlir;
35 using namespace mlir::func;
36 
37 //===----------------------------------------------------------------------===//
38 // FuncDialect Interfaces
39 //===----------------------------------------------------------------------===//
40 namespace {
41 /// This class defines the interface for handling inlining with func operations.
42 struct FuncInlinerInterface : public DialectInlinerInterface {
43   using DialectInlinerInterface::DialectInlinerInterface;
44 
45   //===--------------------------------------------------------------------===//
46   // Analysis Hooks
47   //===--------------------------------------------------------------------===//
48 
49   /// All call operations can be inlined.
isLegalToInline__anon8a470dc10111::FuncInlinerInterface50   bool isLegalToInline(Operation *call, Operation *callable,
51                        bool wouldBeCloned) const final {
52     return true;
53   }
54 
55   /// All operations can be inlined.
isLegalToInline__anon8a470dc10111::FuncInlinerInterface56   bool isLegalToInline(Operation *, Region *, bool,
57                        BlockAndValueMapping &) const final {
58     return true;
59   }
60 
61   /// All functions can be inlined.
isLegalToInline__anon8a470dc10111::FuncInlinerInterface62   bool isLegalToInline(Region *, Region *, bool,
63                        BlockAndValueMapping &) const final {
64     return true;
65   }
66 
67   //===--------------------------------------------------------------------===//
68   // Transformation Hooks
69   //===--------------------------------------------------------------------===//
70 
71   /// Handle the given inlined terminator by replacing it with a new operation
72   /// as necessary.
handleTerminator__anon8a470dc10111::FuncInlinerInterface73   void handleTerminator(Operation *op, Block *newDest) const final {
74     // Only return needs to be handled here.
75     auto returnOp = dyn_cast<ReturnOp>(op);
76     if (!returnOp)
77       return;
78 
79     // Replace the return with a branch to the dest.
80     OpBuilder builder(op);
81     builder.create<cf::BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
82     op->erase();
83   }
84 
85   /// Handle the given inlined terminator by replacing it with a new operation
86   /// as necessary.
handleTerminator__anon8a470dc10111::FuncInlinerInterface87   void handleTerminator(Operation *op,
88                         ArrayRef<Value> valuesToRepl) const final {
89     // Only return needs to be handled here.
90     auto returnOp = cast<ReturnOp>(op);
91 
92     // Replace the values directly with the return operands.
93     assert(returnOp.getNumOperands() == valuesToRepl.size());
94     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
95       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
96   }
97 };
98 } // namespace
99 
100 //===----------------------------------------------------------------------===//
101 // FuncDialect
102 //===----------------------------------------------------------------------===//
103 
initialize()104 void FuncDialect::initialize() {
105   addOperations<
106 #define GET_OP_LIST
107 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
108       >();
109   addInterfaces<FuncInlinerInterface>();
110 }
111 
112 /// Materialize a single constant operation from a given attribute value with
113 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)114 Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value,
115                                             Type type, Location loc) {
116   if (ConstantOp::isBuildableWith(value, type))
117     return builder.create<ConstantOp>(loc, type,
118                                       value.cast<FlatSymbolRefAttr>());
119   return nullptr;
120 }
121 
122 //===----------------------------------------------------------------------===//
123 // CallOp
124 //===----------------------------------------------------------------------===//
125 
verifySymbolUses(SymbolTableCollection & symbolTable)126 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
127   // Check that the callee attribute was specified.
128   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
129   if (!fnAttr)
130     return emitOpError("requires a 'callee' symbol reference attribute");
131   FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
132   if (!fn)
133     return emitOpError() << "'" << fnAttr.getValue()
134                          << "' does not reference a valid function";
135 
136   // Verify that the operand and result types match the callee.
137   auto fnType = fn.getFunctionType();
138   if (fnType.getNumInputs() != getNumOperands())
139     return emitOpError("incorrect number of operands for callee");
140 
141   for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
142     if (getOperand(i).getType() != fnType.getInput(i))
143       return emitOpError("operand type mismatch: expected operand type ")
144              << fnType.getInput(i) << ", but provided "
145              << getOperand(i).getType() << " for operand number " << i;
146 
147   if (fnType.getNumResults() != getNumResults())
148     return emitOpError("incorrect number of results for callee");
149 
150   for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
151     if (getResult(i).getType() != fnType.getResult(i)) {
152       auto diag = emitOpError("result type mismatch at index ") << i;
153       diag.attachNote() << "      op result types: " << getResultTypes();
154       diag.attachNote() << "function result types: " << fnType.getResults();
155       return diag;
156     }
157 
158   return success();
159 }
160 
getCalleeType()161 FunctionType CallOp::getCalleeType() {
162   return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // CallIndirectOp
167 //===----------------------------------------------------------------------===//
168 
169 /// Fold indirect calls that have a constant function as the callee operand.
canonicalize(CallIndirectOp indirectCall,PatternRewriter & rewriter)170 LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
171                                            PatternRewriter &rewriter) {
172   // Check that the callee is a constant callee.
173   SymbolRefAttr calledFn;
174   if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
175     return failure();
176 
177   // Replace with a direct call.
178   rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
179                                       indirectCall.getResultTypes(),
180                                       indirectCall.getArgOperands());
181   return success();
182 }
183 
184 //===----------------------------------------------------------------------===//
185 // ConstantOp
186 //===----------------------------------------------------------------------===//
187 
verify()188 LogicalResult ConstantOp::verify() {
189   StringRef fnName = getValue();
190   Type type = getType();
191 
192   // Try to find the referenced function.
193   auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnName);
194   if (!fn)
195     return emitOpError() << "reference to undefined function '" << fnName
196                          << "'";
197 
198   // Check that the referenced function has the correct type.
199   if (fn.getFunctionType() != type)
200     return emitOpError("reference to function with mismatched type");
201 
202   return success();
203 }
204 
fold(ArrayRef<Attribute> operands)205 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
206   assert(operands.empty() && "constant has no operands");
207   return getValueAttr();
208 }
209 
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)210 void ConstantOp::getAsmResultNames(
211     function_ref<void(Value, StringRef)> setNameFn) {
212   setNameFn(getResult(), "f");
213 }
214 
isBuildableWith(Attribute value,Type type)215 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
216   return value.isa<FlatSymbolRefAttr>() && type.isa<FunctionType>();
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // FuncOp
221 //===----------------------------------------------------------------------===//
222 
create(Location location,StringRef name,FunctionType type,ArrayRef<NamedAttribute> attrs)223 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
224                       ArrayRef<NamedAttribute> attrs) {
225   OpBuilder builder(location->getContext());
226   OperationState state(location, getOperationName());
227   FuncOp::build(builder, state, name, type, attrs);
228   return cast<FuncOp>(Operation::create(state));
229 }
create(Location location,StringRef name,FunctionType type,Operation::dialect_attr_range attrs)230 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
231                       Operation::dialect_attr_range attrs) {
232   SmallVector<NamedAttribute, 8> attrRef(attrs);
233   return create(location, name, type, llvm::makeArrayRef(attrRef));
234 }
create(Location location,StringRef name,FunctionType type,ArrayRef<NamedAttribute> attrs,ArrayRef<DictionaryAttr> argAttrs)235 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
236                       ArrayRef<NamedAttribute> attrs,
237                       ArrayRef<DictionaryAttr> argAttrs) {
238   FuncOp func = create(location, name, type, attrs);
239   func.setAllArgAttrs(argAttrs);
240   return func;
241 }
242 
build(OpBuilder & builder,OperationState & state,StringRef name,FunctionType type,ArrayRef<NamedAttribute> attrs,ArrayRef<DictionaryAttr> argAttrs)243 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
244                    FunctionType type, ArrayRef<NamedAttribute> attrs,
245                    ArrayRef<DictionaryAttr> argAttrs) {
246   state.addAttribute(SymbolTable::getSymbolAttrName(),
247                      builder.getStringAttr(name));
248   state.addAttribute(FunctionOpInterface::getTypeAttrName(),
249                      TypeAttr::get(type));
250   state.attributes.append(attrs.begin(), attrs.end());
251   state.addRegion();
252 
253   if (argAttrs.empty())
254     return;
255   assert(type.getNumInputs() == argAttrs.size());
256   function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
257                                                 /*resultAttrs=*/llvm::None);
258 }
259 
parse(OpAsmParser & parser,OperationState & result)260 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
261   auto buildFuncType =
262       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
263          function_interface_impl::VariadicFlag,
264          std::string &) { return builder.getFunctionType(argTypes, results); };
265 
266   return function_interface_impl::parseFunctionOp(
267       parser, result, /*allowVariadic=*/false, buildFuncType);
268 }
269 
print(OpAsmPrinter & p)270 void FuncOp::print(OpAsmPrinter &p) {
271   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
272 }
273 
274 /// Clone the internal blocks from this function into dest and all attributes
275 /// from this function to dest.
cloneInto(FuncOp dest,BlockAndValueMapping & mapper)276 void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
277   // Add the attributes of this function to dest.
278   llvm::MapVector<StringAttr, Attribute> newAttrMap;
279   for (const auto &attr : dest->getAttrs())
280     newAttrMap.insert({attr.getName(), attr.getValue()});
281   for (const auto &attr : (*this)->getAttrs())
282     newAttrMap.insert({attr.getName(), attr.getValue()});
283 
284   auto newAttrs = llvm::to_vector(llvm::map_range(
285       newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
286         return NamedAttribute(attrPair.first, attrPair.second);
287       }));
288   dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
289 
290   // Clone the body.
291   getBody().cloneInto(&dest.getBody(), mapper);
292 }
293 
294 /// Create a deep copy of this function and all of its blocks, remapping
295 /// any operands that use values outside of the function using the map that is
296 /// provided (leaving them alone if no entry is present). Replaces references
297 /// to cloned sub-values with the corresponding value that is copied, and adds
298 /// those mappings to the mapper.
clone(BlockAndValueMapping & mapper)299 FuncOp FuncOp::clone(BlockAndValueMapping &mapper) {
300   // Create the new function.
301   FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
302 
303   // If the function has a body, then the user might be deleting arguments to
304   // the function by specifying them in the mapper. If so, we don't add the
305   // argument to the input type vector.
306   if (!isExternal()) {
307     FunctionType oldType = getFunctionType();
308 
309     unsigned oldNumArgs = oldType.getNumInputs();
310     SmallVector<Type, 4> newInputs;
311     newInputs.reserve(oldNumArgs);
312     for (unsigned i = 0; i != oldNumArgs; ++i)
313       if (!mapper.contains(getArgument(i)))
314         newInputs.push_back(oldType.getInput(i));
315 
316     /// If any of the arguments were dropped, update the type and drop any
317     /// necessary argument attributes.
318     if (newInputs.size() != oldNumArgs) {
319       newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
320                                         oldType.getResults()));
321 
322       if (ArrayAttr argAttrs = getAllArgAttrs()) {
323         SmallVector<Attribute> newArgAttrs;
324         newArgAttrs.reserve(newInputs.size());
325         for (unsigned i = 0; i != oldNumArgs; ++i)
326           if (!mapper.contains(getArgument(i)))
327             newArgAttrs.push_back(argAttrs[i]);
328         newFunc.setAllArgAttrs(newArgAttrs);
329       }
330     }
331   }
332 
333   /// Clone the current function into the new one and return it.
334   cloneInto(newFunc, mapper);
335   return newFunc;
336 }
clone()337 FuncOp FuncOp::clone() {
338   BlockAndValueMapping mapper;
339   return clone(mapper);
340 }
341 
342 //===----------------------------------------------------------------------===//
343 // ReturnOp
344 //===----------------------------------------------------------------------===//
345 
verify()346 LogicalResult ReturnOp::verify() {
347   auto function = cast<FuncOp>((*this)->getParentOp());
348 
349   // The operand number and types must match the function signature.
350   const auto &results = function.getFunctionType().getResults();
351   if (getNumOperands() != results.size())
352     return emitOpError("has ")
353            << getNumOperands() << " operands, but enclosing function (@"
354            << function.getName() << ") returns " << results.size();
355 
356   for (unsigned i = 0, e = results.size(); i != e; ++i)
357     if (getOperand(i).getType() != results[i])
358       return emitError() << "type of return operand " << i << " ("
359                          << getOperand(i).getType()
360                          << ") doesn't match function result type ("
361                          << results[i] << ")"
362                          << " in function @" << function.getName();
363 
364   return success();
365 }
366 
367 //===----------------------------------------------------------------------===//
368 // TableGen'd op method definitions
369 //===----------------------------------------------------------------------===//
370 
371 #define GET_OP_CLASSES
372 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
373