1 //===- Async.cpp - MLIR Async Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Async/IR/Async.h" 10 11 #include "mlir/IR/DialectImplementation.h" 12 #include "llvm/ADT/TypeSwitch.h" 13 14 using namespace mlir; 15 using namespace mlir::async; 16 17 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc" 18 19 void AsyncDialect::initialize() { 20 addOperations< 21 #define GET_OP_LIST 22 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 23 >(); 24 addTypes< 25 #define GET_TYPEDEF_LIST 26 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" 27 >(); 28 } 29 30 //===----------------------------------------------------------------------===// 31 // YieldOp 32 //===----------------------------------------------------------------------===// 33 34 static LogicalResult verify(YieldOp op) { 35 // Get the underlying value types from async values returned from the 36 // parent `async.execute` operation. 37 auto executeOp = op->getParentOfType<ExecuteOp>(); 38 auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) { 39 return result.getType().cast<ValueType>().getValueType(); 40 }); 41 42 if (op.getOperandTypes() != types) 43 return op.emitOpError("operand types do not match the types returned from " 44 "the parent ExecuteOp"); 45 46 return success(); 47 } 48 49 //===----------------------------------------------------------------------===// 50 /// ExecuteOp 51 //===----------------------------------------------------------------------===// 52 53 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes"; 54 55 void ExecuteOp::getNumRegionInvocations( 56 ArrayRef<Attribute> operands, SmallVectorImpl<int64_t> &countPerRegion) { 57 (void)operands; 58 assert(countPerRegion.empty()); 59 countPerRegion.push_back(1); 60 } 61 62 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index, 63 ArrayRef<Attribute> operands, 64 SmallVectorImpl<RegionSuccessor> ®ions) { 65 // The `body` region branch back to the parent operation. 66 if (index.hasValue()) { 67 assert(*index == 0); 68 regions.push_back(RegionSuccessor(getResults())); 69 return; 70 } 71 72 // Otherwise the successor is the body region. 73 regions.push_back(RegionSuccessor(&body())); 74 } 75 76 void ExecuteOp::build(OpBuilder &builder, OperationState &result, 77 TypeRange resultTypes, ValueRange dependencies, 78 ValueRange operands, BodyBuilderFn bodyBuilder) { 79 80 result.addOperands(dependencies); 81 result.addOperands(operands); 82 83 // Add derived `operand_segment_sizes` attribute based on parsed operands. 84 int32_t numDependencies = dependencies.size(); 85 int32_t numOperands = operands.size(); 86 auto operandSegmentSizes = DenseIntElementsAttr::get( 87 VectorType::get({2}, builder.getIntegerType(32)), 88 {numDependencies, numOperands}); 89 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); 90 91 // First result is always a token, and then `resultTypes` wrapped into 92 // `async.value`. 93 result.addTypes({TokenType::get(result.getContext())}); 94 for (Type type : resultTypes) 95 result.addTypes(ValueType::get(type)); 96 97 // Add a body region with block arguments as unwrapped async value operands. 98 Region *bodyRegion = result.addRegion(); 99 bodyRegion->push_back(new Block); 100 Block &bodyBlock = bodyRegion->front(); 101 for (Value operand : operands) { 102 auto valueType = operand.getType().dyn_cast<ValueType>(); 103 bodyBlock.addArgument(valueType ? valueType.getValueType() 104 : operand.getType()); 105 } 106 107 // Create the default terminator if the builder is not provided and if the 108 // expected result is empty. Otherwise, leave this to the caller 109 // because we don't know which values to return from the execute op. 110 if (resultTypes.empty() && !bodyBuilder) { 111 OpBuilder::InsertionGuard guard(builder); 112 builder.setInsertionPointToStart(&bodyBlock); 113 builder.create<async::YieldOp>(result.location, ValueRange()); 114 } else if (bodyBuilder) { 115 OpBuilder::InsertionGuard guard(builder); 116 builder.setInsertionPointToStart(&bodyBlock); 117 bodyBuilder(builder, result.location, bodyBlock.getArguments()); 118 } 119 } 120 121 static void print(OpAsmPrinter &p, ExecuteOp op) { 122 p << op.getOperationName(); 123 124 // [%tokens,...] 125 if (!op.dependencies().empty()) 126 p << " [" << op.dependencies() << "]"; 127 128 // (%value as %unwrapped: !async.value<!arg.type>, ...) 129 if (!op.operands().empty()) { 130 p << " ("; 131 Block *entry = op.body().empty() ? nullptr : &op.body().front(); 132 llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable { 133 Value argument = entry ? entry->getArgument(n++) : Value(); 134 p << operand << " as " << argument << ": " << operand.getType(); 135 }); 136 p << ")"; 137 } 138 139 // -> (!async.value<!return.type>, ...) 140 p.printOptionalArrowTypeList(llvm::drop_begin(op.getResultTypes())); 141 p.printOptionalAttrDictWithKeyword(op->getAttrs(), 142 {kOperandSegmentSizesAttr}); 143 p.printRegion(op.body(), /*printEntryBlockArgs=*/false); 144 } 145 146 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) { 147 MLIRContext *ctx = result.getContext(); 148 149 // Sizes of parsed variadic operands, will be updated below after parsing. 150 int32_t numDependencies = 0; 151 int32_t numOperands = 0; 152 153 auto tokenTy = TokenType::get(ctx); 154 155 // Parse dependency tokens. 156 if (succeeded(parser.parseOptionalLSquare())) { 157 SmallVector<OpAsmParser::OperandType, 4> tokenArgs; 158 if (parser.parseOperandList(tokenArgs) || 159 parser.resolveOperands(tokenArgs, tokenTy, result.operands) || 160 parser.parseRSquare()) 161 return failure(); 162 163 numDependencies = tokenArgs.size(); 164 } 165 166 // Parse async value operands (%value as %unwrapped : !async.value<!type>). 167 SmallVector<OpAsmParser::OperandType, 4> valueArgs; 168 SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs; 169 SmallVector<Type, 4> valueTypes; 170 SmallVector<Type, 4> unwrappedTypes; 171 172 if (succeeded(parser.parseOptionalLParen())) { 173 auto argsLoc = parser.getCurrentLocation(); 174 175 // Parse a single instance of `%value as %unwrapped : !async.value<!type>`. 176 auto parseAsyncValueArg = [&]() -> ParseResult { 177 if (parser.parseOperand(valueArgs.emplace_back()) || 178 parser.parseKeyword("as") || 179 parser.parseOperand(unwrappedArgs.emplace_back()) || 180 parser.parseColonType(valueTypes.emplace_back())) 181 return failure(); 182 183 auto valueTy = valueTypes.back().dyn_cast<ValueType>(); 184 unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type()); 185 186 return success(); 187 }; 188 189 // If the next token is `)` skip async value arguments parsing. 190 if (failed(parser.parseOptionalRParen())) { 191 do { 192 if (parseAsyncValueArg()) 193 return failure(); 194 } while (succeeded(parser.parseOptionalComma())); 195 196 if (parser.parseRParen() || 197 parser.resolveOperands(valueArgs, valueTypes, argsLoc, 198 result.operands)) 199 return failure(); 200 } 201 202 numOperands = valueArgs.size(); 203 } 204 205 // Add derived `operand_segment_sizes` attribute based on parsed operands. 206 auto operandSegmentSizes = DenseIntElementsAttr::get( 207 VectorType::get({2}, parser.getBuilder().getI32Type()), 208 {numDependencies, numOperands}); 209 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); 210 211 // Parse the types of results returned from the async execute op. 212 SmallVector<Type, 4> resultTypes; 213 if (parser.parseOptionalArrowTypeList(resultTypes)) 214 return failure(); 215 216 // Async execute first result is always a completion token. 217 parser.addTypeToList(tokenTy, result.types); 218 parser.addTypesToList(resultTypes, result.types); 219 220 // Parse operation attributes. 221 NamedAttrList attrs; 222 if (parser.parseOptionalAttrDictWithKeyword(attrs)) 223 return failure(); 224 result.addAttributes(attrs); 225 226 // Parse asynchronous region. 227 Region *body = result.addRegion(); 228 if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs}, 229 /*argTypes=*/{unwrappedTypes}, 230 /*enableNameShadowing=*/false)) 231 return failure(); 232 233 return success(); 234 } 235 236 static LogicalResult verify(ExecuteOp op) { 237 // Unwrap async.execute value operands types. 238 auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) { 239 return operand.getType().cast<ValueType>().getValueType(); 240 }); 241 242 // Verify that unwrapped argument types matches the body region arguments. 243 if (op.body().getArgumentTypes() != unwrappedTypes) 244 return op.emitOpError("async body region argument types do not match the " 245 "execute operation arguments types"); 246 247 return success(); 248 } 249 250 //===----------------------------------------------------------------------===// 251 /// CreateGroupOp 252 //===----------------------------------------------------------------------===// 253 254 LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op, 255 PatternRewriter &rewriter) { 256 // Find all `await_all` users of the group. 257 llvm::SmallVector<AwaitAllOp> awaitAllUsers; 258 259 auto isAwaitAll = [&](Operation *op) -> bool { 260 if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) { 261 awaitAllUsers.push_back(awaitAll); 262 return true; 263 } 264 return false; 265 }; 266 267 // Check if all users of the group are `await_all` operations. 268 if (!llvm::all_of(op->getUsers(), isAwaitAll)) 269 return failure(); 270 271 // If group is only awaited without adding anything to it, we can safely erase 272 // the create operation and all users. 273 for (AwaitAllOp awaitAll : awaitAllUsers) 274 rewriter.eraseOp(awaitAll); 275 rewriter.eraseOp(op); 276 277 return success(); 278 } 279 280 //===----------------------------------------------------------------------===// 281 /// AwaitOp 282 //===----------------------------------------------------------------------===// 283 284 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand, 285 ArrayRef<NamedAttribute> attrs) { 286 result.addOperands({operand}); 287 result.attributes.append(attrs.begin(), attrs.end()); 288 289 // Add unwrapped async.value type to the returned values types. 290 if (auto valueType = operand.getType().dyn_cast<ValueType>()) 291 result.addTypes(valueType.getValueType()); 292 } 293 294 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, 295 Type &resultType) { 296 if (parser.parseType(operandType)) 297 return failure(); 298 299 // Add unwrapped async.value type to the returned values types. 300 if (auto valueType = operandType.dyn_cast<ValueType>()) 301 resultType = valueType.getValueType(); 302 303 return success(); 304 } 305 306 static void printAwaitResultType(OpAsmPrinter &p, Operation *op, 307 Type operandType, Type resultType) { 308 p << operandType; 309 } 310 311 static LogicalResult verify(AwaitOp op) { 312 Type argType = op.operand().getType(); 313 314 // Awaiting on a token does not have any results. 315 if (argType.isa<TokenType>() && !op.getResultTypes().empty()) 316 return op.emitOpError("awaiting on a token must have empty result"); 317 318 // Awaiting on a value unwraps the async value type. 319 if (auto value = argType.dyn_cast<ValueType>()) { 320 if (*op.getResultType() != value.getValueType()) 321 return op.emitOpError() 322 << "result type " << *op.getResultType() 323 << " does not match async value type " << value.getValueType(); 324 } 325 326 return success(); 327 } 328 329 //===----------------------------------------------------------------------===// 330 // TableGen'd op method definitions 331 //===----------------------------------------------------------------------===// 332 333 #define GET_OP_CLASSES 334 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 335 336 //===----------------------------------------------------------------------===// 337 // TableGen'd type method definitions 338 //===----------------------------------------------------------------------===// 339 340 #define GET_TYPEDEF_CLASSES 341 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" 342 343 void ValueType::print(DialectAsmPrinter &printer) const { 344 printer << getMnemonic(); 345 printer << "<"; 346 printer.printType(getValueType()); 347 printer << '>'; 348 } 349 350 Type ValueType::parse(mlir::MLIRContext *, mlir::DialectAsmParser &parser) { 351 Type ty; 352 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) { 353 parser.emitError(parser.getNameLoc(), "failed to parse async value type"); 354 return Type(); 355 } 356 return ValueType::get(ty); 357 } 358 359 /// Print a type registered to this dialect. 360 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const { 361 if (failed(generatedTypePrinter(type, os))) 362 llvm_unreachable("unexpected 'async' type kind"); 363 } 364 365 /// Parse a type registered to this dialect. 366 Type AsyncDialect::parseType(DialectAsmParser &parser) const { 367 StringRef typeTag; 368 if (parser.parseKeyword(&typeTag)) 369 return Type(); 370 Type genType; 371 auto parseResult = generatedTypeParser(parser.getBuilder().getContext(), 372 parser, typeTag, genType); 373 if (parseResult.hasValue()) 374 return genType; 375 parser.emitError(parser.getNameLoc(), "unknown async type: ") << typeTag; 376 return {}; 377 } 378