1 //===- FuncOps.cpp - Func Dialect Operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Func/IR/FuncOps.h"
10 
11 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
12 #include "mlir/IR/BlockAndValueMapping.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/IR/Value.h"
21 #include "mlir/Support/MathExtras.h"
22 #include "mlir/Transforms/InliningUtils.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/StringSwitch.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <numeric>
29 
30 #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
31 
32 using namespace mlir;
33 using namespace mlir::func;
34 
35 //===----------------------------------------------------------------------===//
36 // FuncDialect Interfaces
37 //===----------------------------------------------------------------------===//
38 namespace {
39 /// This class defines the interface for handling inlining with func operations.
40 struct FuncInlinerInterface : public DialectInlinerInterface {
41   using DialectInlinerInterface::DialectInlinerInterface;
42 
43   //===--------------------------------------------------------------------===//
44   // Analysis Hooks
45   //===--------------------------------------------------------------------===//
46 
47   /// All call operations can be inlined.
48   bool isLegalToInline(Operation *call, Operation *callable,
49                        bool wouldBeCloned) const final {
50     return true;
51   }
52 
53   /// All operations can be inlined.
54   bool isLegalToInline(Operation *, Region *, bool,
55                        BlockAndValueMapping &) const final {
56     return true;
57   }
58 
59   //===--------------------------------------------------------------------===//
60   // Transformation Hooks
61   //===--------------------------------------------------------------------===//
62 
63   /// Handle the given inlined terminator by replacing it with a new operation
64   /// as necessary.
65   void handleTerminator(Operation *op, Block *newDest) const final {
66     // Only return needs to be handled here.
67     auto returnOp = dyn_cast<ReturnOp>(op);
68     if (!returnOp)
69       return;
70 
71     // Replace the return with a branch to the dest.
72     OpBuilder builder(op);
73     builder.create<cf::BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
74     op->erase();
75   }
76 
77   /// Handle the given inlined terminator by replacing it with a new operation
78   /// as necessary.
79   void handleTerminator(Operation *op,
80                         ArrayRef<Value> valuesToRepl) const final {
81     // Only return needs to be handled here.
82     auto returnOp = cast<ReturnOp>(op);
83 
84     // Replace the values directly with the return operands.
85     assert(returnOp.getNumOperands() == valuesToRepl.size());
86     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
87       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
88   }
89 };
90 } // namespace
91 
92 //===----------------------------------------------------------------------===//
93 // FuncDialect
94 //===----------------------------------------------------------------------===//
95 
96 void FuncDialect::initialize() {
97   addOperations<
98 #define GET_OP_LIST
99 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
100       >();
101   addInterfaces<FuncInlinerInterface>();
102 }
103 
104 /// Materialize a single constant operation from a given attribute value with
105 /// the desired resultant type.
106 Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value,
107                                             Type type, Location loc) {
108   if (ConstantOp::isBuildableWith(value, type))
109     return builder.create<ConstantOp>(loc, type,
110                                       value.cast<FlatSymbolRefAttr>());
111   return nullptr;
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // CallOp
116 //===----------------------------------------------------------------------===//
117 
118 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
119   // Check that the callee attribute was specified.
120   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
121   if (!fnAttr)
122     return emitOpError("requires a 'callee' symbol reference attribute");
123   FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
124   if (!fn)
125     return emitOpError() << "'" << fnAttr.getValue()
126                          << "' does not reference a valid function";
127 
128   // Verify that the operand and result types match the callee.
129   auto fnType = fn.getType();
130   if (fnType.getNumInputs() != getNumOperands())
131     return emitOpError("incorrect number of operands for callee");
132 
133   for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
134     if (getOperand(i).getType() != fnType.getInput(i))
135       return emitOpError("operand type mismatch: expected operand type ")
136              << fnType.getInput(i) << ", but provided "
137              << getOperand(i).getType() << " for operand number " << i;
138 
139   if (fnType.getNumResults() != getNumResults())
140     return emitOpError("incorrect number of results for callee");
141 
142   for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
143     if (getResult(i).getType() != fnType.getResult(i)) {
144       auto diag = emitOpError("result type mismatch at index ") << i;
145       diag.attachNote() << "      op result types: " << getResultTypes();
146       diag.attachNote() << "function result types: " << fnType.getResults();
147       return diag;
148     }
149 
150   return success();
151 }
152 
153 FunctionType CallOp::getCalleeType() {
154   return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // CallIndirectOp
159 //===----------------------------------------------------------------------===//
160 
161 /// Fold indirect calls that have a constant function as the callee operand.
162 LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
163                                            PatternRewriter &rewriter) {
164   // Check that the callee is a constant callee.
165   SymbolRefAttr calledFn;
166   if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
167     return failure();
168 
169   // Replace with a direct call.
170   rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
171                                       indirectCall.getResultTypes(),
172                                       indirectCall.getArgOperands());
173   return success();
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // ConstantOp
178 //===----------------------------------------------------------------------===//
179 
180 LogicalResult ConstantOp::verify() {
181   StringRef fnName = getValue();
182   Type type = getType();
183 
184   // Try to find the referenced function.
185   auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnName);
186   if (!fn)
187     return emitOpError() << "reference to undefined function '" << fnName
188                          << "'";
189 
190   // Check that the referenced function has the correct type.
191   if (fn.getType() != type)
192     return emitOpError("reference to function with mismatched type");
193 
194   return success();
195 }
196 
197 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
198   assert(operands.empty() && "constant has no operands");
199   return getValueAttr();
200 }
201 
202 void ConstantOp::getAsmResultNames(
203     function_ref<void(Value, StringRef)> setNameFn) {
204   setNameFn(getResult(), "f");
205 }
206 
207 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
208   return value.isa<FlatSymbolRefAttr>() && type.isa<FunctionType>();
209 }
210 
211 //===----------------------------------------------------------------------===//
212 // ReturnOp
213 //===----------------------------------------------------------------------===//
214 
215 LogicalResult ReturnOp::verify() {
216   auto function = cast<FuncOp>((*this)->getParentOp());
217 
218   // The operand number and types must match the function signature.
219   const auto &results = function.getType().getResults();
220   if (getNumOperands() != results.size())
221     return emitOpError("has ")
222            << getNumOperands() << " operands, but enclosing function (@"
223            << function.getName() << ") returns " << results.size();
224 
225   for (unsigned i = 0, e = results.size(); i != e; ++i)
226     if (getOperand(i).getType() != results[i])
227       return emitError() << "type of return operand " << i << " ("
228                          << getOperand(i).getType()
229                          << ") doesn't match function result type ("
230                          << results[i] << ")"
231                          << " in function @" << function.getName();
232 
233   return success();
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // TableGen'd op method definitions
238 //===----------------------------------------------------------------------===//
239 
240 #define GET_OP_CLASSES
241 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
242