//===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/IR/Builders.h" #include "mlir/IR/FunctionImplementation.h" using namespace mlir; using namespace mlir::ml_program; //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc" //===----------------------------------------------------------------------===// // FuncOp //===----------------------------------------------------------------------===// ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { auto buildFuncType = [](Builder &builder, ArrayRef argTypes, ArrayRef results, function_interface_impl::VariadicFlag, std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, buildFuncType); } void FuncOp::print(OpAsmPrinter &p) { function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); } //===----------------------------------------------------------------------===// // SubgraphOp //===----------------------------------------------------------------------===// ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) { auto buildFuncType = [](Builder &builder, ArrayRef argTypes, ArrayRef results, function_interface_impl::VariadicFlag, std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, buildFuncType); } void SubgraphOp::print(OpAsmPrinter &p) { function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); } //===----------------------------------------------------------------------===// // OutputOp //===----------------------------------------------------------------------===// LogicalResult OutputOp::verify() { auto function = cast((*this)->getParentOp()); // The operand number and types must match the function signature. const auto &results = function.getFunctionType().getResults(); if (getNumOperands() != results.size()) return emitOpError("has ") << getNumOperands() << " operands, but enclosing function (@" << function.getName() << ") outputs " << results.size(); for (unsigned i = 0, e = results.size(); i != e; ++i) if (getOperand(i).getType() != results[i]) return emitError() << "type of output operand " << i << " (" << getOperand(i).getType() << ") doesn't match function result type (" << results[i] << ")" << " in function @" << function.getName(); return success(); } //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// LogicalResult ReturnOp::verify() { auto function = cast((*this)->getParentOp()); // The operand number and types must match the function signature. const auto &results = function.getFunctionType().getResults(); if (getNumOperands() != results.size()) return emitOpError("has ") << getNumOperands() << " operands, but enclosing function (@" << function.getName() << ") returns " << results.size(); for (unsigned i = 0, e = results.size(); i != e; ++i) if (getOperand(i).getType() != results[i]) return emitError() << "type of return operand " << i << " (" << getOperand(i).getType() << ") doesn't match function result type (" << results[i] << ")" << " in function @" << function.getName(); return success(); }