1*61352a58SStella Laurenzo //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===// 2*61352a58SStella Laurenzo // 3*61352a58SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*61352a58SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5*61352a58SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*61352a58SStella Laurenzo // 7*61352a58SStella Laurenzo //===----------------------------------------------------------------------===// 8*61352a58SStella Laurenzo 9*61352a58SStella Laurenzo #include "mlir/Dialect/MLProgram/IR/MLProgram.h" 10*61352a58SStella Laurenzo #include "mlir/IR/Builders.h" 11*61352a58SStella Laurenzo #include "mlir/IR/FunctionImplementation.h" 12*61352a58SStella Laurenzo 13*61352a58SStella Laurenzo using namespace mlir; 14*61352a58SStella Laurenzo using namespace mlir::ml_program; 15*61352a58SStella Laurenzo 16*61352a58SStella Laurenzo //===----------------------------------------------------------------------===// 17*61352a58SStella Laurenzo // TableGen'd op method definitions 18*61352a58SStella Laurenzo //===----------------------------------------------------------------------===// 19*61352a58SStella Laurenzo 20*61352a58SStella Laurenzo #define GET_OP_CLASSES 21*61352a58SStella Laurenzo #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc" 22*61352a58SStella Laurenzo 23*61352a58SStella Laurenzo //===----------------------------------------------------------------------===// 24*61352a58SStella Laurenzo // FuncOp 25*61352a58SStella Laurenzo //===----------------------------------------------------------------------===// 26*61352a58SStella Laurenzo 27*61352a58SStella Laurenzo ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 28*61352a58SStella Laurenzo auto buildFuncType = 29*61352a58SStella Laurenzo [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 30*61352a58SStella Laurenzo function_interface_impl::VariadicFlag, 31*61352a58SStella Laurenzo std::string &) { return builder.getFunctionType(argTypes, results); }; 32*61352a58SStella Laurenzo 33*61352a58SStella Laurenzo return function_interface_impl::parseFunctionOp( 34*61352a58SStella Laurenzo parser, result, /*allowVariadic=*/false, buildFuncType); 35*61352a58SStella Laurenzo } 36*61352a58SStella Laurenzo 37*61352a58SStella Laurenzo void FuncOp::print(OpAsmPrinter &p) { 38*61352a58SStella Laurenzo function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 39*61352a58SStella Laurenzo } 40*61352a58SStella Laurenzo 41*61352a58SStella Laurenzo //===----------------------------------------------------------------------===// 42*61352a58SStella Laurenzo // SubgraphOp 43*61352a58SStella Laurenzo //===----------------------------------------------------------------------===// 44*61352a58SStella Laurenzo 45*61352a58SStella Laurenzo ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) { 46*61352a58SStella Laurenzo auto buildFuncType = 47*61352a58SStella Laurenzo [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 48*61352a58SStella Laurenzo function_interface_impl::VariadicFlag, 49*61352a58SStella Laurenzo std::string &) { return builder.getFunctionType(argTypes, results); }; 50*61352a58SStella Laurenzo 51*61352a58SStella Laurenzo return function_interface_impl::parseFunctionOp( 52*61352a58SStella Laurenzo parser, result, /*allowVariadic=*/false, buildFuncType); 53*61352a58SStella Laurenzo } 54*61352a58SStella Laurenzo 55*61352a58SStella Laurenzo void SubgraphOp::print(OpAsmPrinter &p) { 56*61352a58SStella Laurenzo function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 57*61352a58SStella Laurenzo } 58*61352a58SStella Laurenzo 59*61352a58SStella Laurenzo //===----------------------------------------------------------------------===// 60*61352a58SStella Laurenzo // OutputOp 61*61352a58SStella Laurenzo //===----------------------------------------------------------------------===// 62*61352a58SStella Laurenzo 63*61352a58SStella Laurenzo LogicalResult OutputOp::verify() { 64*61352a58SStella Laurenzo auto function = cast<SubgraphOp>((*this)->getParentOp()); 65*61352a58SStella Laurenzo 66*61352a58SStella Laurenzo // The operand number and types must match the function signature. 67*61352a58SStella Laurenzo const auto &results = function.getFunctionType().getResults(); 68*61352a58SStella Laurenzo if (getNumOperands() != results.size()) 69*61352a58SStella Laurenzo return emitOpError("has ") 70*61352a58SStella Laurenzo << getNumOperands() << " operands, but enclosing function (@" 71*61352a58SStella Laurenzo << function.getName() << ") outputs " << results.size(); 72*61352a58SStella Laurenzo 73*61352a58SStella Laurenzo for (unsigned i = 0, e = results.size(); i != e; ++i) 74*61352a58SStella Laurenzo if (getOperand(i).getType() != results[i]) 75*61352a58SStella Laurenzo return emitError() << "type of output operand " << i << " (" 76*61352a58SStella Laurenzo << getOperand(i).getType() 77*61352a58SStella Laurenzo << ") doesn't match function result type (" 78*61352a58SStella Laurenzo << results[i] << ")" 79*61352a58SStella Laurenzo << " in function @" << function.getName(); 80*61352a58SStella Laurenzo 81*61352a58SStella Laurenzo return success(); 82*61352a58SStella Laurenzo } 83*61352a58SStella Laurenzo 84*61352a58SStella Laurenzo //===----------------------------------------------------------------------===// 85*61352a58SStella Laurenzo // ReturnOp 86*61352a58SStella Laurenzo //===----------------------------------------------------------------------===// 87*61352a58SStella Laurenzo 88*61352a58SStella Laurenzo LogicalResult ReturnOp::verify() { 89*61352a58SStella Laurenzo auto function = cast<FuncOp>((*this)->getParentOp()); 90*61352a58SStella Laurenzo 91*61352a58SStella Laurenzo // The operand number and types must match the function signature. 92*61352a58SStella Laurenzo const auto &results = function.getFunctionType().getResults(); 93*61352a58SStella Laurenzo if (getNumOperands() != results.size()) 94*61352a58SStella Laurenzo return emitOpError("has ") 95*61352a58SStella Laurenzo << getNumOperands() << " operands, but enclosing function (@" 96*61352a58SStella Laurenzo << function.getName() << ") returns " << results.size(); 97*61352a58SStella Laurenzo 98*61352a58SStella Laurenzo for (unsigned i = 0, e = results.size(); i != e; ++i) 99*61352a58SStella Laurenzo if (getOperand(i).getType() != results[i]) 100*61352a58SStella Laurenzo return emitError() << "type of return operand " << i << " (" 101*61352a58SStella Laurenzo << getOperand(i).getType() 102*61352a58SStella Laurenzo << ") doesn't match function result type (" 103*61352a58SStella Laurenzo << results[i] << ")" 104*61352a58SStella Laurenzo << " in function @" << function.getName(); 105*61352a58SStella Laurenzo 106*61352a58SStella Laurenzo return success(); 107*61352a58SStella Laurenzo } 108