1 //===- Async.cpp - MLIR Async Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Async/IR/Async.h" 10 11 #include "mlir/IR/DialectImplementation.h" 12 #include "llvm/ADT/TypeSwitch.h" 13 14 using namespace mlir; 15 using namespace mlir::async; 16 17 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc" 18 19 constexpr StringRef AsyncDialect::kAllowedToBlockAttrName; 20 21 void AsyncDialect::initialize() { 22 addOperations< 23 #define GET_OP_LIST 24 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 25 >(); 26 addTypes< 27 #define GET_TYPEDEF_LIST 28 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" 29 >(); 30 } 31 32 //===----------------------------------------------------------------------===// 33 // YieldOp 34 //===----------------------------------------------------------------------===// 35 36 static LogicalResult verify(YieldOp op) { 37 // Get the underlying value types from async values returned from the 38 // parent `async.execute` operation. 39 auto executeOp = op->getParentOfType<ExecuteOp>(); 40 auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) { 41 return result.getType().cast<ValueType>().getValueType(); 42 }); 43 44 if (op.getOperandTypes() != types) 45 return op.emitOpError("operand types do not match the types returned from " 46 "the parent ExecuteOp"); 47 48 return success(); 49 } 50 51 //===----------------------------------------------------------------------===// 52 /// ExecuteOp 53 //===----------------------------------------------------------------------===// 54 55 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes"; 56 57 void ExecuteOp::getNumRegionInvocations( 58 ArrayRef<Attribute> operands, SmallVectorImpl<int64_t> &countPerRegion) { 59 (void)operands; 60 assert(countPerRegion.empty()); 61 countPerRegion.push_back(1); 62 } 63 64 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index, 65 ArrayRef<Attribute> operands, 66 SmallVectorImpl<RegionSuccessor> ®ions) { 67 // The `body` region branch back to the parent operation. 68 if (index.hasValue()) { 69 assert(*index == 0); 70 regions.push_back(RegionSuccessor(getResults())); 71 return; 72 } 73 74 // Otherwise the successor is the body region. 75 regions.push_back(RegionSuccessor(&body())); 76 } 77 78 void ExecuteOp::build(OpBuilder &builder, OperationState &result, 79 TypeRange resultTypes, ValueRange dependencies, 80 ValueRange operands, BodyBuilderFn bodyBuilder) { 81 82 result.addOperands(dependencies); 83 result.addOperands(operands); 84 85 // Add derived `operand_segment_sizes` attribute based on parsed operands. 86 int32_t numDependencies = dependencies.size(); 87 int32_t numOperands = operands.size(); 88 auto operandSegmentSizes = DenseIntElementsAttr::get( 89 VectorType::get({2}, builder.getIntegerType(32)), 90 {numDependencies, numOperands}); 91 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); 92 93 // First result is always a token, and then `resultTypes` wrapped into 94 // `async.value`. 95 result.addTypes({TokenType::get(result.getContext())}); 96 for (Type type : resultTypes) 97 result.addTypes(ValueType::get(type)); 98 99 // Add a body region with block arguments as unwrapped async value operands. 100 Region *bodyRegion = result.addRegion(); 101 bodyRegion->push_back(new Block); 102 Block &bodyBlock = bodyRegion->front(); 103 for (Value operand : operands) { 104 auto valueType = operand.getType().dyn_cast<ValueType>(); 105 bodyBlock.addArgument(valueType ? valueType.getValueType() 106 : operand.getType()); 107 } 108 109 // Create the default terminator if the builder is not provided and if the 110 // expected result is empty. Otherwise, leave this to the caller 111 // because we don't know which values to return from the execute op. 112 if (resultTypes.empty() && !bodyBuilder) { 113 OpBuilder::InsertionGuard guard(builder); 114 builder.setInsertionPointToStart(&bodyBlock); 115 builder.create<async::YieldOp>(result.location, ValueRange()); 116 } else if (bodyBuilder) { 117 OpBuilder::InsertionGuard guard(builder); 118 builder.setInsertionPointToStart(&bodyBlock); 119 bodyBuilder(builder, result.location, bodyBlock.getArguments()); 120 } 121 } 122 123 static void print(OpAsmPrinter &p, ExecuteOp op) { 124 p << op.getOperationName(); 125 126 // [%tokens,...] 127 if (!op.dependencies().empty()) 128 p << " [" << op.dependencies() << "]"; 129 130 // (%value as %unwrapped: !async.value<!arg.type>, ...) 131 if (!op.operands().empty()) { 132 p << " ("; 133 Block *entry = op.body().empty() ? nullptr : &op.body().front(); 134 llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable { 135 Value argument = entry ? entry->getArgument(n++) : Value(); 136 p << operand << " as " << argument << ": " << operand.getType(); 137 }); 138 p << ")"; 139 } 140 141 // -> (!async.value<!return.type>, ...) 142 p.printOptionalArrowTypeList(llvm::drop_begin(op.getResultTypes())); 143 p.printOptionalAttrDictWithKeyword(op->getAttrs(), 144 {kOperandSegmentSizesAttr}); 145 p.printRegion(op.body(), /*printEntryBlockArgs=*/false); 146 } 147 148 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) { 149 MLIRContext *ctx = result.getContext(); 150 151 // Sizes of parsed variadic operands, will be updated below after parsing. 152 int32_t numDependencies = 0; 153 int32_t numOperands = 0; 154 155 auto tokenTy = TokenType::get(ctx); 156 157 // Parse dependency tokens. 158 if (succeeded(parser.parseOptionalLSquare())) { 159 SmallVector<OpAsmParser::OperandType, 4> tokenArgs; 160 if (parser.parseOperandList(tokenArgs) || 161 parser.resolveOperands(tokenArgs, tokenTy, result.operands) || 162 parser.parseRSquare()) 163 return failure(); 164 165 numDependencies = tokenArgs.size(); 166 } 167 168 // Parse async value operands (%value as %unwrapped : !async.value<!type>). 169 SmallVector<OpAsmParser::OperandType, 4> valueArgs; 170 SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs; 171 SmallVector<Type, 4> valueTypes; 172 SmallVector<Type, 4> unwrappedTypes; 173 174 if (succeeded(parser.parseOptionalLParen())) { 175 auto argsLoc = parser.getCurrentLocation(); 176 177 // Parse a single instance of `%value as %unwrapped : !async.value<!type>`. 178 auto parseAsyncValueArg = [&]() -> ParseResult { 179 if (parser.parseOperand(valueArgs.emplace_back()) || 180 parser.parseKeyword("as") || 181 parser.parseOperand(unwrappedArgs.emplace_back()) || 182 parser.parseColonType(valueTypes.emplace_back())) 183 return failure(); 184 185 auto valueTy = valueTypes.back().dyn_cast<ValueType>(); 186 unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type()); 187 188 return success(); 189 }; 190 191 // If the next token is `)` skip async value arguments parsing. 192 if (failed(parser.parseOptionalRParen())) { 193 do { 194 if (parseAsyncValueArg()) 195 return failure(); 196 } while (succeeded(parser.parseOptionalComma())); 197 198 if (parser.parseRParen() || 199 parser.resolveOperands(valueArgs, valueTypes, argsLoc, 200 result.operands)) 201 return failure(); 202 } 203 204 numOperands = valueArgs.size(); 205 } 206 207 // Add derived `operand_segment_sizes` attribute based on parsed operands. 208 auto operandSegmentSizes = DenseIntElementsAttr::get( 209 VectorType::get({2}, parser.getBuilder().getI32Type()), 210 {numDependencies, numOperands}); 211 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); 212 213 // Parse the types of results returned from the async execute op. 214 SmallVector<Type, 4> resultTypes; 215 if (parser.parseOptionalArrowTypeList(resultTypes)) 216 return failure(); 217 218 // Async execute first result is always a completion token. 219 parser.addTypeToList(tokenTy, result.types); 220 parser.addTypesToList(resultTypes, result.types); 221 222 // Parse operation attributes. 223 NamedAttrList attrs; 224 if (parser.parseOptionalAttrDictWithKeyword(attrs)) 225 return failure(); 226 result.addAttributes(attrs); 227 228 // Parse asynchronous region. 229 Region *body = result.addRegion(); 230 if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs}, 231 /*argTypes=*/{unwrappedTypes}, 232 /*enableNameShadowing=*/false)) 233 return failure(); 234 235 return success(); 236 } 237 238 static LogicalResult verify(ExecuteOp op) { 239 // Unwrap async.execute value operands types. 240 auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) { 241 return operand.getType().cast<ValueType>().getValueType(); 242 }); 243 244 // Verify that unwrapped argument types matches the body region arguments. 245 if (op.body().getArgumentTypes() != unwrappedTypes) 246 return op.emitOpError("async body region argument types do not match the " 247 "execute operation arguments types"); 248 249 return success(); 250 } 251 252 //===----------------------------------------------------------------------===// 253 /// CreateGroupOp 254 //===----------------------------------------------------------------------===// 255 256 LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op, 257 PatternRewriter &rewriter) { 258 // Find all `await_all` users of the group. 259 llvm::SmallVector<AwaitAllOp> awaitAllUsers; 260 261 auto isAwaitAll = [&](Operation *op) -> bool { 262 if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) { 263 awaitAllUsers.push_back(awaitAll); 264 return true; 265 } 266 return false; 267 }; 268 269 // Check if all users of the group are `await_all` operations. 270 if (!llvm::all_of(op->getUsers(), isAwaitAll)) 271 return failure(); 272 273 // If group is only awaited without adding anything to it, we can safely erase 274 // the create operation and all users. 275 for (AwaitAllOp awaitAll : awaitAllUsers) 276 rewriter.eraseOp(awaitAll); 277 rewriter.eraseOp(op); 278 279 return success(); 280 } 281 282 //===----------------------------------------------------------------------===// 283 /// AwaitOp 284 //===----------------------------------------------------------------------===// 285 286 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand, 287 ArrayRef<NamedAttribute> attrs) { 288 result.addOperands({operand}); 289 result.attributes.append(attrs.begin(), attrs.end()); 290 291 // Add unwrapped async.value type to the returned values types. 292 if (auto valueType = operand.getType().dyn_cast<ValueType>()) 293 result.addTypes(valueType.getValueType()); 294 } 295 296 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, 297 Type &resultType) { 298 if (parser.parseType(operandType)) 299 return failure(); 300 301 // Add unwrapped async.value type to the returned values types. 302 if (auto valueType = operandType.dyn_cast<ValueType>()) 303 resultType = valueType.getValueType(); 304 305 return success(); 306 } 307 308 static void printAwaitResultType(OpAsmPrinter &p, Operation *op, 309 Type operandType, Type resultType) { 310 p << operandType; 311 } 312 313 static LogicalResult verify(AwaitOp op) { 314 Type argType = op.operand().getType(); 315 316 // Awaiting on a token does not have any results. 317 if (argType.isa<TokenType>() && !op.getResultTypes().empty()) 318 return op.emitOpError("awaiting on a token must have empty result"); 319 320 // Awaiting on a value unwraps the async value type. 321 if (auto value = argType.dyn_cast<ValueType>()) { 322 if (*op.getResultType() != value.getValueType()) 323 return op.emitOpError() 324 << "result type " << *op.getResultType() 325 << " does not match async value type " << value.getValueType(); 326 } 327 328 return success(); 329 } 330 331 //===----------------------------------------------------------------------===// 332 // TableGen'd op method definitions 333 //===----------------------------------------------------------------------===// 334 335 #define GET_OP_CLASSES 336 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 337 338 //===----------------------------------------------------------------------===// 339 // TableGen'd type method definitions 340 //===----------------------------------------------------------------------===// 341 342 #define GET_TYPEDEF_CLASSES 343 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" 344 345 void ValueType::print(DialectAsmPrinter &printer) const { 346 printer << getMnemonic(); 347 printer << "<"; 348 printer.printType(getValueType()); 349 printer << '>'; 350 } 351 352 Type ValueType::parse(mlir::MLIRContext *, mlir::DialectAsmParser &parser) { 353 Type ty; 354 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) { 355 parser.emitError(parser.getNameLoc(), "failed to parse async value type"); 356 return Type(); 357 } 358 return ValueType::get(ty); 359 } 360 361 /// Print a type registered to this dialect. 362 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const { 363 if (failed(generatedTypePrinter(type, os))) 364 llvm_unreachable("unexpected 'async' type kind"); 365 } 366 367 /// Parse a type registered to this dialect. 368 Type AsyncDialect::parseType(DialectAsmParser &parser) const { 369 StringRef typeTag; 370 if (parser.parseKeyword(&typeTag)) 371 return Type(); 372 Type genType; 373 auto parseResult = generatedTypeParser(parser.getBuilder().getContext(), 374 parser, typeTag, genType); 375 if (parseResult.hasValue()) 376 return genType; 377 parser.emitError(parser.getNameLoc(), "unknown async type: ") << typeTag; 378 return {}; 379 } 380