1 //===- Async.cpp - MLIR Async Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Async/IR/Async.h" 10 11 #include "mlir/IR/DialectImplementation.h" 12 #include "llvm/ADT/TypeSwitch.h" 13 14 using namespace mlir; 15 using namespace mlir::async; 16 17 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc" 18 19 constexpr StringRef AsyncDialect::kAllowedToBlockAttrName; 20 21 void AsyncDialect::initialize() { 22 addOperations< 23 #define GET_OP_LIST 24 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 25 >(); 26 addTypes< 27 #define GET_TYPEDEF_LIST 28 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" 29 >(); 30 } 31 32 //===----------------------------------------------------------------------===// 33 // YieldOp 34 //===----------------------------------------------------------------------===// 35 36 static LogicalResult verify(YieldOp op) { 37 // Get the underlying value types from async values returned from the 38 // parent `async.execute` operation. 39 auto executeOp = op->getParentOfType<ExecuteOp>(); 40 auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) { 41 return result.getType().cast<ValueType>().getValueType(); 42 }); 43 44 if (op.getOperandTypes() != types) 45 return op.emitOpError("operand types do not match the types returned from " 46 "the parent ExecuteOp"); 47 48 return success(); 49 } 50 51 MutableOperandRange 52 YieldOp::getMutableSuccessorOperands(Optional<unsigned> index) { 53 assert(!index.hasValue()); 54 return operandsMutable(); 55 } 56 57 //===----------------------------------------------------------------------===// 58 /// ExecuteOp 59 //===----------------------------------------------------------------------===// 60 61 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes"; 62 63 void ExecuteOp::getNumRegionInvocations( 64 ArrayRef<Attribute>, SmallVectorImpl<int64_t> &countPerRegion) { 65 assert(countPerRegion.empty()); 66 countPerRegion.push_back(1); 67 } 68 69 OperandRange ExecuteOp::getSuccessorEntryOperands(unsigned index) { 70 assert(index == 0 && "invalid region index"); 71 return operands(); 72 } 73 74 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index, 75 ArrayRef<Attribute>, 76 SmallVectorImpl<RegionSuccessor> ®ions) { 77 // The `body` region branch back to the parent operation. 78 if (index.hasValue()) { 79 assert(*index == 0 && "invalid region index"); 80 regions.push_back(RegionSuccessor(results())); 81 return; 82 } 83 84 // Otherwise the successor is the body region. 85 regions.push_back(RegionSuccessor(&body(), body().getArguments())); 86 } 87 88 void ExecuteOp::build(OpBuilder &builder, OperationState &result, 89 TypeRange resultTypes, ValueRange dependencies, 90 ValueRange operands, BodyBuilderFn bodyBuilder) { 91 92 result.addOperands(dependencies); 93 result.addOperands(operands); 94 95 // Add derived `operand_segment_sizes` attribute based on parsed operands. 96 int32_t numDependencies = dependencies.size(); 97 int32_t numOperands = operands.size(); 98 auto operandSegmentSizes = DenseIntElementsAttr::get( 99 VectorType::get({2}, builder.getIntegerType(32)), 100 {numDependencies, numOperands}); 101 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); 102 103 // First result is always a token, and then `resultTypes` wrapped into 104 // `async.value`. 105 result.addTypes({TokenType::get(result.getContext())}); 106 for (Type type : resultTypes) 107 result.addTypes(ValueType::get(type)); 108 109 // Add a body region with block arguments as unwrapped async value operands. 110 Region *bodyRegion = result.addRegion(); 111 bodyRegion->push_back(new Block); 112 Block &bodyBlock = bodyRegion->front(); 113 for (Value operand : operands) { 114 auto valueType = operand.getType().dyn_cast<ValueType>(); 115 bodyBlock.addArgument(valueType ? valueType.getValueType() 116 : operand.getType()); 117 } 118 119 // Create the default terminator if the builder is not provided and if the 120 // expected result is empty. Otherwise, leave this to the caller 121 // because we don't know which values to return from the execute op. 122 if (resultTypes.empty() && !bodyBuilder) { 123 OpBuilder::InsertionGuard guard(builder); 124 builder.setInsertionPointToStart(&bodyBlock); 125 builder.create<async::YieldOp>(result.location, ValueRange()); 126 } else if (bodyBuilder) { 127 OpBuilder::InsertionGuard guard(builder); 128 builder.setInsertionPointToStart(&bodyBlock); 129 bodyBuilder(builder, result.location, bodyBlock.getArguments()); 130 } 131 } 132 133 static void print(OpAsmPrinter &p, ExecuteOp op) { 134 p << op.getOperationName(); 135 136 // [%tokens,...] 137 if (!op.dependencies().empty()) 138 p << " [" << op.dependencies() << "]"; 139 140 // (%value as %unwrapped: !async.value<!arg.type>, ...) 141 if (!op.operands().empty()) { 142 p << " ("; 143 Block *entry = op.body().empty() ? nullptr : &op.body().front(); 144 llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable { 145 Value argument = entry ? entry->getArgument(n++) : Value(); 146 p << operand << " as " << argument << ": " << operand.getType(); 147 }); 148 p << ")"; 149 } 150 151 // -> (!async.value<!return.type>, ...) 152 p.printOptionalArrowTypeList(llvm::drop_begin(op.getResultTypes())); 153 p.printOptionalAttrDictWithKeyword(op->getAttrs(), 154 {kOperandSegmentSizesAttr}); 155 p.printRegion(op.body(), /*printEntryBlockArgs=*/false); 156 } 157 158 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) { 159 MLIRContext *ctx = result.getContext(); 160 161 // Sizes of parsed variadic operands, will be updated below after parsing. 162 int32_t numDependencies = 0; 163 int32_t numOperands = 0; 164 165 auto tokenTy = TokenType::get(ctx); 166 167 // Parse dependency tokens. 168 if (succeeded(parser.parseOptionalLSquare())) { 169 SmallVector<OpAsmParser::OperandType, 4> tokenArgs; 170 if (parser.parseOperandList(tokenArgs) || 171 parser.resolveOperands(tokenArgs, tokenTy, result.operands) || 172 parser.parseRSquare()) 173 return failure(); 174 175 numDependencies = tokenArgs.size(); 176 } 177 178 // Parse async value operands (%value as %unwrapped : !async.value<!type>). 179 SmallVector<OpAsmParser::OperandType, 4> valueArgs; 180 SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs; 181 SmallVector<Type, 4> valueTypes; 182 SmallVector<Type, 4> unwrappedTypes; 183 184 if (succeeded(parser.parseOptionalLParen())) { 185 auto argsLoc = parser.getCurrentLocation(); 186 187 // Parse a single instance of `%value as %unwrapped : !async.value<!type>`. 188 auto parseAsyncValueArg = [&]() -> ParseResult { 189 if (parser.parseOperand(valueArgs.emplace_back()) || 190 parser.parseKeyword("as") || 191 parser.parseOperand(unwrappedArgs.emplace_back()) || 192 parser.parseColonType(valueTypes.emplace_back())) 193 return failure(); 194 195 auto valueTy = valueTypes.back().dyn_cast<ValueType>(); 196 unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type()); 197 198 return success(); 199 }; 200 201 // If the next token is `)` skip async value arguments parsing. 202 if (failed(parser.parseOptionalRParen())) { 203 do { 204 if (parseAsyncValueArg()) 205 return failure(); 206 } while (succeeded(parser.parseOptionalComma())); 207 208 if (parser.parseRParen() || 209 parser.resolveOperands(valueArgs, valueTypes, argsLoc, 210 result.operands)) 211 return failure(); 212 } 213 214 numOperands = valueArgs.size(); 215 } 216 217 // Add derived `operand_segment_sizes` attribute based on parsed operands. 218 auto operandSegmentSizes = DenseIntElementsAttr::get( 219 VectorType::get({2}, parser.getBuilder().getI32Type()), 220 {numDependencies, numOperands}); 221 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); 222 223 // Parse the types of results returned from the async execute op. 224 SmallVector<Type, 4> resultTypes; 225 if (parser.parseOptionalArrowTypeList(resultTypes)) 226 return failure(); 227 228 // Async execute first result is always a completion token. 229 parser.addTypeToList(tokenTy, result.types); 230 parser.addTypesToList(resultTypes, result.types); 231 232 // Parse operation attributes. 233 NamedAttrList attrs; 234 if (parser.parseOptionalAttrDictWithKeyword(attrs)) 235 return failure(); 236 result.addAttributes(attrs); 237 238 // Parse asynchronous region. 239 Region *body = result.addRegion(); 240 if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs}, 241 /*argTypes=*/{unwrappedTypes}, 242 /*enableNameShadowing=*/false)) 243 return failure(); 244 245 return success(); 246 } 247 248 static LogicalResult verify(ExecuteOp op) { 249 // Unwrap async.execute value operands types. 250 auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) { 251 return operand.getType().cast<ValueType>().getValueType(); 252 }); 253 254 // Verify that unwrapped argument types matches the body region arguments. 255 if (op.body().getArgumentTypes() != unwrappedTypes) 256 return op.emitOpError("async body region argument types do not match the " 257 "execute operation arguments types"); 258 259 return success(); 260 } 261 262 //===----------------------------------------------------------------------===// 263 /// CreateGroupOp 264 //===----------------------------------------------------------------------===// 265 266 LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op, 267 PatternRewriter &rewriter) { 268 // Find all `await_all` users of the group. 269 llvm::SmallVector<AwaitAllOp> awaitAllUsers; 270 271 auto isAwaitAll = [&](Operation *op) -> bool { 272 if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) { 273 awaitAllUsers.push_back(awaitAll); 274 return true; 275 } 276 return false; 277 }; 278 279 // Check if all users of the group are `await_all` operations. 280 if (!llvm::all_of(op->getUsers(), isAwaitAll)) 281 return failure(); 282 283 // If group is only awaited without adding anything to it, we can safely erase 284 // the create operation and all users. 285 for (AwaitAllOp awaitAll : awaitAllUsers) 286 rewriter.eraseOp(awaitAll); 287 rewriter.eraseOp(op); 288 289 return success(); 290 } 291 292 //===----------------------------------------------------------------------===// 293 /// AwaitOp 294 //===----------------------------------------------------------------------===// 295 296 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand, 297 ArrayRef<NamedAttribute> attrs) { 298 result.addOperands({operand}); 299 result.attributes.append(attrs.begin(), attrs.end()); 300 301 // Add unwrapped async.value type to the returned values types. 302 if (auto valueType = operand.getType().dyn_cast<ValueType>()) 303 result.addTypes(valueType.getValueType()); 304 } 305 306 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, 307 Type &resultType) { 308 if (parser.parseType(operandType)) 309 return failure(); 310 311 // Add unwrapped async.value type to the returned values types. 312 if (auto valueType = operandType.dyn_cast<ValueType>()) 313 resultType = valueType.getValueType(); 314 315 return success(); 316 } 317 318 static void printAwaitResultType(OpAsmPrinter &p, Operation *op, 319 Type operandType, Type resultType) { 320 p << operandType; 321 } 322 323 static LogicalResult verify(AwaitOp op) { 324 Type argType = op.operand().getType(); 325 326 // Awaiting on a token does not have any results. 327 if (argType.isa<TokenType>() && !op.getResultTypes().empty()) 328 return op.emitOpError("awaiting on a token must have empty result"); 329 330 // Awaiting on a value unwraps the async value type. 331 if (auto value = argType.dyn_cast<ValueType>()) { 332 if (*op.getResultType() != value.getValueType()) 333 return op.emitOpError() 334 << "result type " << *op.getResultType() 335 << " does not match async value type " << value.getValueType(); 336 } 337 338 return success(); 339 } 340 341 //===----------------------------------------------------------------------===// 342 // TableGen'd op method definitions 343 //===----------------------------------------------------------------------===// 344 345 #define GET_OP_CLASSES 346 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 347 348 //===----------------------------------------------------------------------===// 349 // TableGen'd type method definitions 350 //===----------------------------------------------------------------------===// 351 352 #define GET_TYPEDEF_CLASSES 353 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" 354 355 void ValueType::print(DialectAsmPrinter &printer) const { 356 printer << getMnemonic(); 357 printer << "<"; 358 printer.printType(getValueType()); 359 printer << '>'; 360 } 361 362 Type ValueType::parse(mlir::MLIRContext *, mlir::DialectAsmParser &parser) { 363 Type ty; 364 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) { 365 parser.emitError(parser.getNameLoc(), "failed to parse async value type"); 366 return Type(); 367 } 368 return ValueType::get(ty); 369 } 370 371 /// Print a type registered to this dialect. 372 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const { 373 if (failed(generatedTypePrinter(type, os))) 374 llvm_unreachable("unexpected 'async' type kind"); 375 } 376 377 /// Parse a type registered to this dialect. 378 Type AsyncDialect::parseType(DialectAsmParser &parser) const { 379 StringRef typeTag; 380 if (parser.parseKeyword(&typeTag)) 381 return Type(); 382 Type genType; 383 auto parseResult = generatedTypeParser(parser.getBuilder().getContext(), 384 parser, typeTag, genType); 385 if (parseResult.hasValue()) 386 return genType; 387 parser.emitError(parser.getNameLoc(), "unknown async type: ") << typeTag; 388 return {}; 389 } 390