1 //===- FuncOps.cpp - Func Dialect Operations ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Func/IR/FuncOps.h" 10 11 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 12 #include "mlir/IR/BlockAndValueMapping.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/Matchers.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/PatternMatch.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/IR/Value.h" 21 #include "mlir/Support/MathExtras.h" 22 #include "mlir/Transforms/InliningUtils.h" 23 #include "llvm/ADT/APFloat.h" 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/ADT/StringSwitch.h" 26 #include "llvm/Support/FormatVariadic.h" 27 #include "llvm/Support/raw_ostream.h" 28 #include <numeric> 29 30 #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc" 31 32 using namespace mlir; 33 using namespace mlir::func; 34 35 //===----------------------------------------------------------------------===// 36 // FuncDialect Interfaces 37 //===----------------------------------------------------------------------===// 38 namespace { 39 /// This class defines the interface for handling inlining with func operations. 40 struct FuncInlinerInterface : public DialectInlinerInterface { 41 using DialectInlinerInterface::DialectInlinerInterface; 42 43 //===--------------------------------------------------------------------===// 44 // Analysis Hooks 45 //===--------------------------------------------------------------------===// 46 47 /// All call operations can be inlined. 48 bool isLegalToInline(Operation *call, Operation *callable, 49 bool wouldBeCloned) const final { 50 return true; 51 } 52 53 /// All operations can be inlined. 54 bool isLegalToInline(Operation *, Region *, bool, 55 BlockAndValueMapping &) const final { 56 return true; 57 } 58 59 //===--------------------------------------------------------------------===// 60 // Transformation Hooks 61 //===--------------------------------------------------------------------===// 62 63 /// Handle the given inlined terminator by replacing it with a new operation 64 /// as necessary. 65 void handleTerminator(Operation *op, Block *newDest) const final { 66 // Only return needs to be handled here. 67 auto returnOp = dyn_cast<ReturnOp>(op); 68 if (!returnOp) 69 return; 70 71 // Replace the return with a branch to the dest. 72 OpBuilder builder(op); 73 builder.create<cf::BranchOp>(op->getLoc(), newDest, returnOp.getOperands()); 74 op->erase(); 75 } 76 77 /// Handle the given inlined terminator by replacing it with a new operation 78 /// as necessary. 79 void handleTerminator(Operation *op, 80 ArrayRef<Value> valuesToRepl) const final { 81 // Only return needs to be handled here. 82 auto returnOp = cast<ReturnOp>(op); 83 84 // Replace the values directly with the return operands. 85 assert(returnOp.getNumOperands() == valuesToRepl.size()); 86 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 87 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 88 } 89 }; 90 } // namespace 91 92 //===----------------------------------------------------------------------===// 93 // FuncDialect 94 //===----------------------------------------------------------------------===// 95 96 void FuncDialect::initialize() { 97 addOperations< 98 #define GET_OP_LIST 99 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" 100 >(); 101 addInterfaces<FuncInlinerInterface>(); 102 } 103 104 /// Materialize a single constant operation from a given attribute value with 105 /// the desired resultant type. 106 Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value, 107 Type type, Location loc) { 108 if (ConstantOp::isBuildableWith(value, type)) 109 return builder.create<ConstantOp>(loc, type, 110 value.cast<FlatSymbolRefAttr>()); 111 return nullptr; 112 } 113 114 //===----------------------------------------------------------------------===// 115 // CallOp 116 //===----------------------------------------------------------------------===// 117 118 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 119 // Check that the callee attribute was specified. 120 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); 121 if (!fnAttr) 122 return emitOpError("requires a 'callee' symbol reference attribute"); 123 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr); 124 if (!fn) 125 return emitOpError() << "'" << fnAttr.getValue() 126 << "' does not reference a valid function"; 127 128 // Verify that the operand and result types match the callee. 129 auto fnType = fn.getType(); 130 if (fnType.getNumInputs() != getNumOperands()) 131 return emitOpError("incorrect number of operands for callee"); 132 133 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) 134 if (getOperand(i).getType() != fnType.getInput(i)) 135 return emitOpError("operand type mismatch: expected operand type ") 136 << fnType.getInput(i) << ", but provided " 137 << getOperand(i).getType() << " for operand number " << i; 138 139 if (fnType.getNumResults() != getNumResults()) 140 return emitOpError("incorrect number of results for callee"); 141 142 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) 143 if (getResult(i).getType() != fnType.getResult(i)) { 144 auto diag = emitOpError("result type mismatch at index ") << i; 145 diag.attachNote() << " op result types: " << getResultTypes(); 146 diag.attachNote() << "function result types: " << fnType.getResults(); 147 return diag; 148 } 149 150 return success(); 151 } 152 153 FunctionType CallOp::getCalleeType() { 154 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); 155 } 156 157 //===----------------------------------------------------------------------===// 158 // CallIndirectOp 159 //===----------------------------------------------------------------------===// 160 161 /// Fold indirect calls that have a constant function as the callee operand. 162 LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall, 163 PatternRewriter &rewriter) { 164 // Check that the callee is a constant callee. 165 SymbolRefAttr calledFn; 166 if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) 167 return failure(); 168 169 // Replace with a direct call. 170 rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, 171 indirectCall.getResultTypes(), 172 indirectCall.getArgOperands()); 173 return success(); 174 } 175 176 //===----------------------------------------------------------------------===// 177 // ConstantOp 178 //===----------------------------------------------------------------------===// 179 180 LogicalResult ConstantOp::verify() { 181 StringRef fnName = getValue(); 182 Type type = getType(); 183 184 // Try to find the referenced function. 185 auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnName); 186 if (!fn) 187 return emitOpError() << "reference to undefined function '" << fnName 188 << "'"; 189 190 // Check that the referenced function has the correct type. 191 if (fn.getType() != type) 192 return emitOpError("reference to function with mismatched type"); 193 194 return success(); 195 } 196 197 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { 198 assert(operands.empty() && "constant has no operands"); 199 return getValueAttr(); 200 } 201 202 void ConstantOp::getAsmResultNames( 203 function_ref<void(Value, StringRef)> setNameFn) { 204 setNameFn(getResult(), "f"); 205 } 206 207 bool ConstantOp::isBuildableWith(Attribute value, Type type) { 208 return value.isa<FlatSymbolRefAttr>() && type.isa<FunctionType>(); 209 } 210 211 //===----------------------------------------------------------------------===// 212 // ReturnOp 213 //===----------------------------------------------------------------------===// 214 215 LogicalResult ReturnOp::verify() { 216 auto function = cast<FuncOp>((*this)->getParentOp()); 217 218 // The operand number and types must match the function signature. 219 const auto &results = function.getType().getResults(); 220 if (getNumOperands() != results.size()) 221 return emitOpError("has ") 222 << getNumOperands() << " operands, but enclosing function (@" 223 << function.getName() << ") returns " << results.size(); 224 225 for (unsigned i = 0, e = results.size(); i != e; ++i) 226 if (getOperand(i).getType() != results[i]) 227 return emitError() << "type of return operand " << i << " (" 228 << getOperand(i).getType() 229 << ") doesn't match function result type (" 230 << results[i] << ")" 231 << " in function @" << function.getName(); 232 233 return success(); 234 } 235 236 //===----------------------------------------------------------------------===// 237 // TableGen'd op method definitions 238 //===----------------------------------------------------------------------===// 239 240 #define GET_OP_CLASSES 241 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" 242