161352a58SStella Laurenzo //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===// 261352a58SStella Laurenzo // 361352a58SStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 461352a58SStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 561352a58SStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 661352a58SStella Laurenzo // 761352a58SStella Laurenzo //===----------------------------------------------------------------------===// 861352a58SStella Laurenzo 961352a58SStella Laurenzo #include "mlir/Dialect/MLProgram/IR/MLProgram.h" 1061352a58SStella Laurenzo #include "mlir/IR/Builders.h" 1161352a58SStella Laurenzo #include "mlir/IR/FunctionImplementation.h" 1261352a58SStella Laurenzo 1361352a58SStella Laurenzo using namespace mlir; 1461352a58SStella Laurenzo using namespace mlir::ml_program; 1561352a58SStella Laurenzo 1661352a58SStella Laurenzo //===----------------------------------------------------------------------===// 17*2bb25285SStella Laurenzo // Custom asm helpers 18*2bb25285SStella Laurenzo //===----------------------------------------------------------------------===// 19*2bb25285SStella Laurenzo 20*2bb25285SStella Laurenzo /// some.op custom<TypeOrAttr>($type, $attr) 21*2bb25285SStella Laurenzo /// 22*2bb25285SStella Laurenzo /// Uninitialized: 23*2bb25285SStella Laurenzo /// some.op : tensor<3xi32> 24*2bb25285SStella Laurenzo /// Initialized to narrower type than op: 25*2bb25285SStella Laurenzo /// some.op (dense<0> : tensor<3xi32>) : tensor<?xi32> 26*2bb25285SStella Laurenzo static ParseResult parseTypedInitialValue(OpAsmParser &parser, 27*2bb25285SStella Laurenzo TypeAttr &typeAttr, Attribute &attr) { 28*2bb25285SStella Laurenzo if (succeeded(parser.parseOptionalLParen())) { 29*2bb25285SStella Laurenzo if (failed(parser.parseAttribute(attr))) 30*2bb25285SStella Laurenzo return failure(); 31*2bb25285SStella Laurenzo if (failed(parser.parseRParen())) 32*2bb25285SStella Laurenzo return failure(); 33*2bb25285SStella Laurenzo } 34*2bb25285SStella Laurenzo 35*2bb25285SStella Laurenzo Type type; 36*2bb25285SStella Laurenzo if (failed(parser.parseColonType(type))) 37*2bb25285SStella Laurenzo return failure(); 38*2bb25285SStella Laurenzo typeAttr = TypeAttr::get(type); 39*2bb25285SStella Laurenzo return success(); 40*2bb25285SStella Laurenzo } 41*2bb25285SStella Laurenzo 42*2bb25285SStella Laurenzo static void printTypedInitialValue(OpAsmPrinter &p, Operation *op, 43*2bb25285SStella Laurenzo TypeAttr type, Attribute attr) { 44*2bb25285SStella Laurenzo if (attr) { 45*2bb25285SStella Laurenzo p << "("; 46*2bb25285SStella Laurenzo p.printAttribute(attr); 47*2bb25285SStella Laurenzo p << ")"; 48*2bb25285SStella Laurenzo } 49*2bb25285SStella Laurenzo 50*2bb25285SStella Laurenzo p << " : "; 51*2bb25285SStella Laurenzo p.printAttribute(type); 52*2bb25285SStella Laurenzo } 53*2bb25285SStella Laurenzo 54*2bb25285SStella Laurenzo /// some.op custom<SymbolVisibility>($sym_visibility) $sym_name 55*2bb25285SStella Laurenzo /// -> 56*2bb25285SStella Laurenzo /// some.op public @foo 57*2bb25285SStella Laurenzo /// some.op private @foo 58*2bb25285SStella Laurenzo static ParseResult parseSymbolVisibility(OpAsmParser &parser, 59*2bb25285SStella Laurenzo StringAttr &symVisibilityAttr) { 60*2bb25285SStella Laurenzo StringRef symVisibility; 61*2bb25285SStella Laurenzo (void)parser.parseOptionalKeyword(&symVisibility, 62*2bb25285SStella Laurenzo {"public", "private", "nested"}); 63*2bb25285SStella Laurenzo if (symVisibility.empty()) 64*2bb25285SStella Laurenzo return parser.emitError(parser.getCurrentLocation()) 65*2bb25285SStella Laurenzo << "expected 'public', 'private', or 'nested'"; 66*2bb25285SStella Laurenzo if (!symVisibility.empty()) 67*2bb25285SStella Laurenzo symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility); 68*2bb25285SStella Laurenzo return success(); 69*2bb25285SStella Laurenzo } 70*2bb25285SStella Laurenzo 71*2bb25285SStella Laurenzo static void printSymbolVisibility(OpAsmPrinter &p, Operation *op, 72*2bb25285SStella Laurenzo StringAttr symVisibilityAttr) { 73*2bb25285SStella Laurenzo if (!symVisibilityAttr) 74*2bb25285SStella Laurenzo p << "public"; 75*2bb25285SStella Laurenzo else 76*2bb25285SStella Laurenzo p << symVisibilityAttr.getValue(); 77*2bb25285SStella Laurenzo } 78*2bb25285SStella Laurenzo 79*2bb25285SStella Laurenzo //===----------------------------------------------------------------------===// 8061352a58SStella Laurenzo // TableGen'd op method definitions 8161352a58SStella Laurenzo //===----------------------------------------------------------------------===// 8261352a58SStella Laurenzo 8361352a58SStella Laurenzo #define GET_OP_CLASSES 8461352a58SStella Laurenzo #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc" 8561352a58SStella Laurenzo 8661352a58SStella Laurenzo //===----------------------------------------------------------------------===// 8761352a58SStella Laurenzo // FuncOp 8861352a58SStella Laurenzo //===----------------------------------------------------------------------===// 8961352a58SStella Laurenzo 9061352a58SStella Laurenzo ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 9161352a58SStella Laurenzo auto buildFuncType = 9261352a58SStella Laurenzo [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 9361352a58SStella Laurenzo function_interface_impl::VariadicFlag, 9461352a58SStella Laurenzo std::string &) { return builder.getFunctionType(argTypes, results); }; 9561352a58SStella Laurenzo 9661352a58SStella Laurenzo return function_interface_impl::parseFunctionOp( 9761352a58SStella Laurenzo parser, result, /*allowVariadic=*/false, buildFuncType); 9861352a58SStella Laurenzo } 9961352a58SStella Laurenzo 10061352a58SStella Laurenzo void FuncOp::print(OpAsmPrinter &p) { 10161352a58SStella Laurenzo function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 10261352a58SStella Laurenzo } 10361352a58SStella Laurenzo 10461352a58SStella Laurenzo //===----------------------------------------------------------------------===// 105*2bb25285SStella Laurenzo // GlobalOp 106*2bb25285SStella Laurenzo //===----------------------------------------------------------------------===// 107*2bb25285SStella Laurenzo 108*2bb25285SStella Laurenzo LogicalResult GlobalOp::verify() { 109*2bb25285SStella Laurenzo if (!getIsMutable() && !getValue()) 110*2bb25285SStella Laurenzo return emitOpError() << "immutable global must have an initial value"; 111*2bb25285SStella Laurenzo return success(); 112*2bb25285SStella Laurenzo } 113*2bb25285SStella Laurenzo 114*2bb25285SStella Laurenzo //===----------------------------------------------------------------------===// 115*2bb25285SStella Laurenzo // GlobalLoadConstOp 116*2bb25285SStella Laurenzo //===----------------------------------------------------------------------===// 117*2bb25285SStella Laurenzo 118*2bb25285SStella Laurenzo GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) { 119*2bb25285SStella Laurenzo return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 120*2bb25285SStella Laurenzo getOperation()->getParentOp(), getGlobalAttr()); 121*2bb25285SStella Laurenzo } 122*2bb25285SStella Laurenzo 123*2bb25285SStella Laurenzo LogicalResult 124*2bb25285SStella Laurenzo GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 125*2bb25285SStella Laurenzo GlobalOp referrent = getGlobalOp(symbolTable); 126*2bb25285SStella Laurenzo if (!referrent) 127*2bb25285SStella Laurenzo return emitOpError() << "undefined global: " << getGlobal(); 128*2bb25285SStella Laurenzo 129*2bb25285SStella Laurenzo if (referrent.getIsMutable()) 130*2bb25285SStella Laurenzo return emitOpError() << "cannot load as const from mutable global " 131*2bb25285SStella Laurenzo << getGlobal(); 132*2bb25285SStella Laurenzo 133*2bb25285SStella Laurenzo if (referrent.getType() != getResult().getType()) 134*2bb25285SStella Laurenzo return emitOpError() << "cannot load from global typed " 135*2bb25285SStella Laurenzo << referrent.getType() << " as " 136*2bb25285SStella Laurenzo << getResult().getType(); 137*2bb25285SStella Laurenzo 138*2bb25285SStella Laurenzo return success(); 139*2bb25285SStella Laurenzo } 140*2bb25285SStella Laurenzo 141*2bb25285SStella Laurenzo //===----------------------------------------------------------------------===// 14261352a58SStella Laurenzo // SubgraphOp 14361352a58SStella Laurenzo //===----------------------------------------------------------------------===// 14461352a58SStella Laurenzo 14561352a58SStella Laurenzo ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) { 14661352a58SStella Laurenzo auto buildFuncType = 14761352a58SStella Laurenzo [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 14861352a58SStella Laurenzo function_interface_impl::VariadicFlag, 14961352a58SStella Laurenzo std::string &) { return builder.getFunctionType(argTypes, results); }; 15061352a58SStella Laurenzo 15161352a58SStella Laurenzo return function_interface_impl::parseFunctionOp( 15261352a58SStella Laurenzo parser, result, /*allowVariadic=*/false, buildFuncType); 15361352a58SStella Laurenzo } 15461352a58SStella Laurenzo 15561352a58SStella Laurenzo void SubgraphOp::print(OpAsmPrinter &p) { 15661352a58SStella Laurenzo function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 15761352a58SStella Laurenzo } 15861352a58SStella Laurenzo 15961352a58SStella Laurenzo //===----------------------------------------------------------------------===// 16061352a58SStella Laurenzo // OutputOp 16161352a58SStella Laurenzo //===----------------------------------------------------------------------===// 16261352a58SStella Laurenzo 16361352a58SStella Laurenzo LogicalResult OutputOp::verify() { 16461352a58SStella Laurenzo auto function = cast<SubgraphOp>((*this)->getParentOp()); 16561352a58SStella Laurenzo 16661352a58SStella Laurenzo // The operand number and types must match the function signature. 16761352a58SStella Laurenzo const auto &results = function.getFunctionType().getResults(); 16861352a58SStella Laurenzo if (getNumOperands() != results.size()) 16961352a58SStella Laurenzo return emitOpError("has ") 17061352a58SStella Laurenzo << getNumOperands() << " operands, but enclosing function (@" 17161352a58SStella Laurenzo << function.getName() << ") outputs " << results.size(); 17261352a58SStella Laurenzo 17361352a58SStella Laurenzo for (unsigned i = 0, e = results.size(); i != e; ++i) 17461352a58SStella Laurenzo if (getOperand(i).getType() != results[i]) 17561352a58SStella Laurenzo return emitError() << "type of output operand " << i << " (" 17661352a58SStella Laurenzo << getOperand(i).getType() 17761352a58SStella Laurenzo << ") doesn't match function result type (" 17861352a58SStella Laurenzo << results[i] << ")" 17961352a58SStella Laurenzo << " in function @" << function.getName(); 18061352a58SStella Laurenzo 18161352a58SStella Laurenzo return success(); 18261352a58SStella Laurenzo } 18361352a58SStella Laurenzo 18461352a58SStella Laurenzo //===----------------------------------------------------------------------===// 18561352a58SStella Laurenzo // ReturnOp 18661352a58SStella Laurenzo //===----------------------------------------------------------------------===// 18761352a58SStella Laurenzo 18861352a58SStella Laurenzo LogicalResult ReturnOp::verify() { 18961352a58SStella Laurenzo auto function = cast<FuncOp>((*this)->getParentOp()); 19061352a58SStella Laurenzo 19161352a58SStella Laurenzo // The operand number and types must match the function signature. 19261352a58SStella Laurenzo const auto &results = function.getFunctionType().getResults(); 19361352a58SStella Laurenzo if (getNumOperands() != results.size()) 19461352a58SStella Laurenzo return emitOpError("has ") 19561352a58SStella Laurenzo << getNumOperands() << " operands, but enclosing function (@" 19661352a58SStella Laurenzo << function.getName() << ") returns " << results.size(); 19761352a58SStella Laurenzo 19861352a58SStella Laurenzo for (unsigned i = 0, e = results.size(); i != e; ++i) 19961352a58SStella Laurenzo if (getOperand(i).getType() != results[i]) 20061352a58SStella Laurenzo return emitError() << "type of return operand " << i << " (" 20161352a58SStella Laurenzo << getOperand(i).getType() 20261352a58SStella Laurenzo << ") doesn't match function result type (" 20361352a58SStella Laurenzo << results[i] << ")" 20461352a58SStella Laurenzo << " in function @" << function.getName(); 20561352a58SStella Laurenzo 20661352a58SStella Laurenzo return success(); 20761352a58SStella Laurenzo } 208