1 //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/FunctionImplementation.h"
12 
13 using namespace mlir;
14 using namespace mlir::ml_program;
15 
16 //===----------------------------------------------------------------------===//
17 // TableGen'd op method definitions
18 //===----------------------------------------------------------------------===//
19 
20 #define GET_OP_CLASSES
21 #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
22 
23 //===----------------------------------------------------------------------===//
24 // FuncOp
25 //===----------------------------------------------------------------------===//
26 
27 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
28   auto buildFuncType =
29       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
30          function_interface_impl::VariadicFlag,
31          std::string &) { return builder.getFunctionType(argTypes, results); };
32 
33   return function_interface_impl::parseFunctionOp(
34       parser, result, /*allowVariadic=*/false, buildFuncType);
35 }
36 
37 void FuncOp::print(OpAsmPrinter &p) {
38   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
39 }
40 
41 //===----------------------------------------------------------------------===//
42 // SubgraphOp
43 //===----------------------------------------------------------------------===//
44 
45 ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
46   auto buildFuncType =
47       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
48          function_interface_impl::VariadicFlag,
49          std::string &) { return builder.getFunctionType(argTypes, results); };
50 
51   return function_interface_impl::parseFunctionOp(
52       parser, result, /*allowVariadic=*/false, buildFuncType);
53 }
54 
55 void SubgraphOp::print(OpAsmPrinter &p) {
56   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
57 }
58 
59 //===----------------------------------------------------------------------===//
60 // OutputOp
61 //===----------------------------------------------------------------------===//
62 
63 LogicalResult OutputOp::verify() {
64   auto function = cast<SubgraphOp>((*this)->getParentOp());
65 
66   // The operand number and types must match the function signature.
67   const auto &results = function.getFunctionType().getResults();
68   if (getNumOperands() != results.size())
69     return emitOpError("has ")
70            << getNumOperands() << " operands, but enclosing function (@"
71            << function.getName() << ") outputs " << results.size();
72 
73   for (unsigned i = 0, e = results.size(); i != e; ++i)
74     if (getOperand(i).getType() != results[i])
75       return emitError() << "type of output operand " << i << " ("
76                          << getOperand(i).getType()
77                          << ") doesn't match function result type ("
78                          << results[i] << ")"
79                          << " in function @" << function.getName();
80 
81   return success();
82 }
83 
84 //===----------------------------------------------------------------------===//
85 // ReturnOp
86 //===----------------------------------------------------------------------===//
87 
88 LogicalResult ReturnOp::verify() {
89   auto function = cast<FuncOp>((*this)->getParentOp());
90 
91   // The operand number and types must match the function signature.
92   const auto &results = function.getFunctionType().getResults();
93   if (getNumOperands() != results.size())
94     return emitOpError("has ")
95            << getNumOperands() << " operands, but enclosing function (@"
96            << function.getName() << ") returns " << results.size();
97 
98   for (unsigned i = 0, e = results.size(); i != e; ++i)
99     if (getOperand(i).getType() != results[i])
100       return emitError() << "type of return operand " << i << " ("
101                          << getOperand(i).getType()
102                          << ") doesn't match function result type ("
103                          << results[i] << ")"
104                          << " in function @" << function.getName();
105 
106   return success();
107 }
108