//===- FuncOps.cpp - Func Dialect Operations ------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" #include #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc" using namespace mlir; using namespace mlir::func; //===----------------------------------------------------------------------===// // FuncDialect Interfaces //===----------------------------------------------------------------------===// namespace { /// This class defines the interface for handling inlining with func operations. struct FuncInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; //===--------------------------------------------------------------------===// // Analysis Hooks //===--------------------------------------------------------------------===// /// All call operations can be inlined. bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const final { return true; } /// All operations can be inlined. bool isLegalToInline(Operation *, Region *, bool, BlockAndValueMapping &) const final { return true; } //===--------------------------------------------------------------------===// // Transformation Hooks //===--------------------------------------------------------------------===// /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. void handleTerminator(Operation *op, Block *newDest) const final { // Only return needs to be handled here. auto returnOp = dyn_cast(op); if (!returnOp) return; // Replace the return with a branch to the dest. OpBuilder builder(op); builder.create(op->getLoc(), newDest, returnOp.getOperands()); op->erase(); } /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. void handleTerminator(Operation *op, ArrayRef valuesToRepl) const final { // Only return needs to be handled here. auto returnOp = cast(op); // Replace the values directly with the return operands. assert(returnOp.getNumOperands() == valuesToRepl.size()); for (const auto &it : llvm::enumerate(returnOp.getOperands())) valuesToRepl[it.index()].replaceAllUsesWith(it.value()); } }; } // namespace //===----------------------------------------------------------------------===// // FuncDialect //===----------------------------------------------------------------------===// void FuncDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" >(); addInterfaces(); } /// Materialize a single constant operation from a given attribute value with /// the desired resultant type. Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (ConstantOp::isBuildableWith(value, type)) return builder.create(loc, type, value.cast()); return nullptr; } //===----------------------------------------------------------------------===// // CallOp //===----------------------------------------------------------------------===// LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // Check that the callee attribute was specified. auto fnAttr = (*this)->getAttrOfType("callee"); if (!fnAttr) return emitOpError("requires a 'callee' symbol reference attribute"); FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); if (!fn) return emitOpError() << "'" << fnAttr.getValue() << "' does not reference a valid function"; // Verify that the operand and result types match the callee. auto fnType = fn.getType(); if (fnType.getNumInputs() != getNumOperands()) return emitOpError("incorrect number of operands for callee"); for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) if (getOperand(i).getType() != fnType.getInput(i)) return emitOpError("operand type mismatch: expected operand type ") << fnType.getInput(i) << ", but provided " << getOperand(i).getType() << " for operand number " << i; if (fnType.getNumResults() != getNumResults()) return emitOpError("incorrect number of results for callee"); for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) if (getResult(i).getType() != fnType.getResult(i)) { auto diag = emitOpError("result type mismatch at index ") << i; diag.attachNote() << " op result types: " << getResultTypes(); diag.attachNote() << "function result types: " << fnType.getResults(); return diag; } return success(); } FunctionType CallOp::getCalleeType() { return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); } //===----------------------------------------------------------------------===// // CallIndirectOp //===----------------------------------------------------------------------===// /// Fold indirect calls that have a constant function as the callee operand. LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall, PatternRewriter &rewriter) { // Check that the callee is a constant callee. SymbolRefAttr calledFn; if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) return failure(); // Replace with a direct call. rewriter.replaceOpWithNewOp(indirectCall, calledFn, indirectCall.getResultTypes(), indirectCall.getArgOperands()); return success(); } //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// LogicalResult ConstantOp::verify() { StringRef fnName = getValue(); Type type = getType(); // Try to find the referenced function. auto fn = (*this)->getParentOfType().lookupSymbol(fnName); if (!fn) return emitOpError() << "reference to undefined function '" << fnName << "'"; // Check that the referenced function has the correct type. if (fn.getType() != type) return emitOpError("reference to function with mismatched type"); return success(); } OpFoldResult ConstantOp::fold(ArrayRef operands) { assert(operands.empty() && "constant has no operands"); return getValueAttr(); } void ConstantOp::getAsmResultNames( function_ref setNameFn) { setNameFn(getResult(), "f"); } bool ConstantOp::isBuildableWith(Attribute value, Type type) { return value.isa() && type.isa(); } //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// LogicalResult ReturnOp::verify() { auto function = cast((*this)->getParentOp()); // The operand number and types must match the function signature. const auto &results = function.getType().getResults(); if (getNumOperands() != results.size()) return emitOpError("has ") << getNumOperands() << " operands, but enclosing function (@" << function.getName() << ") returns " << results.size(); for (unsigned i = 0, e = results.size(); i != e; ++i) if (getOperand(i).getType() != results[i]) return emitError() << "type of return operand " << i << " (" << getOperand(i).getType() << ") doesn't match function result type (" << results[i] << ")" << " in function @" << function.getName(); return success(); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"