1 //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MLProgram/IR/MLProgram.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/FunctionImplementation.h" 12 13 using namespace mlir; 14 using namespace mlir::ml_program; 15 16 //===----------------------------------------------------------------------===// 17 // TableGen'd op method definitions 18 //===----------------------------------------------------------------------===// 19 20 #define GET_OP_CLASSES 21 #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc" 22 23 //===----------------------------------------------------------------------===// 24 // FuncOp 25 //===----------------------------------------------------------------------===// 26 27 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 28 auto buildFuncType = 29 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 30 function_interface_impl::VariadicFlag, 31 std::string &) { return builder.getFunctionType(argTypes, results); }; 32 33 return function_interface_impl::parseFunctionOp( 34 parser, result, /*allowVariadic=*/false, buildFuncType); 35 } 36 37 void FuncOp::print(OpAsmPrinter &p) { 38 function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 39 } 40 41 //===----------------------------------------------------------------------===// 42 // SubgraphOp 43 //===----------------------------------------------------------------------===// 44 45 ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) { 46 auto buildFuncType = 47 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 48 function_interface_impl::VariadicFlag, 49 std::string &) { return builder.getFunctionType(argTypes, results); }; 50 51 return function_interface_impl::parseFunctionOp( 52 parser, result, /*allowVariadic=*/false, buildFuncType); 53 } 54 55 void SubgraphOp::print(OpAsmPrinter &p) { 56 function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 57 } 58 59 //===----------------------------------------------------------------------===// 60 // OutputOp 61 //===----------------------------------------------------------------------===// 62 63 LogicalResult OutputOp::verify() { 64 auto function = cast<SubgraphOp>((*this)->getParentOp()); 65 66 // The operand number and types must match the function signature. 67 const auto &results = function.getFunctionType().getResults(); 68 if (getNumOperands() != results.size()) 69 return emitOpError("has ") 70 << getNumOperands() << " operands, but enclosing function (@" 71 << function.getName() << ") outputs " << results.size(); 72 73 for (unsigned i = 0, e = results.size(); i != e; ++i) 74 if (getOperand(i).getType() != results[i]) 75 return emitError() << "type of output operand " << i << " (" 76 << getOperand(i).getType() 77 << ") doesn't match function result type (" 78 << results[i] << ")" 79 << " in function @" << function.getName(); 80 81 return success(); 82 } 83 84 //===----------------------------------------------------------------------===// 85 // ReturnOp 86 //===----------------------------------------------------------------------===// 87 88 LogicalResult ReturnOp::verify() { 89 auto function = cast<FuncOp>((*this)->getParentOp()); 90 91 // The operand number and types must match the function signature. 92 const auto &results = function.getFunctionType().getResults(); 93 if (getNumOperands() != results.size()) 94 return emitOpError("has ") 95 << getNumOperands() << " operands, but enclosing function (@" 96 << function.getName() << ") returns " << results.size(); 97 98 for (unsigned i = 0, e = results.size(); i != e; ++i) 99 if (getOperand(i).getType() != results[i]) 100 return emitError() << "type of return operand " << i << " (" 101 << getOperand(i).getType() 102 << ") doesn't match function result type (" 103 << results[i] << ")" 104 << " in function @" << function.getName(); 105 106 return success(); 107 } 108