1*61352a58SStella Laurenzo //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
2*61352a58SStella Laurenzo //
3*61352a58SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*61352a58SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5*61352a58SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*61352a58SStella Laurenzo //
7*61352a58SStella Laurenzo //===----------------------------------------------------------------------===//
8*61352a58SStella Laurenzo 
9*61352a58SStella Laurenzo #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
10*61352a58SStella Laurenzo #include "mlir/IR/Builders.h"
11*61352a58SStella Laurenzo #include "mlir/IR/FunctionImplementation.h"
12*61352a58SStella Laurenzo 
13*61352a58SStella Laurenzo using namespace mlir;
14*61352a58SStella Laurenzo using namespace mlir::ml_program;
15*61352a58SStella Laurenzo 
16*61352a58SStella Laurenzo //===----------------------------------------------------------------------===//
17*61352a58SStella Laurenzo // TableGen'd op method definitions
18*61352a58SStella Laurenzo //===----------------------------------------------------------------------===//
19*61352a58SStella Laurenzo 
20*61352a58SStella Laurenzo #define GET_OP_CLASSES
21*61352a58SStella Laurenzo #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
22*61352a58SStella Laurenzo 
23*61352a58SStella Laurenzo //===----------------------------------------------------------------------===//
24*61352a58SStella Laurenzo // FuncOp
25*61352a58SStella Laurenzo //===----------------------------------------------------------------------===//
26*61352a58SStella Laurenzo 
27*61352a58SStella Laurenzo ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
28*61352a58SStella Laurenzo   auto buildFuncType =
29*61352a58SStella Laurenzo       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
30*61352a58SStella Laurenzo          function_interface_impl::VariadicFlag,
31*61352a58SStella Laurenzo          std::string &) { return builder.getFunctionType(argTypes, results); };
32*61352a58SStella Laurenzo 
33*61352a58SStella Laurenzo   return function_interface_impl::parseFunctionOp(
34*61352a58SStella Laurenzo       parser, result, /*allowVariadic=*/false, buildFuncType);
35*61352a58SStella Laurenzo }
36*61352a58SStella Laurenzo 
37*61352a58SStella Laurenzo void FuncOp::print(OpAsmPrinter &p) {
38*61352a58SStella Laurenzo   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
39*61352a58SStella Laurenzo }
40*61352a58SStella Laurenzo 
41*61352a58SStella Laurenzo //===----------------------------------------------------------------------===//
42*61352a58SStella Laurenzo // SubgraphOp
43*61352a58SStella Laurenzo //===----------------------------------------------------------------------===//
44*61352a58SStella Laurenzo 
45*61352a58SStella Laurenzo ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
46*61352a58SStella Laurenzo   auto buildFuncType =
47*61352a58SStella Laurenzo       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
48*61352a58SStella Laurenzo          function_interface_impl::VariadicFlag,
49*61352a58SStella Laurenzo          std::string &) { return builder.getFunctionType(argTypes, results); };
50*61352a58SStella Laurenzo 
51*61352a58SStella Laurenzo   return function_interface_impl::parseFunctionOp(
52*61352a58SStella Laurenzo       parser, result, /*allowVariadic=*/false, buildFuncType);
53*61352a58SStella Laurenzo }
54*61352a58SStella Laurenzo 
55*61352a58SStella Laurenzo void SubgraphOp::print(OpAsmPrinter &p) {
56*61352a58SStella Laurenzo   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
57*61352a58SStella Laurenzo }
58*61352a58SStella Laurenzo 
59*61352a58SStella Laurenzo //===----------------------------------------------------------------------===//
60*61352a58SStella Laurenzo // OutputOp
61*61352a58SStella Laurenzo //===----------------------------------------------------------------------===//
62*61352a58SStella Laurenzo 
63*61352a58SStella Laurenzo LogicalResult OutputOp::verify() {
64*61352a58SStella Laurenzo   auto function = cast<SubgraphOp>((*this)->getParentOp());
65*61352a58SStella Laurenzo 
66*61352a58SStella Laurenzo   // The operand number and types must match the function signature.
67*61352a58SStella Laurenzo   const auto &results = function.getFunctionType().getResults();
68*61352a58SStella Laurenzo   if (getNumOperands() != results.size())
69*61352a58SStella Laurenzo     return emitOpError("has ")
70*61352a58SStella Laurenzo            << getNumOperands() << " operands, but enclosing function (@"
71*61352a58SStella Laurenzo            << function.getName() << ") outputs " << results.size();
72*61352a58SStella Laurenzo 
73*61352a58SStella Laurenzo   for (unsigned i = 0, e = results.size(); i != e; ++i)
74*61352a58SStella Laurenzo     if (getOperand(i).getType() != results[i])
75*61352a58SStella Laurenzo       return emitError() << "type of output operand " << i << " ("
76*61352a58SStella Laurenzo                          << getOperand(i).getType()
77*61352a58SStella Laurenzo                          << ") doesn't match function result type ("
78*61352a58SStella Laurenzo                          << results[i] << ")"
79*61352a58SStella Laurenzo                          << " in function @" << function.getName();
80*61352a58SStella Laurenzo 
81*61352a58SStella Laurenzo   return success();
82*61352a58SStella Laurenzo }
83*61352a58SStella Laurenzo 
84*61352a58SStella Laurenzo //===----------------------------------------------------------------------===//
85*61352a58SStella Laurenzo // ReturnOp
86*61352a58SStella Laurenzo //===----------------------------------------------------------------------===//
87*61352a58SStella Laurenzo 
88*61352a58SStella Laurenzo LogicalResult ReturnOp::verify() {
89*61352a58SStella Laurenzo   auto function = cast<FuncOp>((*this)->getParentOp());
90*61352a58SStella Laurenzo 
91*61352a58SStella Laurenzo   // The operand number and types must match the function signature.
92*61352a58SStella Laurenzo   const auto &results = function.getFunctionType().getResults();
93*61352a58SStella Laurenzo   if (getNumOperands() != results.size())
94*61352a58SStella Laurenzo     return emitOpError("has ")
95*61352a58SStella Laurenzo            << getNumOperands() << " operands, but enclosing function (@"
96*61352a58SStella Laurenzo            << function.getName() << ") returns " << results.size();
97*61352a58SStella Laurenzo 
98*61352a58SStella Laurenzo   for (unsigned i = 0, e = results.size(); i != e; ++i)
99*61352a58SStella Laurenzo     if (getOperand(i).getType() != results[i])
100*61352a58SStella Laurenzo       return emitError() << "type of return operand " << i << " ("
101*61352a58SStella Laurenzo                          << getOperand(i).getType()
102*61352a58SStella Laurenzo                          << ") doesn't match function result type ("
103*61352a58SStella Laurenzo                          << results[i] << ")"
104*61352a58SStella Laurenzo                          << " in function @" << function.getName();
105*61352a58SStella Laurenzo 
106*61352a58SStella Laurenzo   return success();
107*61352a58SStella Laurenzo }
108