161352a58SStella Laurenzo //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
261352a58SStella Laurenzo //
361352a58SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
461352a58SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
561352a58SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
661352a58SStella Laurenzo //
761352a58SStella Laurenzo //===----------------------------------------------------------------------===//
861352a58SStella Laurenzo 
961352a58SStella Laurenzo #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
1061352a58SStella Laurenzo #include "mlir/IR/Builders.h"
1161352a58SStella Laurenzo #include "mlir/IR/FunctionImplementation.h"
1261352a58SStella Laurenzo 
1361352a58SStella Laurenzo using namespace mlir;
1461352a58SStella Laurenzo using namespace mlir::ml_program;
1561352a58SStella Laurenzo 
1661352a58SStella Laurenzo //===----------------------------------------------------------------------===//
172bb25285SStella Laurenzo // Custom asm helpers
182bb25285SStella Laurenzo //===----------------------------------------------------------------------===//
192bb25285SStella Laurenzo 
20*3bb79993SStella Laurenzo /// Parse and print an ordering clause for a variadic of consuming tokens
21*3bb79993SStella Laurenzo /// and an optional producing token.
22*3bb79993SStella Laurenzo ///
23*3bb79993SStella Laurenzo /// Syntax:
24*3bb79993SStella Laurenzo ///   ordering(%0, %1 -> !ml_program.token)
25*3bb79993SStella Laurenzo ///   ordering(() -> !ml_program.token)
26*3bb79993SStella Laurenzo ///   ordering(%0, %1)
27*3bb79993SStella Laurenzo ///
28*3bb79993SStella Laurenzo /// If both the consuming and producing token are not present on the op, then
29*3bb79993SStella Laurenzo /// the clause prints nothing.
30*3bb79993SStella Laurenzo static ParseResult parseTokenOrdering(
31*3bb79993SStella Laurenzo     OpAsmParser &parser,
32*3bb79993SStella Laurenzo     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &consumeTokens,
33*3bb79993SStella Laurenzo     Type &produceTokenType) {
34*3bb79993SStella Laurenzo   if (failed(parser.parseOptionalKeyword("ordering")) ||
35*3bb79993SStella Laurenzo       failed(parser.parseLParen()))
36*3bb79993SStella Laurenzo     return success();
37*3bb79993SStella Laurenzo 
38*3bb79993SStella Laurenzo   // Parse consuming token list. If there are no consuming tokens, the
39*3bb79993SStella Laurenzo   // '()' null list represents this.
40*3bb79993SStella Laurenzo   if (succeeded(parser.parseOptionalLParen())) {
41*3bb79993SStella Laurenzo     if (failed(parser.parseRParen()))
42*3bb79993SStella Laurenzo       return failure();
43*3bb79993SStella Laurenzo   } else {
44*3bb79993SStella Laurenzo     if (failed(parser.parseOperandList(consumeTokens,
45*3bb79993SStella Laurenzo                                        /*requiredOperandCount=*/-1)))
46*3bb79993SStella Laurenzo       return failure();
47*3bb79993SStella Laurenzo   }
48*3bb79993SStella Laurenzo 
49*3bb79993SStella Laurenzo   // Parse optional producer token.
50*3bb79993SStella Laurenzo   if (succeeded(parser.parseOptionalArrow()))
51*3bb79993SStella Laurenzo     if (failed(parser.parseType(produceTokenType)))
52*3bb79993SStella Laurenzo       return failure();
53*3bb79993SStella Laurenzo 
54*3bb79993SStella Laurenzo   if (failed(parser.parseRParen()))
55*3bb79993SStella Laurenzo     return failure();
56*3bb79993SStella Laurenzo 
57*3bb79993SStella Laurenzo   return success();
58*3bb79993SStella Laurenzo }
59*3bb79993SStella Laurenzo 
60*3bb79993SStella Laurenzo static void printTokenOrdering(OpAsmPrinter &p, Operation *op,
61*3bb79993SStella Laurenzo                                OperandRange consumeTokens,
62*3bb79993SStella Laurenzo                                Type produceTokenType) {
63*3bb79993SStella Laurenzo   if (consumeTokens.empty() && !produceTokenType)
64*3bb79993SStella Laurenzo     return;
65*3bb79993SStella Laurenzo 
66*3bb79993SStella Laurenzo   p << " ordering(";
67*3bb79993SStella Laurenzo   if (consumeTokens.empty())
68*3bb79993SStella Laurenzo     p << "()";
69*3bb79993SStella Laurenzo   else
70*3bb79993SStella Laurenzo     p.printOperands(consumeTokens);
71*3bb79993SStella Laurenzo   if (produceTokenType) {
72*3bb79993SStella Laurenzo     p << " -> ";
73*3bb79993SStella Laurenzo     p.printType(produceTokenType);
74*3bb79993SStella Laurenzo   }
75*3bb79993SStella Laurenzo   p << ")";
76*3bb79993SStella Laurenzo }
77*3bb79993SStella Laurenzo 
782bb25285SStella Laurenzo /// some.op custom<TypeOrAttr>($type, $attr)
792bb25285SStella Laurenzo ///
802bb25285SStella Laurenzo /// Uninitialized:
812bb25285SStella Laurenzo ///   some.op : tensor<3xi32>
822bb25285SStella Laurenzo /// Initialized to narrower type than op:
832bb25285SStella Laurenzo ///   some.op (dense<0> : tensor<3xi32>) : tensor<?xi32>
842bb25285SStella Laurenzo static ParseResult parseTypedInitialValue(OpAsmParser &parser,
852bb25285SStella Laurenzo                                           TypeAttr &typeAttr, Attribute &attr) {
862bb25285SStella Laurenzo   if (succeeded(parser.parseOptionalLParen())) {
872bb25285SStella Laurenzo     if (failed(parser.parseAttribute(attr)))
882bb25285SStella Laurenzo       return failure();
892bb25285SStella Laurenzo     if (failed(parser.parseRParen()))
902bb25285SStella Laurenzo       return failure();
912bb25285SStella Laurenzo   }
922bb25285SStella Laurenzo 
932bb25285SStella Laurenzo   Type type;
942bb25285SStella Laurenzo   if (failed(parser.parseColonType(type)))
952bb25285SStella Laurenzo     return failure();
962bb25285SStella Laurenzo   typeAttr = TypeAttr::get(type);
972bb25285SStella Laurenzo   return success();
982bb25285SStella Laurenzo }
992bb25285SStella Laurenzo 
1002bb25285SStella Laurenzo static void printTypedInitialValue(OpAsmPrinter &p, Operation *op,
1012bb25285SStella Laurenzo                                    TypeAttr type, Attribute attr) {
1022bb25285SStella Laurenzo   if (attr) {
1032bb25285SStella Laurenzo     p << "(";
1042bb25285SStella Laurenzo     p.printAttribute(attr);
1052bb25285SStella Laurenzo     p << ")";
1062bb25285SStella Laurenzo   }
1072bb25285SStella Laurenzo 
1082bb25285SStella Laurenzo   p << " : ";
1092bb25285SStella Laurenzo   p.printAttribute(type);
1102bb25285SStella Laurenzo }
1112bb25285SStella Laurenzo 
1122bb25285SStella Laurenzo /// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
1132bb25285SStella Laurenzo /// ->
1142bb25285SStella Laurenzo /// some.op public @foo
1152bb25285SStella Laurenzo /// some.op private @foo
1162bb25285SStella Laurenzo static ParseResult parseSymbolVisibility(OpAsmParser &parser,
1172bb25285SStella Laurenzo                                          StringAttr &symVisibilityAttr) {
1182bb25285SStella Laurenzo   StringRef symVisibility;
1192bb25285SStella Laurenzo   (void)parser.parseOptionalKeyword(&symVisibility,
1202bb25285SStella Laurenzo                                     {"public", "private", "nested"});
1212bb25285SStella Laurenzo   if (symVisibility.empty())
1222bb25285SStella Laurenzo     return parser.emitError(parser.getCurrentLocation())
1232bb25285SStella Laurenzo            << "expected 'public', 'private', or 'nested'";
1242bb25285SStella Laurenzo   if (!symVisibility.empty())
1252bb25285SStella Laurenzo     symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
1262bb25285SStella Laurenzo   return success();
1272bb25285SStella Laurenzo }
1282bb25285SStella Laurenzo 
1292bb25285SStella Laurenzo static void printSymbolVisibility(OpAsmPrinter &p, Operation *op,
1302bb25285SStella Laurenzo                                   StringAttr symVisibilityAttr) {
1312bb25285SStella Laurenzo   if (!symVisibilityAttr)
1322bb25285SStella Laurenzo     p << "public";
1332bb25285SStella Laurenzo   else
1342bb25285SStella Laurenzo     p << symVisibilityAttr.getValue();
1352bb25285SStella Laurenzo }
1362bb25285SStella Laurenzo 
1372bb25285SStella Laurenzo //===----------------------------------------------------------------------===//
13861352a58SStella Laurenzo // TableGen'd op method definitions
13961352a58SStella Laurenzo //===----------------------------------------------------------------------===//
14061352a58SStella Laurenzo 
14161352a58SStella Laurenzo #define GET_OP_CLASSES
14261352a58SStella Laurenzo #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
14361352a58SStella Laurenzo 
14461352a58SStella Laurenzo //===----------------------------------------------------------------------===//
14561352a58SStella Laurenzo // FuncOp
14661352a58SStella Laurenzo //===----------------------------------------------------------------------===//
14761352a58SStella Laurenzo 
14861352a58SStella Laurenzo ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
14961352a58SStella Laurenzo   auto buildFuncType =
15061352a58SStella Laurenzo       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
15161352a58SStella Laurenzo          function_interface_impl::VariadicFlag,
15261352a58SStella Laurenzo          std::string &) { return builder.getFunctionType(argTypes, results); };
15361352a58SStella Laurenzo 
15461352a58SStella Laurenzo   return function_interface_impl::parseFunctionOp(
15561352a58SStella Laurenzo       parser, result, /*allowVariadic=*/false, buildFuncType);
15661352a58SStella Laurenzo }
15761352a58SStella Laurenzo 
15861352a58SStella Laurenzo void FuncOp::print(OpAsmPrinter &p) {
15961352a58SStella Laurenzo   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
16061352a58SStella Laurenzo }
16161352a58SStella Laurenzo 
16261352a58SStella Laurenzo //===----------------------------------------------------------------------===//
1632bb25285SStella Laurenzo // GlobalOp
1642bb25285SStella Laurenzo //===----------------------------------------------------------------------===//
1652bb25285SStella Laurenzo 
1662bb25285SStella Laurenzo LogicalResult GlobalOp::verify() {
1672bb25285SStella Laurenzo   if (!getIsMutable() && !getValue())
1682bb25285SStella Laurenzo     return emitOpError() << "immutable global must have an initial value";
1692bb25285SStella Laurenzo   return success();
1702bb25285SStella Laurenzo }
1712bb25285SStella Laurenzo 
1722bb25285SStella Laurenzo //===----------------------------------------------------------------------===//
173*3bb79993SStella Laurenzo // GlobalLoadOp
174*3bb79993SStella Laurenzo //===----------------------------------------------------------------------===//
175*3bb79993SStella Laurenzo 
176*3bb79993SStella Laurenzo GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) {
177*3bb79993SStella Laurenzo   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
178*3bb79993SStella Laurenzo       getOperation()->getParentOp(), getGlobalAttr());
179*3bb79993SStella Laurenzo }
180*3bb79993SStella Laurenzo 
181*3bb79993SStella Laurenzo LogicalResult
182*3bb79993SStella Laurenzo GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
183*3bb79993SStella Laurenzo   GlobalOp referrent = getGlobalOp(symbolTable);
184*3bb79993SStella Laurenzo   if (!referrent)
185*3bb79993SStella Laurenzo     return emitOpError() << "undefined global: " << getGlobal();
186*3bb79993SStella Laurenzo 
187*3bb79993SStella Laurenzo   if (referrent.getType() != getResult().getType()) {
188*3bb79993SStella Laurenzo     return emitOpError() << "cannot load from global typed "
189*3bb79993SStella Laurenzo                          << referrent.getType() << " as "
190*3bb79993SStella Laurenzo                          << getResult().getType();
191*3bb79993SStella Laurenzo   }
192*3bb79993SStella Laurenzo 
193*3bb79993SStella Laurenzo   return success();
194*3bb79993SStella Laurenzo }
195*3bb79993SStella Laurenzo 
196*3bb79993SStella Laurenzo //===----------------------------------------------------------------------===//
1972bb25285SStella Laurenzo // GlobalLoadConstOp
1982bb25285SStella Laurenzo //===----------------------------------------------------------------------===//
1992bb25285SStella Laurenzo 
2002bb25285SStella Laurenzo GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) {
2012bb25285SStella Laurenzo   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
2022bb25285SStella Laurenzo       getOperation()->getParentOp(), getGlobalAttr());
2032bb25285SStella Laurenzo }
2042bb25285SStella Laurenzo 
2052bb25285SStella Laurenzo LogicalResult
2062bb25285SStella Laurenzo GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2072bb25285SStella Laurenzo   GlobalOp referrent = getGlobalOp(symbolTable);
2082bb25285SStella Laurenzo   if (!referrent)
2092bb25285SStella Laurenzo     return emitOpError() << "undefined global: " << getGlobal();
2102bb25285SStella Laurenzo 
2112bb25285SStella Laurenzo   if (referrent.getIsMutable())
2122bb25285SStella Laurenzo     return emitOpError() << "cannot load as const from mutable global "
2132bb25285SStella Laurenzo                          << getGlobal();
2142bb25285SStella Laurenzo 
2152bb25285SStella Laurenzo   if (referrent.getType() != getResult().getType())
2162bb25285SStella Laurenzo     return emitOpError() << "cannot load from global typed "
2172bb25285SStella Laurenzo                          << referrent.getType() << " as "
2182bb25285SStella Laurenzo                          << getResult().getType();
2192bb25285SStella Laurenzo 
2202bb25285SStella Laurenzo   return success();
2212bb25285SStella Laurenzo }
2222bb25285SStella Laurenzo 
2232bb25285SStella Laurenzo //===----------------------------------------------------------------------===//
224*3bb79993SStella Laurenzo // GlobalStoreOp
225*3bb79993SStella Laurenzo //===----------------------------------------------------------------------===//
226*3bb79993SStella Laurenzo 
227*3bb79993SStella Laurenzo GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) {
228*3bb79993SStella Laurenzo   return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
229*3bb79993SStella Laurenzo       getOperation()->getParentOp(), getGlobalAttr());
230*3bb79993SStella Laurenzo }
231*3bb79993SStella Laurenzo 
232*3bb79993SStella Laurenzo LogicalResult
233*3bb79993SStella Laurenzo GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
234*3bb79993SStella Laurenzo   GlobalOp referrent = getGlobalOp(symbolTable);
235*3bb79993SStella Laurenzo   if (!referrent)
236*3bb79993SStella Laurenzo     return emitOpError() << "undefined global: " << getGlobal();
237*3bb79993SStella Laurenzo 
238*3bb79993SStella Laurenzo   if (!referrent.getIsMutable()) {
239*3bb79993SStella Laurenzo     return emitOpError() << "cannot store to an immutable global "
240*3bb79993SStella Laurenzo                          << getGlobal();
241*3bb79993SStella Laurenzo   }
242*3bb79993SStella Laurenzo 
243*3bb79993SStella Laurenzo   if (referrent.getType() != getValue().getType()) {
244*3bb79993SStella Laurenzo     return emitOpError() << "cannot store to a global typed "
245*3bb79993SStella Laurenzo                          << referrent.getType() << " from "
246*3bb79993SStella Laurenzo                          << getValue().getType();
247*3bb79993SStella Laurenzo   }
248*3bb79993SStella Laurenzo 
249*3bb79993SStella Laurenzo   return success();
250*3bb79993SStella Laurenzo }
251*3bb79993SStella Laurenzo 
252*3bb79993SStella Laurenzo //===----------------------------------------------------------------------===//
25361352a58SStella Laurenzo // SubgraphOp
25461352a58SStella Laurenzo //===----------------------------------------------------------------------===//
25561352a58SStella Laurenzo 
25661352a58SStella Laurenzo ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
25761352a58SStella Laurenzo   auto buildFuncType =
25861352a58SStella Laurenzo       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
25961352a58SStella Laurenzo          function_interface_impl::VariadicFlag,
26061352a58SStella Laurenzo          std::string &) { return builder.getFunctionType(argTypes, results); };
26161352a58SStella Laurenzo 
26261352a58SStella Laurenzo   return function_interface_impl::parseFunctionOp(
26361352a58SStella Laurenzo       parser, result, /*allowVariadic=*/false, buildFuncType);
26461352a58SStella Laurenzo }
26561352a58SStella Laurenzo 
26661352a58SStella Laurenzo void SubgraphOp::print(OpAsmPrinter &p) {
26761352a58SStella Laurenzo   function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
26861352a58SStella Laurenzo }
26961352a58SStella Laurenzo 
27061352a58SStella Laurenzo //===----------------------------------------------------------------------===//
27161352a58SStella Laurenzo // OutputOp
27261352a58SStella Laurenzo //===----------------------------------------------------------------------===//
27361352a58SStella Laurenzo 
27461352a58SStella Laurenzo LogicalResult OutputOp::verify() {
27561352a58SStella Laurenzo   auto function = cast<SubgraphOp>((*this)->getParentOp());
27661352a58SStella Laurenzo 
27761352a58SStella Laurenzo   // The operand number and types must match the function signature.
27861352a58SStella Laurenzo   const auto &results = function.getFunctionType().getResults();
27961352a58SStella Laurenzo   if (getNumOperands() != results.size())
28061352a58SStella Laurenzo     return emitOpError("has ")
28161352a58SStella Laurenzo            << getNumOperands() << " operands, but enclosing function (@"
28261352a58SStella Laurenzo            << function.getName() << ") outputs " << results.size();
28361352a58SStella Laurenzo 
28461352a58SStella Laurenzo   for (unsigned i = 0, e = results.size(); i != e; ++i)
28561352a58SStella Laurenzo     if (getOperand(i).getType() != results[i])
28661352a58SStella Laurenzo       return emitError() << "type of output operand " << i << " ("
28761352a58SStella Laurenzo                          << getOperand(i).getType()
28861352a58SStella Laurenzo                          << ") doesn't match function result type ("
28961352a58SStella Laurenzo                          << results[i] << ")"
29061352a58SStella Laurenzo                          << " in function @" << function.getName();
29161352a58SStella Laurenzo 
29261352a58SStella Laurenzo   return success();
29361352a58SStella Laurenzo }
29461352a58SStella Laurenzo 
29561352a58SStella Laurenzo //===----------------------------------------------------------------------===//
29661352a58SStella Laurenzo // ReturnOp
29761352a58SStella Laurenzo //===----------------------------------------------------------------------===//
29861352a58SStella Laurenzo 
29961352a58SStella Laurenzo LogicalResult ReturnOp::verify() {
30061352a58SStella Laurenzo   auto function = cast<FuncOp>((*this)->getParentOp());
30161352a58SStella Laurenzo 
30261352a58SStella Laurenzo   // The operand number and types must match the function signature.
30361352a58SStella Laurenzo   const auto &results = function.getFunctionType().getResults();
30461352a58SStella Laurenzo   if (getNumOperands() != results.size())
30561352a58SStella Laurenzo     return emitOpError("has ")
30661352a58SStella Laurenzo            << getNumOperands() << " operands, but enclosing function (@"
30761352a58SStella Laurenzo            << function.getName() << ") returns " << results.size();
30861352a58SStella Laurenzo 
30961352a58SStella Laurenzo   for (unsigned i = 0, e = results.size(); i != e; ++i)
31061352a58SStella Laurenzo     if (getOperand(i).getType() != results[i])
31161352a58SStella Laurenzo       return emitError() << "type of return operand " << i << " ("
31261352a58SStella Laurenzo                          << getOperand(i).getType()
31361352a58SStella Laurenzo                          << ") doesn't match function result type ("
31461352a58SStella Laurenzo                          << results[i] << ")"
31561352a58SStella Laurenzo                          << " in function @" << function.getName();
31661352a58SStella Laurenzo 
31761352a58SStella Laurenzo   return success();
31861352a58SStella Laurenzo }
319