1 //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MLProgram/IR/MLProgram.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/FunctionImplementation.h" 12 13 using namespace mlir; 14 using namespace mlir::ml_program; 15 16 //===----------------------------------------------------------------------===// 17 // Custom asm helpers 18 //===----------------------------------------------------------------------===// 19 20 /// Parse and print an ordering clause for a variadic of consuming tokens 21 /// and an producing token. 22 /// 23 /// Syntax: 24 /// ordering(%0, %1 -> !ml_program.token) 25 /// ordering(() -> !ml_program.token) 26 /// 27 /// If both the consuming and producing token are not present on the op, then 28 /// the clause prints nothing. 29 static ParseResult parseTokenOrdering( 30 OpAsmParser &parser, 31 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &consumeTokens, 32 Type &produceTokenType) { 33 if (failed(parser.parseOptionalKeyword("ordering")) || 34 failed(parser.parseLParen())) 35 return success(); 36 37 // Parse consuming token list. If there are no consuming tokens, the 38 // '()' null list represents this. 39 if (succeeded(parser.parseOptionalLParen())) { 40 if (failed(parser.parseRParen())) 41 return failure(); 42 } else { 43 if (failed(parser.parseOperandList(consumeTokens, 44 /*requiredOperandCount=*/-1))) 45 return failure(); 46 } 47 48 // Parse producer token. 49 if (failed(parser.parseArrow())) 50 return failure(); 51 if (failed(parser.parseType(produceTokenType))) 52 return failure(); 53 54 if (failed(parser.parseRParen())) 55 return failure(); 56 57 return success(); 58 } 59 60 static void printTokenOrdering(OpAsmPrinter &p, Operation *op, 61 OperandRange consumeTokens, 62 Type produceTokenType) { 63 if (consumeTokens.empty() && !produceTokenType) 64 return; 65 66 p << " ordering("; 67 if (consumeTokens.empty()) 68 p << "()"; 69 else 70 p.printOperands(consumeTokens); 71 if (produceTokenType) { 72 p << " -> "; 73 p.printType(produceTokenType); 74 } 75 p << ")"; 76 } 77 78 /// some.op custom<TypeOrAttr>($type, $attr) 79 /// 80 /// Uninitialized: 81 /// some.op : tensor<3xi32> 82 /// Initialized to narrower type than op: 83 /// some.op (dense<0> : tensor<3xi32>) : tensor<?xi32> 84 static ParseResult parseTypedInitialValue(OpAsmParser &parser, 85 TypeAttr &typeAttr, Attribute &attr) { 86 if (succeeded(parser.parseOptionalLParen())) { 87 if (failed(parser.parseAttribute(attr))) 88 return failure(); 89 if (failed(parser.parseRParen())) 90 return failure(); 91 } 92 93 Type type; 94 if (failed(parser.parseColonType(type))) 95 return failure(); 96 typeAttr = TypeAttr::get(type); 97 return success(); 98 } 99 100 static void printTypedInitialValue(OpAsmPrinter &p, Operation *op, 101 TypeAttr type, Attribute attr) { 102 if (attr) { 103 p << "("; 104 p.printAttribute(attr); 105 p << ")"; 106 } 107 108 p << " : "; 109 p.printAttribute(type); 110 } 111 112 /// some.op custom<SymbolVisibility>($sym_visibility) $sym_name 113 /// -> 114 /// some.op public @foo 115 /// some.op private @foo 116 static ParseResult parseSymbolVisibility(OpAsmParser &parser, 117 StringAttr &symVisibilityAttr) { 118 StringRef symVisibility; 119 (void)parser.parseOptionalKeyword(&symVisibility, 120 {"public", "private", "nested"}); 121 if (symVisibility.empty()) 122 return parser.emitError(parser.getCurrentLocation()) 123 << "expected 'public', 'private', or 'nested'"; 124 if (!symVisibility.empty()) 125 symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility); 126 return success(); 127 } 128 129 static void printSymbolVisibility(OpAsmPrinter &p, Operation *op, 130 StringAttr symVisibilityAttr) { 131 if (!symVisibilityAttr) 132 p << "public"; 133 else 134 p << symVisibilityAttr.getValue(); 135 } 136 137 //===----------------------------------------------------------------------===// 138 // TableGen'd op method definitions 139 //===----------------------------------------------------------------------===// 140 141 #define GET_OP_CLASSES 142 #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc" 143 144 //===----------------------------------------------------------------------===// 145 // FuncOp 146 //===----------------------------------------------------------------------===// 147 148 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 149 auto buildFuncType = 150 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 151 function_interface_impl::VariadicFlag, 152 std::string &) { return builder.getFunctionType(argTypes, results); }; 153 154 return function_interface_impl::parseFunctionOp( 155 parser, result, /*allowVariadic=*/false, buildFuncType); 156 } 157 158 void FuncOp::print(OpAsmPrinter &p) { 159 function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 160 } 161 162 //===----------------------------------------------------------------------===// 163 // GlobalOp 164 //===----------------------------------------------------------------------===// 165 166 LogicalResult GlobalOp::verify() { 167 if (!getIsMutable() && !getValue()) 168 return emitOpError() << "immutable global must have an initial value"; 169 return success(); 170 } 171 172 //===----------------------------------------------------------------------===// 173 // GlobalLoadOp 174 //===----------------------------------------------------------------------===// 175 176 GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) { 177 return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 178 getOperation()->getParentOp(), getGlobalAttr()); 179 } 180 181 LogicalResult 182 GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 183 GlobalOp referrent = getGlobalOp(symbolTable); 184 if (!referrent) 185 return emitOpError() << "undefined global: " << getGlobal(); 186 187 if (referrent.getType() != getResult().getType()) { 188 return emitOpError() << "cannot load from global typed " 189 << referrent.getType() << " as " 190 << getResult().getType(); 191 } 192 193 return success(); 194 } 195 196 //===----------------------------------------------------------------------===// 197 // GlobalLoadConstOp 198 //===----------------------------------------------------------------------===// 199 200 GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) { 201 return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 202 getOperation()->getParentOp(), getGlobalAttr()); 203 } 204 205 LogicalResult 206 GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 207 GlobalOp referrent = getGlobalOp(symbolTable); 208 if (!referrent) 209 return emitOpError() << "undefined global: " << getGlobal(); 210 211 if (referrent.getIsMutable()) 212 return emitOpError() << "cannot load as const from mutable global " 213 << getGlobal(); 214 215 if (referrent.getType() != getResult().getType()) 216 return emitOpError() << "cannot load from global typed " 217 << referrent.getType() << " as " 218 << getResult().getType(); 219 220 return success(); 221 } 222 223 //===----------------------------------------------------------------------===// 224 // GlobalLoadGraphOp 225 //===----------------------------------------------------------------------===// 226 227 GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) { 228 return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 229 getOperation()->getParentOp(), getGlobalAttr()); 230 } 231 232 LogicalResult 233 GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 234 GlobalOp referrent = getGlobalOp(symbolTable); 235 if (!referrent) 236 return emitOpError() << "undefined global: " << getGlobal(); 237 238 if (referrent.getType() != getResult().getType()) { 239 return emitOpError() << "cannot load from global typed " 240 << referrent.getType() << " as " 241 << getResult().getType(); 242 } 243 244 return success(); 245 } 246 247 //===----------------------------------------------------------------------===// 248 // GlobalStoreOp 249 //===----------------------------------------------------------------------===// 250 251 GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) { 252 return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 253 getOperation()->getParentOp(), getGlobalAttr()); 254 } 255 256 LogicalResult 257 GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 258 GlobalOp referrent = getGlobalOp(symbolTable); 259 if (!referrent) 260 return emitOpError() << "undefined global: " << getGlobal(); 261 262 if (!referrent.getIsMutable()) { 263 return emitOpError() << "cannot store to an immutable global " 264 << getGlobal(); 265 } 266 267 if (referrent.getType() != getValue().getType()) { 268 return emitOpError() << "cannot store to a global typed " 269 << referrent.getType() << " from " 270 << getValue().getType(); 271 } 272 273 return success(); 274 } 275 276 //===----------------------------------------------------------------------===// 277 // GlobalStoreGraphOp 278 //===----------------------------------------------------------------------===// 279 280 GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) { 281 return symbolTable.lookupNearestSymbolFrom<GlobalOp>( 282 getOperation()->getParentOp(), getGlobalAttr()); 283 } 284 285 LogicalResult 286 GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 287 GlobalOp referrent = getGlobalOp(symbolTable); 288 if (!referrent) 289 return emitOpError() << "undefined global: " << getGlobal(); 290 291 if (!referrent.getIsMutable()) { 292 return emitOpError() << "cannot store to an immutable global " 293 << getGlobal(); 294 } 295 296 if (referrent.getType() != getValue().getType()) { 297 return emitOpError() << "cannot store to a global typed " 298 << referrent.getType() << " from " 299 << getValue().getType(); 300 } 301 302 return success(); 303 } 304 305 //===----------------------------------------------------------------------===// 306 // SubgraphOp 307 //===----------------------------------------------------------------------===// 308 309 ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) { 310 auto buildFuncType = 311 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 312 function_interface_impl::VariadicFlag, 313 std::string &) { return builder.getFunctionType(argTypes, results); }; 314 315 return function_interface_impl::parseFunctionOp( 316 parser, result, /*allowVariadic=*/false, buildFuncType); 317 } 318 319 void SubgraphOp::print(OpAsmPrinter &p) { 320 function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 321 } 322 323 //===----------------------------------------------------------------------===// 324 // OutputOp 325 //===----------------------------------------------------------------------===// 326 327 LogicalResult OutputOp::verify() { 328 auto function = cast<SubgraphOp>((*this)->getParentOp()); 329 330 // The operand number and types must match the function signature. 331 const auto &results = function.getFunctionType().getResults(); 332 if (getNumOperands() != results.size()) 333 return emitOpError("has ") 334 << getNumOperands() << " operands, but enclosing function (@" 335 << function.getName() << ") outputs " << results.size(); 336 337 for (unsigned i = 0, e = results.size(); i != e; ++i) 338 if (getOperand(i).getType() != results[i]) 339 return emitError() << "type of output operand " << i << " (" 340 << getOperand(i).getType() 341 << ") doesn't match function result type (" 342 << results[i] << ")" 343 << " in function @" << function.getName(); 344 345 return success(); 346 } 347 348 //===----------------------------------------------------------------------===// 349 // ReturnOp 350 //===----------------------------------------------------------------------===// 351 352 LogicalResult ReturnOp::verify() { 353 auto function = cast<FuncOp>((*this)->getParentOp()); 354 355 // The operand number and types must match the function signature. 356 const auto &results = function.getFunctionType().getResults(); 357 if (getNumOperands() != results.size()) 358 return emitOpError("has ") 359 << getNumOperands() << " operands, but enclosing function (@" 360 << function.getName() << ") returns " << results.size(); 361 362 for (unsigned i = 0, e = results.size(); i != e; ++i) 363 if (getOperand(i).getType() != results[i]) 364 return emitError() << "type of return operand " << i << " (" 365 << getOperand(i).getType() 366 << ") doesn't match function result type (" 367 << results[i] << ")" 368 << " in function @" << function.getName(); 369 370 return success(); 371 } 372