1 //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MLProgram/IR/MLProgram.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/FunctionImplementation.h" 12 13 using namespace mlir; 14 using namespace mlir::ml_program; 15 16 //===----------------------------------------------------------------------===// 17 // Custom asm helpers 18 //===----------------------------------------------------------------------===// 19 20 /// Parse and print an ordering clause for a variadic of consuming tokens 21 /// and an optional producing token. 22 /// 23 /// Syntax: 24 /// ordering(%0, %1 -> !ml_program.token) 25 /// ordering(() -> !ml_program.token) 26 /// ordering(%0, %1) 27 /// 28 /// If both the consuming and producing token are not present on the op, then 29 /// the clause prints nothing. 30 static ParseResult parseTokenOrdering( 31 OpAsmParser &parser, 32 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &consumeTokens, 33 Type &produceTokenType) { 34 if (failed(parser.parseOptionalKeyword("ordering")) || 35 failed(parser.parseLParen())) 36 return success(); 37 38 // Parse consuming token list. If there are no consuming tokens, the 39 // '()' null list represents this. 40 if (succeeded(parser.parseOptionalLParen())) { 41 if (failed(parser.parseRParen())) 42 return failure(); 43 } else { 44 if (failed(parser.parseOperandList(consumeTokens, 45 /*requiredOperandCount=*/-1))) 46 return failure(); 47 } 48 49 // Parse optional producer token. 50 if (succeeded(parser.parseOptionalArrow())) 51 if (failed(parser.parseType(produceTokenType))) 52 return failure(); 53 54 if (failed(parser.parseRParen())) 55 return failure(); 56 57 return success(); 58 } 59 60 static void printTokenOrdering(OpAsmPrinter &p, Operation *op, 61 OperandRange consumeTokens, 62 Type produceTokenType) { 63 if (consumeTokens.empty() && !produceTokenType) 64 return; 65 66 p << " ordering("; 67 if (consumeTokens.empty()) 68 p << "()"; 69 else 70 p.printOperands(consumeTokens); 71 if (produceTokenType) { 72 p << " -> "; 73 p.printType(produceTokenType); 74 } 75 p << ")"; 76 } 77 78 /// some.op custom<TypeOrAttr>($type, $attr) 79 /// 80 /// Uninitialized: 81 /// some.op : tensor<3xi32> 82 /// Initialized to narrower type than op: 83 /// some.op (dense<0> : tensor<3xi32>) : tensor<?xi32> 84 static ParseResult parseTypedInitialValue(OpAsmParser &parser, 85 TypeAttr &typeAttr, Attribute &attr) { 86 if (succeeded(parser.parseOptionalLParen())) { 87 if (failed(parser.parseAttribute(attr))) 88 return failure(); 89 if (failed(parser.parseRParen())) 90 return failure(); 91 } 92 93 Type type; 94 if (failed(parser.parseColonType(type))) 95 return failure(); 96 typeAttr = TypeAttr::get(type); 97 return success(); 98 } 99 100 static void printTypedInitialValue(OpAsmPrinter &p, Operation *op, 101 TypeAttr type, Attribute attr) { 102 if (attr) { 103 p << "("; 104 p.printAttribute(attr); 105 p << ")"; 106 } 107 108 p << " : "; 109 p.printAttribute(type); 110 } 111 112 /// some.op custom<SymbolVisibility>($sym_visibility) $sym_name 113 /// -> 114 /// some.op public @foo 115 /// some.op private @foo 116 static ParseResult parseSymbolVisibility(OpAsmParser &parser, 117 StringAttr &symVisibilityAttr) { 118 StringRef symVisibility; 119 (void)parser.parseOptionalKeyword(&symVisibility, 120 {"public", "private", "nested"}); 121 if (symVisibility.empty()) 122 return parser.emitError(parser.getCurrentLocation()) 123 << "expected 'public', 'private', or 'nested'"; 124 if (!symVisibility.empty()) 125 symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility); 126 return success(); 127 } 128 129 static void printSymbolVisibility(OpAsmPrinter &p, Operation *op, 130 StringAttr symVisibilityAttr) { 131 if (!symVisibilityAttr) 132 p << "public"; 133 else 134 p << symVisibilityAttr.getValue(); 135 } 136 137 //===----------------------------------------------------------------------===// 138 // TableGen'd op method definitions 139 //===----------------------------------------------------------------------===// 140 141 #define GET_OP_CLASSES 142 #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc" 143 144 //===----------------------------------------------------------------------===// 145 // FuncOp 146 //===----------------------------------------------------------------------===// 147 148 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 149 auto buildFuncType = 150 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 151 function_interface_impl::VariadicFlag, 152 std::string &) { return builder.getFunctionType(argTypes, results); }; 153 154 return function_interface_impl::parseFunctionOp( 155 parser, result, /*allowVariadic=*/false, buildFuncType); 156 } 157 158 void FuncOp::print(OpAsmPrinter &p) { 159 function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 160 } 161 162 //===----------------------------------------------------------------------===// 163 // GlobalOp 164 //===----------------------------------------------------------------------===// 165 166 LogicalResult GlobalOp::verify() { 167 if (!getIsMutable() && !getValue()) 168 return emitOpError() << "immutable global must have an initial value"; 169 return success(); 170 } 171 172 //===----------------------------------------------------------------------===// 173 // GlobalLoadOp 174 //===----------------------------------------------------------------------===// 175 176 GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) { 177 return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 178 getOperation()->getParentOp(), getGlobalAttr()); 179 } 180 181 LogicalResult 182 GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 183 GlobalOp referrent = getGlobalOp(symbolTable); 184 if (!referrent) 185 return emitOpError() << "undefined global: " << getGlobal(); 186 187 if (referrent.getType() != getResult().getType()) { 188 return emitOpError() << "cannot load from global typed " 189 << referrent.getType() << " as " 190 << getResult().getType(); 191 } 192 193 return success(); 194 } 195 196 //===----------------------------------------------------------------------===// 197 // GlobalLoadConstOp 198 //===----------------------------------------------------------------------===// 199 200 GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) { 201 return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 202 getOperation()->getParentOp(), getGlobalAttr()); 203 } 204 205 LogicalResult 206 GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 207 GlobalOp referrent = getGlobalOp(symbolTable); 208 if (!referrent) 209 return emitOpError() << "undefined global: " << getGlobal(); 210 211 if (referrent.getIsMutable()) 212 return emitOpError() << "cannot load as const from mutable global " 213 << getGlobal(); 214 215 if (referrent.getType() != getResult().getType()) 216 return emitOpError() << "cannot load from global typed " 217 << referrent.getType() << " as " 218 << getResult().getType(); 219 220 return success(); 221 } 222 223 //===----------------------------------------------------------------------===// 224 // GlobalStoreOp 225 //===----------------------------------------------------------------------===// 226 227 GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) { 228 return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 229 getOperation()->getParentOp(), getGlobalAttr()); 230 } 231 232 LogicalResult 233 GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 234 GlobalOp referrent = getGlobalOp(symbolTable); 235 if (!referrent) 236 return emitOpError() << "undefined global: " << getGlobal(); 237 238 if (!referrent.getIsMutable()) { 239 return emitOpError() << "cannot store to an immutable global " 240 << getGlobal(); 241 } 242 243 if (referrent.getType() != getValue().getType()) { 244 return emitOpError() << "cannot store to a global typed " 245 << referrent.getType() << " from " 246 << getValue().getType(); 247 } 248 249 return success(); 250 } 251 252 //===----------------------------------------------------------------------===// 253 // SubgraphOp 254 //===----------------------------------------------------------------------===// 255 256 ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) { 257 auto buildFuncType = 258 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 259 function_interface_impl::VariadicFlag, 260 std::string &) { return builder.getFunctionType(argTypes, results); }; 261 262 return function_interface_impl::parseFunctionOp( 263 parser, result, /*allowVariadic=*/false, buildFuncType); 264 } 265 266 void SubgraphOp::print(OpAsmPrinter &p) { 267 function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 268 } 269 270 //===----------------------------------------------------------------------===// 271 // OutputOp 272 //===----------------------------------------------------------------------===// 273 274 LogicalResult OutputOp::verify() { 275 auto function = cast<SubgraphOp>((*this)->getParentOp()); 276 277 // The operand number and types must match the function signature. 278 const auto &results = function.getFunctionType().getResults(); 279 if (getNumOperands() != results.size()) 280 return emitOpError("has ") 281 << getNumOperands() << " operands, but enclosing function (@" 282 << function.getName() << ") outputs " << results.size(); 283 284 for (unsigned i = 0, e = results.size(); i != e; ++i) 285 if (getOperand(i).getType() != results[i]) 286 return emitError() << "type of output operand " << i << " (" 287 << getOperand(i).getType() 288 << ") doesn't match function result type (" 289 << results[i] << ")" 290 << " in function @" << function.getName(); 291 292 return success(); 293 } 294 295 //===----------------------------------------------------------------------===// 296 // ReturnOp 297 //===----------------------------------------------------------------------===// 298 299 LogicalResult ReturnOp::verify() { 300 auto function = cast<FuncOp>((*this)->getParentOp()); 301 302 // The operand number and types must match the function signature. 303 const auto &results = function.getFunctionType().getResults(); 304 if (getNumOperands() != results.size()) 305 return emitOpError("has ") 306 << getNumOperands() << " operands, but enclosing function (@" 307 << function.getName() << ") returns " << results.size(); 308 309 for (unsigned i = 0, e = results.size(); i != e; ++i) 310 if (getOperand(i).getType() != results[i]) 311 return emitError() << "type of return operand " << i << " (" 312 << getOperand(i).getType() 313 << ") doesn't match function result type (" 314 << results[i] << ")" 315 << " in function @" << function.getName(); 316 317 return success(); 318 } 319