161352a58SStella Laurenzo //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===// 261352a58SStella Laurenzo // 361352a58SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 461352a58SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 561352a58SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 661352a58SStella Laurenzo // 761352a58SStella Laurenzo //===----------------------------------------------------------------------===// 861352a58SStella Laurenzo 961352a58SStella Laurenzo #include "mlir/Dialect/MLProgram/IR/MLProgram.h" 1061352a58SStella Laurenzo #include "mlir/IR/Builders.h" 1161352a58SStella Laurenzo #include "mlir/IR/FunctionImplementation.h" 1261352a58SStella Laurenzo 1361352a58SStella Laurenzo using namespace mlir; 1461352a58SStella Laurenzo using namespace mlir::ml_program; 1561352a58SStella Laurenzo 1661352a58SStella Laurenzo //===----------------------------------------------------------------------===// 172bb25285SStella Laurenzo // Custom asm helpers 182bb25285SStella Laurenzo //===----------------------------------------------------------------------===// 192bb25285SStella Laurenzo 20*3bb79993SStella Laurenzo /// Parse and print an ordering clause for a variadic of consuming tokens 21*3bb79993SStella Laurenzo /// and an optional producing token. 22*3bb79993SStella Laurenzo /// 23*3bb79993SStella Laurenzo /// Syntax: 24*3bb79993SStella Laurenzo /// ordering(%0, %1 -> !ml_program.token) 25*3bb79993SStella Laurenzo /// ordering(() -> !ml_program.token) 26*3bb79993SStella Laurenzo /// ordering(%0, %1) 27*3bb79993SStella Laurenzo /// 28*3bb79993SStella Laurenzo /// If both the consuming and producing token are not present on the op, then 29*3bb79993SStella Laurenzo /// the clause prints nothing. 30*3bb79993SStella Laurenzo static ParseResult parseTokenOrdering( 31*3bb79993SStella Laurenzo OpAsmParser &parser, 32*3bb79993SStella Laurenzo SmallVectorImpl<OpAsmParser::UnresolvedOperand> &consumeTokens, 33*3bb79993SStella Laurenzo Type &produceTokenType) { 34*3bb79993SStella Laurenzo if (failed(parser.parseOptionalKeyword("ordering")) || 35*3bb79993SStella Laurenzo failed(parser.parseLParen())) 36*3bb79993SStella Laurenzo return success(); 37*3bb79993SStella Laurenzo 38*3bb79993SStella Laurenzo // Parse consuming token list. If there are no consuming tokens, the 39*3bb79993SStella Laurenzo // '()' null list represents this. 40*3bb79993SStella Laurenzo if (succeeded(parser.parseOptionalLParen())) { 41*3bb79993SStella Laurenzo if (failed(parser.parseRParen())) 42*3bb79993SStella Laurenzo return failure(); 43*3bb79993SStella Laurenzo } else { 44*3bb79993SStella Laurenzo if (failed(parser.parseOperandList(consumeTokens, 45*3bb79993SStella Laurenzo /*requiredOperandCount=*/-1))) 46*3bb79993SStella Laurenzo return failure(); 47*3bb79993SStella Laurenzo } 48*3bb79993SStella Laurenzo 49*3bb79993SStella Laurenzo // Parse optional producer token. 50*3bb79993SStella Laurenzo if (succeeded(parser.parseOptionalArrow())) 51*3bb79993SStella Laurenzo if (failed(parser.parseType(produceTokenType))) 52*3bb79993SStella Laurenzo return failure(); 53*3bb79993SStella Laurenzo 54*3bb79993SStella Laurenzo if (failed(parser.parseRParen())) 55*3bb79993SStella Laurenzo return failure(); 56*3bb79993SStella Laurenzo 57*3bb79993SStella Laurenzo return success(); 58*3bb79993SStella Laurenzo } 59*3bb79993SStella Laurenzo 60*3bb79993SStella Laurenzo static void printTokenOrdering(OpAsmPrinter &p, Operation *op, 61*3bb79993SStella Laurenzo OperandRange consumeTokens, 62*3bb79993SStella Laurenzo Type produceTokenType) { 63*3bb79993SStella Laurenzo if (consumeTokens.empty() && !produceTokenType) 64*3bb79993SStella Laurenzo return; 65*3bb79993SStella Laurenzo 66*3bb79993SStella Laurenzo p << " ordering("; 67*3bb79993SStella Laurenzo if (consumeTokens.empty()) 68*3bb79993SStella Laurenzo p << "()"; 69*3bb79993SStella Laurenzo else 70*3bb79993SStella Laurenzo p.printOperands(consumeTokens); 71*3bb79993SStella Laurenzo if (produceTokenType) { 72*3bb79993SStella Laurenzo p << " -> "; 73*3bb79993SStella Laurenzo p.printType(produceTokenType); 74*3bb79993SStella Laurenzo } 75*3bb79993SStella Laurenzo p << ")"; 76*3bb79993SStella Laurenzo } 77*3bb79993SStella Laurenzo 782bb25285SStella Laurenzo /// some.op custom<TypeOrAttr>($type, $attr) 792bb25285SStella Laurenzo /// 802bb25285SStella Laurenzo /// Uninitialized: 812bb25285SStella Laurenzo /// some.op : tensor<3xi32> 822bb25285SStella Laurenzo /// Initialized to narrower type than op: 832bb25285SStella Laurenzo /// some.op (dense<0> : tensor<3xi32>) : tensor<?xi32> 842bb25285SStella Laurenzo static ParseResult parseTypedInitialValue(OpAsmParser &parser, 852bb25285SStella Laurenzo TypeAttr &typeAttr, Attribute &attr) { 862bb25285SStella Laurenzo if (succeeded(parser.parseOptionalLParen())) { 872bb25285SStella Laurenzo if (failed(parser.parseAttribute(attr))) 882bb25285SStella Laurenzo return failure(); 892bb25285SStella Laurenzo if (failed(parser.parseRParen())) 902bb25285SStella Laurenzo return failure(); 912bb25285SStella Laurenzo } 922bb25285SStella Laurenzo 932bb25285SStella Laurenzo Type type; 942bb25285SStella Laurenzo if (failed(parser.parseColonType(type))) 952bb25285SStella Laurenzo return failure(); 962bb25285SStella Laurenzo typeAttr = TypeAttr::get(type); 972bb25285SStella Laurenzo return success(); 982bb25285SStella Laurenzo } 992bb25285SStella Laurenzo 1002bb25285SStella Laurenzo static void printTypedInitialValue(OpAsmPrinter &p, Operation *op, 1012bb25285SStella Laurenzo TypeAttr type, Attribute attr) { 1022bb25285SStella Laurenzo if (attr) { 1032bb25285SStella Laurenzo p << "("; 1042bb25285SStella Laurenzo p.printAttribute(attr); 1052bb25285SStella Laurenzo p << ")"; 1062bb25285SStella Laurenzo } 1072bb25285SStella Laurenzo 1082bb25285SStella Laurenzo p << " : "; 1092bb25285SStella Laurenzo p.printAttribute(type); 1102bb25285SStella Laurenzo } 1112bb25285SStella Laurenzo 1122bb25285SStella Laurenzo /// some.op custom<SymbolVisibility>($sym_visibility) $sym_name 1132bb25285SStella Laurenzo /// -> 1142bb25285SStella Laurenzo /// some.op public @foo 1152bb25285SStella Laurenzo /// some.op private @foo 1162bb25285SStella Laurenzo static ParseResult parseSymbolVisibility(OpAsmParser &parser, 1172bb25285SStella Laurenzo StringAttr &symVisibilityAttr) { 1182bb25285SStella Laurenzo StringRef symVisibility; 1192bb25285SStella Laurenzo (void)parser.parseOptionalKeyword(&symVisibility, 1202bb25285SStella Laurenzo {"public", "private", "nested"}); 1212bb25285SStella Laurenzo if (symVisibility.empty()) 1222bb25285SStella Laurenzo return parser.emitError(parser.getCurrentLocation()) 1232bb25285SStella Laurenzo << "expected 'public', 'private', or 'nested'"; 1242bb25285SStella Laurenzo if (!symVisibility.empty()) 1252bb25285SStella Laurenzo symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility); 1262bb25285SStella Laurenzo return success(); 1272bb25285SStella Laurenzo } 1282bb25285SStella Laurenzo 1292bb25285SStella Laurenzo static void printSymbolVisibility(OpAsmPrinter &p, Operation *op, 1302bb25285SStella Laurenzo StringAttr symVisibilityAttr) { 1312bb25285SStella Laurenzo if (!symVisibilityAttr) 1322bb25285SStella Laurenzo p << "public"; 1332bb25285SStella Laurenzo else 1342bb25285SStella Laurenzo p << symVisibilityAttr.getValue(); 1352bb25285SStella Laurenzo } 1362bb25285SStella Laurenzo 1372bb25285SStella Laurenzo //===----------------------------------------------------------------------===// 13861352a58SStella Laurenzo // TableGen'd op method definitions 13961352a58SStella Laurenzo //===----------------------------------------------------------------------===// 14061352a58SStella Laurenzo 14161352a58SStella Laurenzo #define GET_OP_CLASSES 14261352a58SStella Laurenzo #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc" 14361352a58SStella Laurenzo 14461352a58SStella Laurenzo //===----------------------------------------------------------------------===// 14561352a58SStella Laurenzo // FuncOp 14661352a58SStella Laurenzo //===----------------------------------------------------------------------===// 14761352a58SStella Laurenzo 14861352a58SStella Laurenzo ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 14961352a58SStella Laurenzo auto buildFuncType = 15061352a58SStella Laurenzo [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 15161352a58SStella Laurenzo function_interface_impl::VariadicFlag, 15261352a58SStella Laurenzo std::string &) { return builder.getFunctionType(argTypes, results); }; 15361352a58SStella Laurenzo 15461352a58SStella Laurenzo return function_interface_impl::parseFunctionOp( 15561352a58SStella Laurenzo parser, result, /*allowVariadic=*/false, buildFuncType); 15661352a58SStella Laurenzo } 15761352a58SStella Laurenzo 15861352a58SStella Laurenzo void FuncOp::print(OpAsmPrinter &p) { 15961352a58SStella Laurenzo function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 16061352a58SStella Laurenzo } 16161352a58SStella Laurenzo 16261352a58SStella Laurenzo //===----------------------------------------------------------------------===// 1632bb25285SStella Laurenzo // GlobalOp 1642bb25285SStella Laurenzo //===----------------------------------------------------------------------===// 1652bb25285SStella Laurenzo 1662bb25285SStella Laurenzo LogicalResult GlobalOp::verify() { 1672bb25285SStella Laurenzo if (!getIsMutable() && !getValue()) 1682bb25285SStella Laurenzo return emitOpError() << "immutable global must have an initial value"; 1692bb25285SStella Laurenzo return success(); 1702bb25285SStella Laurenzo } 1712bb25285SStella Laurenzo 1722bb25285SStella Laurenzo //===----------------------------------------------------------------------===// 173*3bb79993SStella Laurenzo // GlobalLoadOp 174*3bb79993SStella Laurenzo //===----------------------------------------------------------------------===// 175*3bb79993SStella Laurenzo 176*3bb79993SStella Laurenzo GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) { 177*3bb79993SStella Laurenzo return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 178*3bb79993SStella Laurenzo getOperation()->getParentOp(), getGlobalAttr()); 179*3bb79993SStella Laurenzo } 180*3bb79993SStella Laurenzo 181*3bb79993SStella Laurenzo LogicalResult 182*3bb79993SStella Laurenzo GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 183*3bb79993SStella Laurenzo GlobalOp referrent = getGlobalOp(symbolTable); 184*3bb79993SStella Laurenzo if (!referrent) 185*3bb79993SStella Laurenzo return emitOpError() << "undefined global: " << getGlobal(); 186*3bb79993SStella Laurenzo 187*3bb79993SStella Laurenzo if (referrent.getType() != getResult().getType()) { 188*3bb79993SStella Laurenzo return emitOpError() << "cannot load from global typed " 189*3bb79993SStella Laurenzo << referrent.getType() << " as " 190*3bb79993SStella Laurenzo << getResult().getType(); 191*3bb79993SStella Laurenzo } 192*3bb79993SStella Laurenzo 193*3bb79993SStella Laurenzo return success(); 194*3bb79993SStella Laurenzo } 195*3bb79993SStella Laurenzo 196*3bb79993SStella Laurenzo //===----------------------------------------------------------------------===// 1972bb25285SStella Laurenzo // GlobalLoadConstOp 1982bb25285SStella Laurenzo //===----------------------------------------------------------------------===// 1992bb25285SStella Laurenzo 2002bb25285SStella Laurenzo GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) { 2012bb25285SStella Laurenzo return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 2022bb25285SStella Laurenzo getOperation()->getParentOp(), getGlobalAttr()); 2032bb25285SStella Laurenzo } 2042bb25285SStella Laurenzo 2052bb25285SStella Laurenzo LogicalResult 2062bb25285SStella Laurenzo GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 2072bb25285SStella Laurenzo GlobalOp referrent = getGlobalOp(symbolTable); 2082bb25285SStella Laurenzo if (!referrent) 2092bb25285SStella Laurenzo return emitOpError() << "undefined global: " << getGlobal(); 2102bb25285SStella Laurenzo 2112bb25285SStella Laurenzo if (referrent.getIsMutable()) 2122bb25285SStella Laurenzo return emitOpError() << "cannot load as const from mutable global " 2132bb25285SStella Laurenzo << getGlobal(); 2142bb25285SStella Laurenzo 2152bb25285SStella Laurenzo if (referrent.getType() != getResult().getType()) 2162bb25285SStella Laurenzo return emitOpError() << "cannot load from global typed " 2172bb25285SStella Laurenzo << referrent.getType() << " as " 2182bb25285SStella Laurenzo << getResult().getType(); 2192bb25285SStella Laurenzo 2202bb25285SStella Laurenzo return success(); 2212bb25285SStella Laurenzo } 2222bb25285SStella Laurenzo 2232bb25285SStella Laurenzo //===----------------------------------------------------------------------===// 224*3bb79993SStella Laurenzo // GlobalStoreOp 225*3bb79993SStella Laurenzo //===----------------------------------------------------------------------===// 226*3bb79993SStella Laurenzo 227*3bb79993SStella Laurenzo GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) { 228*3bb79993SStella Laurenzo return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 229*3bb79993SStella Laurenzo getOperation()->getParentOp(), getGlobalAttr()); 230*3bb79993SStella Laurenzo } 231*3bb79993SStella Laurenzo 232*3bb79993SStella Laurenzo LogicalResult 233*3bb79993SStella Laurenzo GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 234*3bb79993SStella Laurenzo GlobalOp referrent = getGlobalOp(symbolTable); 235*3bb79993SStella Laurenzo if (!referrent) 236*3bb79993SStella Laurenzo return emitOpError() << "undefined global: " << getGlobal(); 237*3bb79993SStella Laurenzo 238*3bb79993SStella Laurenzo if (!referrent.getIsMutable()) { 239*3bb79993SStella Laurenzo return emitOpError() << "cannot store to an immutable global " 240*3bb79993SStella Laurenzo << getGlobal(); 241*3bb79993SStella Laurenzo } 242*3bb79993SStella Laurenzo 243*3bb79993SStella Laurenzo if (referrent.getType() != getValue().getType()) { 244*3bb79993SStella Laurenzo return emitOpError() << "cannot store to a global typed " 245*3bb79993SStella Laurenzo << referrent.getType() << " from " 246*3bb79993SStella Laurenzo << getValue().getType(); 247*3bb79993SStella Laurenzo } 248*3bb79993SStella Laurenzo 249*3bb79993SStella Laurenzo return success(); 250*3bb79993SStella Laurenzo } 251*3bb79993SStella Laurenzo 252*3bb79993SStella Laurenzo //===----------------------------------------------------------------------===// 25361352a58SStella Laurenzo // SubgraphOp 25461352a58SStella Laurenzo //===----------------------------------------------------------------------===// 25561352a58SStella Laurenzo 25661352a58SStella Laurenzo ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) { 25761352a58SStella Laurenzo auto buildFuncType = 25861352a58SStella Laurenzo [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 25961352a58SStella Laurenzo function_interface_impl::VariadicFlag, 26061352a58SStella Laurenzo std::string &) { return builder.getFunctionType(argTypes, results); }; 26161352a58SStella Laurenzo 26261352a58SStella Laurenzo return function_interface_impl::parseFunctionOp( 26361352a58SStella Laurenzo parser, result, /*allowVariadic=*/false, buildFuncType); 26461352a58SStella Laurenzo } 26561352a58SStella Laurenzo 26661352a58SStella Laurenzo void SubgraphOp::print(OpAsmPrinter &p) { 26761352a58SStella Laurenzo function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 26861352a58SStella Laurenzo } 26961352a58SStella Laurenzo 27061352a58SStella Laurenzo //===----------------------------------------------------------------------===// 27161352a58SStella Laurenzo // OutputOp 27261352a58SStella Laurenzo //===----------------------------------------------------------------------===// 27361352a58SStella Laurenzo 27461352a58SStella Laurenzo LogicalResult OutputOp::verify() { 27561352a58SStella Laurenzo auto function = cast<SubgraphOp>((*this)->getParentOp()); 27661352a58SStella Laurenzo 27761352a58SStella Laurenzo // The operand number and types must match the function signature. 27861352a58SStella Laurenzo const auto &results = function.getFunctionType().getResults(); 27961352a58SStella Laurenzo if (getNumOperands() != results.size()) 28061352a58SStella Laurenzo return emitOpError("has ") 28161352a58SStella Laurenzo << getNumOperands() << " operands, but enclosing function (@" 28261352a58SStella Laurenzo << function.getName() << ") outputs " << results.size(); 28361352a58SStella Laurenzo 28461352a58SStella Laurenzo for (unsigned i = 0, e = results.size(); i != e; ++i) 28561352a58SStella Laurenzo if (getOperand(i).getType() != results[i]) 28661352a58SStella Laurenzo return emitError() << "type of output operand " << i << " (" 28761352a58SStella Laurenzo << getOperand(i).getType() 28861352a58SStella Laurenzo << ") doesn't match function result type (" 28961352a58SStella Laurenzo << results[i] << ")" 29061352a58SStella Laurenzo << " in function @" << function.getName(); 29161352a58SStella Laurenzo 29261352a58SStella Laurenzo return success(); 29361352a58SStella Laurenzo } 29461352a58SStella Laurenzo 29561352a58SStella Laurenzo //===----------------------------------------------------------------------===// 29661352a58SStella Laurenzo // ReturnOp 29761352a58SStella Laurenzo //===----------------------------------------------------------------------===// 29861352a58SStella Laurenzo 29961352a58SStella Laurenzo LogicalResult ReturnOp::verify() { 30061352a58SStella Laurenzo auto function = cast<FuncOp>((*this)->getParentOp()); 30161352a58SStella Laurenzo 30261352a58SStella Laurenzo // The operand number and types must match the function signature. 30361352a58SStella Laurenzo const auto &results = function.getFunctionType().getResults(); 30461352a58SStella Laurenzo if (getNumOperands() != results.size()) 30561352a58SStella Laurenzo return emitOpError("has ") 30661352a58SStella Laurenzo << getNumOperands() << " operands, but enclosing function (@" 30761352a58SStella Laurenzo << function.getName() << ") returns " << results.size(); 30861352a58SStella Laurenzo 30961352a58SStella Laurenzo for (unsigned i = 0, e = results.size(); i != e; ++i) 31061352a58SStella Laurenzo if (getOperand(i).getType() != results[i]) 31161352a58SStella Laurenzo return emitError() << "type of return operand " << i << " (" 31261352a58SStella Laurenzo << getOperand(i).getType() 31361352a58SStella Laurenzo << ") doesn't match function result type (" 31461352a58SStella Laurenzo << results[i] << ")" 31561352a58SStella Laurenzo << " in function @" << function.getName(); 31661352a58SStella Laurenzo 31761352a58SStella Laurenzo return success(); 31861352a58SStella Laurenzo } 319