161352a58SStella Laurenzo //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
261352a58SStella Laurenzo //
361352a58SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
461352a58SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
561352a58SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
661352a58SStella Laurenzo //
761352a58SStella Laurenzo //===----------------------------------------------------------------------===//
861352a58SStella Laurenzo 
961352a58SStella Laurenzo #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
1061352a58SStella Laurenzo #include "mlir/IR/Builders.h"
1161352a58SStella Laurenzo #include "mlir/IR/FunctionImplementation.h"
1261352a58SStella Laurenzo 
1361352a58SStella Laurenzo using namespace mlir;
1461352a58SStella Laurenzo using namespace mlir::ml_program;
1561352a58SStella Laurenzo 
1661352a58SStella Laurenzo //===----------------------------------------------------------------------===//
17*2bb25285SStella Laurenzo // Custom asm helpers
18*2bb25285SStella Laurenzo //===----------------------------------------------------------------------===//
19*2bb25285SStella Laurenzo 
20*2bb25285SStella Laurenzo /// some.op custom<TypeOrAttr>($type, $attr)
21*2bb25285SStella Laurenzo ///
22*2bb25285SStella Laurenzo /// Uninitialized:
23*2bb25285SStella Laurenzo ///   some.op : tensor<3xi32>
24*2bb25285SStella Laurenzo /// Initialized to narrower type than op:
25*2bb25285SStella Laurenzo ///   some.op (dense<0> : tensor<3xi32>) : tensor<?xi32>
26*2bb25285SStella Laurenzo static ParseResult parseTypedInitialValue(OpAsmParser &parser,
27*2bb25285SStella Laurenzo                                           TypeAttr &typeAttr, Attribute &attr) {
28*2bb25285SStella Laurenzo   if (succeeded(parser.parseOptionalLParen())) {
29*2bb25285SStella Laurenzo     if (failed(parser.parseAttribute(attr)))
30*2bb25285SStella Laurenzo       return failure();
31*2bb25285SStella Laurenzo     if (failed(parser.parseRParen()))
32*2bb25285SStella Laurenzo       return failure();
33*2bb25285SStella Laurenzo   }
34*2bb25285SStella Laurenzo 
35*2bb25285SStella Laurenzo   Type type;
36*2bb25285SStella Laurenzo   if (failed(parser.parseColonType(type)))
37*2bb25285SStella Laurenzo     return failure();
38*2bb25285SStella Laurenzo   typeAttr = TypeAttr::get(type);
39*2bb25285SStella Laurenzo   return success();
40*2bb25285SStella Laurenzo }
41*2bb25285SStella Laurenzo 
42*2bb25285SStella Laurenzo static void printTypedInitialValue(OpAsmPrinter &p, Operation *op,
43*2bb25285SStella Laurenzo                                    TypeAttr type, Attribute attr) {
44*2bb25285SStella Laurenzo   if (attr) {
45*2bb25285SStella Laurenzo     p << "(";
46*2bb25285SStella Laurenzo     p.printAttribute(attr);
47*2bb25285SStella Laurenzo     p << ")";
48*2bb25285SStella Laurenzo   }
49*2bb25285SStella Laurenzo 
50*2bb25285SStella Laurenzo   p << " : ";
51*2bb25285SStella Laurenzo   p.printAttribute(type);
52*2bb25285SStella Laurenzo }
53*2bb25285SStella Laurenzo 
54*2bb25285SStella Laurenzo /// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
55*2bb25285SStella Laurenzo /// ->
56*2bb25285SStella Laurenzo /// some.op public @foo
57*2bb25285SStella Laurenzo /// some.op private @foo
58*2bb25285SStella Laurenzo static ParseResult parseSymbolVisibility(OpAsmParser &parser,
59*2bb25285SStella Laurenzo                                          StringAttr &symVisibilityAttr) {
60*2bb25285SStella Laurenzo   StringRef symVisibility;
61*2bb25285SStella Laurenzo   (void)parser.parseOptionalKeyword(&symVisibility,
62*2bb25285SStella Laurenzo                                     {"public", "private", "nested"});
63*2bb25285SStella Laurenzo   if (symVisibility.empty())
64*2bb25285SStella Laurenzo     return parser.emitError(parser.getCurrentLocation())
65*2bb25285SStella Laurenzo            << "expected 'public', 'private', or 'nested'";
66*2bb25285SStella Laurenzo   if (!symVisibility.empty())
67*2bb25285SStella Laurenzo     symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
68*2bb25285SStella Laurenzo   return success();
69*2bb25285SStella Laurenzo }
70*2bb25285SStella Laurenzo 
71*2bb25285SStella Laurenzo static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
72*2bb25285SStella Laurenzo                                   StringAttr symVisibilityAttr) {
73*2bb25285SStella Laurenzo   if (!symVisibilityAttr)
74*2bb25285SStella Laurenzo     p << "public";
75*2bb25285SStella Laurenzo   else
76*2bb25285SStella Laurenzo     p << symVisibilityAttr.getValue();
77*2bb25285SStella Laurenzo }
78*2bb25285SStella Laurenzo 
79*2bb25285SStella Laurenzo //===----------------------------------------------------------------------===//
8061352a58SStella Laurenzo // TableGen'd op method definitions
8161352a58SStella Laurenzo //===----------------------------------------------------------------------===//
8261352a58SStella Laurenzo 
8361352a58SStella Laurenzo #define GET_OP_CLASSES
8461352a58SStella Laurenzo #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
8561352a58SStella Laurenzo 
8661352a58SStella Laurenzo //===----------------------------------------------------------------------===//
8761352a58SStella Laurenzo // FuncOp
8861352a58SStella Laurenzo //===----------------------------------------------------------------------===//
8961352a58SStella Laurenzo 
9061352a58SStella Laurenzo ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
9161352a58SStella Laurenzo   auto buildFuncType =
9261352a58SStella Laurenzo       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
9361352a58SStella Laurenzo          function_interface_impl::VariadicFlag,
9461352a58SStella Laurenzo          std::string &) { return builder.getFunctionType(argTypes, results); };
9561352a58SStella Laurenzo 
9661352a58SStella Laurenzo   return function_interface_impl::parseFunctionOp(
9761352a58SStella Laurenzo       parser, result, /*allowVariadic=*/false, buildFuncType);
9861352a58SStella Laurenzo }
9961352a58SStella Laurenzo 
10061352a58SStella Laurenzo void FuncOp::print(OpAsmPrinter &p) {
10161352a58SStella Laurenzo   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
10261352a58SStella Laurenzo }
10361352a58SStella Laurenzo 
10461352a58SStella Laurenzo //===----------------------------------------------------------------------===//
105*2bb25285SStella Laurenzo // GlobalOp
106*2bb25285SStella Laurenzo //===----------------------------------------------------------------------===//
107*2bb25285SStella Laurenzo 
108*2bb25285SStella Laurenzo LogicalResult GlobalOp::verify() {
109*2bb25285SStella Laurenzo   if (!getIsMutable() && !getValue())
110*2bb25285SStella Laurenzo     return emitOpError() << "immutable global must have an initial value";
111*2bb25285SStella Laurenzo   return success();
112*2bb25285SStella Laurenzo }
113*2bb25285SStella Laurenzo 
114*2bb25285SStella Laurenzo //===----------------------------------------------------------------------===//
115*2bb25285SStella Laurenzo // GlobalLoadConstOp
116*2bb25285SStella Laurenzo //===----------------------------------------------------------------------===//
117*2bb25285SStella Laurenzo 
118*2bb25285SStella Laurenzo GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) {
119*2bb25285SStella Laurenzo   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
120*2bb25285SStella Laurenzo       getOperation()->getParentOp(), getGlobalAttr());
121*2bb25285SStella Laurenzo }
122*2bb25285SStella Laurenzo 
123*2bb25285SStella Laurenzo LogicalResult
124*2bb25285SStella Laurenzo GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
125*2bb25285SStella Laurenzo   GlobalOp referrent = getGlobalOp(symbolTable);
126*2bb25285SStella Laurenzo   if (!referrent)
127*2bb25285SStella Laurenzo     return emitOpError() << "undefined global: " << getGlobal();
128*2bb25285SStella Laurenzo 
129*2bb25285SStella Laurenzo   if (referrent.getIsMutable())
130*2bb25285SStella Laurenzo     return emitOpError() << "cannot load as const from mutable global "
131*2bb25285SStella Laurenzo                          << getGlobal();
132*2bb25285SStella Laurenzo 
133*2bb25285SStella Laurenzo   if (referrent.getType() != getResult().getType())
134*2bb25285SStella Laurenzo     return emitOpError() << "cannot load from global typed "
135*2bb25285SStella Laurenzo                          << referrent.getType() << " as "
136*2bb25285SStella Laurenzo                          << getResult().getType();
137*2bb25285SStella Laurenzo 
138*2bb25285SStella Laurenzo   return success();
139*2bb25285SStella Laurenzo }
140*2bb25285SStella Laurenzo 
141*2bb25285SStella Laurenzo //===----------------------------------------------------------------------===//
14261352a58SStella Laurenzo // SubgraphOp
14361352a58SStella Laurenzo //===----------------------------------------------------------------------===//
14461352a58SStella Laurenzo 
14561352a58SStella Laurenzo ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
14661352a58SStella Laurenzo   auto buildFuncType =
14761352a58SStella Laurenzo       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
14861352a58SStella Laurenzo          function_interface_impl::VariadicFlag,
14961352a58SStella Laurenzo          std::string &) { return builder.getFunctionType(argTypes, results); };
15061352a58SStella Laurenzo 
15161352a58SStella Laurenzo   return function_interface_impl::parseFunctionOp(
15261352a58SStella Laurenzo       parser, result, /*allowVariadic=*/false, buildFuncType);
15361352a58SStella Laurenzo }
15461352a58SStella Laurenzo 
15561352a58SStella Laurenzo void SubgraphOp::print(OpAsmPrinter &p) {
15661352a58SStella Laurenzo   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
15761352a58SStella Laurenzo }
15861352a58SStella Laurenzo 
15961352a58SStella Laurenzo //===----------------------------------------------------------------------===//
16061352a58SStella Laurenzo // OutputOp
16161352a58SStella Laurenzo //===----------------------------------------------------------------------===//
16261352a58SStella Laurenzo 
16361352a58SStella Laurenzo LogicalResult OutputOp::verify() {
16461352a58SStella Laurenzo   auto function = cast<SubgraphOp>((*this)->getParentOp());
16561352a58SStella Laurenzo 
16661352a58SStella Laurenzo   // The operand number and types must match the function signature.
16761352a58SStella Laurenzo   const auto &results = function.getFunctionType().getResults();
16861352a58SStella Laurenzo   if (getNumOperands() != results.size())
16961352a58SStella Laurenzo     return emitOpError("has ")
17061352a58SStella Laurenzo            << getNumOperands() << " operands, but enclosing function (@"
17161352a58SStella Laurenzo            << function.getName() << ") outputs " << results.size();
17261352a58SStella Laurenzo 
17361352a58SStella Laurenzo   for (unsigned i = 0, e = results.size(); i != e; ++i)
17461352a58SStella Laurenzo     if (getOperand(i).getType() != results[i])
17561352a58SStella Laurenzo       return emitError() << "type of output operand " << i << " ("
17661352a58SStella Laurenzo                          << getOperand(i).getType()
17761352a58SStella Laurenzo                          << ") doesn't match function result type ("
17861352a58SStella Laurenzo                          << results[i] << ")"
17961352a58SStella Laurenzo                          << " in function @" << function.getName();
18061352a58SStella Laurenzo 
18161352a58SStella Laurenzo   return success();
18261352a58SStella Laurenzo }
18361352a58SStella Laurenzo 
18461352a58SStella Laurenzo //===----------------------------------------------------------------------===//
18561352a58SStella Laurenzo // ReturnOp
18661352a58SStella Laurenzo //===----------------------------------------------------------------------===//
18761352a58SStella Laurenzo 
18861352a58SStella Laurenzo LogicalResult ReturnOp::verify() {
18961352a58SStella Laurenzo   auto function = cast<FuncOp>((*this)->getParentOp());
19061352a58SStella Laurenzo 
19161352a58SStella Laurenzo   // The operand number and types must match the function signature.
19261352a58SStella Laurenzo   const auto &results = function.getFunctionType().getResults();
19361352a58SStella Laurenzo   if (getNumOperands() != results.size())
19461352a58SStella Laurenzo     return emitOpError("has ")
19561352a58SStella Laurenzo            << getNumOperands() << " operands, but enclosing function (@"
19661352a58SStella Laurenzo            << function.getName() << ") returns " << results.size();
19761352a58SStella Laurenzo 
19861352a58SStella Laurenzo   for (unsigned i = 0, e = results.size(); i != e; ++i)
19961352a58SStella Laurenzo     if (getOperand(i).getType() != results[i])
20061352a58SStella Laurenzo       return emitError() << "type of return operand " << i << " ("
20161352a58SStella Laurenzo                          << getOperand(i).getType()
20261352a58SStella Laurenzo                          << ") doesn't match function result type ("
20361352a58SStella Laurenzo                          << results[i] << ")"
20461352a58SStella Laurenzo                          << " in function @" << function.getName();
20561352a58SStella Laurenzo 
20661352a58SStella Laurenzo   return success();
20761352a58SStella Laurenzo }
208