1 //===- Async.cpp - MLIR Async Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Async/IR/Async.h" 10 11 #include "mlir/IR/DialectImplementation.h" 12 #include "llvm/ADT/TypeSwitch.h" 13 14 using namespace mlir; 15 using namespace mlir::async; 16 17 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc" 18 19 constexpr StringRef AsyncDialect::kAllowedToBlockAttrName; 20 21 void AsyncDialect::initialize() { 22 addOperations< 23 #define GET_OP_LIST 24 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 25 >(); 26 addTypes< 27 #define GET_TYPEDEF_LIST 28 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" 29 >(); 30 } 31 32 //===----------------------------------------------------------------------===// 33 // YieldOp 34 //===----------------------------------------------------------------------===// 35 36 static LogicalResult verify(YieldOp op) { 37 // Get the underlying value types from async values returned from the 38 // parent `async.execute` operation. 39 auto executeOp = op->getParentOfType<ExecuteOp>(); 40 auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) { 41 return result.getType().cast<ValueType>().getValueType(); 42 }); 43 44 if (op.getOperandTypes() != types) 45 return op.emitOpError("operand types do not match the types returned from " 46 "the parent ExecuteOp"); 47 48 return success(); 49 } 50 51 MutableOperandRange 52 YieldOp::getMutableSuccessorOperands(Optional<unsigned> index) { 53 assert(!index.hasValue()); 54 return operandsMutable(); 55 } 56 57 //===----------------------------------------------------------------------===// 58 /// ExecuteOp 59 //===----------------------------------------------------------------------===// 60 61 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes"; 62 63 OperandRange ExecuteOp::getSuccessorEntryOperands(unsigned index) { 64 assert(index == 0 && "invalid region index"); 65 return operands(); 66 } 67 68 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index, 69 ArrayRef<Attribute>, 70 SmallVectorImpl<RegionSuccessor> ®ions) { 71 // The `body` region branch back to the parent operation. 72 if (index.hasValue()) { 73 assert(*index == 0 && "invalid region index"); 74 regions.push_back(RegionSuccessor(results())); 75 return; 76 } 77 78 // Otherwise the successor is the body region. 79 regions.push_back(RegionSuccessor(&body(), body().getArguments())); 80 } 81 82 void ExecuteOp::build(OpBuilder &builder, OperationState &result, 83 TypeRange resultTypes, ValueRange dependencies, 84 ValueRange operands, BodyBuilderFn bodyBuilder) { 85 86 result.addOperands(dependencies); 87 result.addOperands(operands); 88 89 // Add derived `operand_segment_sizes` attribute based on parsed operands. 90 int32_t numDependencies = dependencies.size(); 91 int32_t numOperands = operands.size(); 92 auto operandSegmentSizes = DenseIntElementsAttr::get( 93 VectorType::get({2}, builder.getIntegerType(32)), 94 {numDependencies, numOperands}); 95 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); 96 97 // First result is always a token, and then `resultTypes` wrapped into 98 // `async.value`. 99 result.addTypes({TokenType::get(result.getContext())}); 100 for (Type type : resultTypes) 101 result.addTypes(ValueType::get(type)); 102 103 // Add a body region with block arguments as unwrapped async value operands. 104 Region *bodyRegion = result.addRegion(); 105 bodyRegion->push_back(new Block); 106 Block &bodyBlock = bodyRegion->front(); 107 for (Value operand : operands) { 108 auto valueType = operand.getType().dyn_cast<ValueType>(); 109 bodyBlock.addArgument(valueType ? valueType.getValueType() 110 : operand.getType()); 111 } 112 113 // Create the default terminator if the builder is not provided and if the 114 // expected result is empty. Otherwise, leave this to the caller 115 // because we don't know which values to return from the execute op. 116 if (resultTypes.empty() && !bodyBuilder) { 117 OpBuilder::InsertionGuard guard(builder); 118 builder.setInsertionPointToStart(&bodyBlock); 119 builder.create<async::YieldOp>(result.location, ValueRange()); 120 } else if (bodyBuilder) { 121 OpBuilder::InsertionGuard guard(builder); 122 builder.setInsertionPointToStart(&bodyBlock); 123 bodyBuilder(builder, result.location, bodyBlock.getArguments()); 124 } 125 } 126 127 static void print(OpAsmPrinter &p, ExecuteOp op) { 128 // [%tokens,...] 129 if (!op.dependencies().empty()) 130 p << " [" << op.dependencies() << "]"; 131 132 // (%value as %unwrapped: !async.value<!arg.type>, ...) 133 if (!op.operands().empty()) { 134 p << " ("; 135 Block *entry = op.body().empty() ? nullptr : &op.body().front(); 136 llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable { 137 Value argument = entry ? entry->getArgument(n++) : Value(); 138 p << operand << " as " << argument << ": " << operand.getType(); 139 }); 140 p << ")"; 141 } 142 143 // -> (!async.value<!return.type>, ...) 144 p.printOptionalArrowTypeList(llvm::drop_begin(op.getResultTypes())); 145 p.printOptionalAttrDictWithKeyword(op->getAttrs(), 146 {kOperandSegmentSizesAttr}); 147 p.printRegion(op.body(), /*printEntryBlockArgs=*/false); 148 } 149 150 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) { 151 MLIRContext *ctx = result.getContext(); 152 153 // Sizes of parsed variadic operands, will be updated below after parsing. 154 int32_t numDependencies = 0; 155 156 auto tokenTy = TokenType::get(ctx); 157 158 // Parse dependency tokens. 159 if (succeeded(parser.parseOptionalLSquare())) { 160 SmallVector<OpAsmParser::OperandType, 4> tokenArgs; 161 if (parser.parseOperandList(tokenArgs) || 162 parser.resolveOperands(tokenArgs, tokenTy, result.operands) || 163 parser.parseRSquare()) 164 return failure(); 165 166 numDependencies = tokenArgs.size(); 167 } 168 169 // Parse async value operands (%value as %unwrapped : !async.value<!type>). 170 SmallVector<OpAsmParser::OperandType, 4> valueArgs; 171 SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs; 172 SmallVector<Type, 4> valueTypes; 173 SmallVector<Type, 4> unwrappedTypes; 174 175 // Parse a single instance of `%value as %unwrapped : !async.value<!type>`. 176 auto parseAsyncValueArg = [&]() -> ParseResult { 177 if (parser.parseOperand(valueArgs.emplace_back()) || 178 parser.parseKeyword("as") || 179 parser.parseOperand(unwrappedArgs.emplace_back()) || 180 parser.parseColonType(valueTypes.emplace_back())) 181 return failure(); 182 183 auto valueTy = valueTypes.back().dyn_cast<ValueType>(); 184 unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type()); 185 186 return success(); 187 }; 188 189 auto argsLoc = parser.getCurrentLocation(); 190 if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::OptionalParen, 191 parseAsyncValueArg) || 192 parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands)) 193 return failure(); 194 195 int32_t numOperands = valueArgs.size(); 196 197 // Add derived `operand_segment_sizes` attribute based on parsed operands. 198 auto operandSegmentSizes = DenseIntElementsAttr::get( 199 VectorType::get({2}, parser.getBuilder().getI32Type()), 200 {numDependencies, numOperands}); 201 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); 202 203 // Parse the types of results returned from the async execute op. 204 SmallVector<Type, 4> resultTypes; 205 if (parser.parseOptionalArrowTypeList(resultTypes)) 206 return failure(); 207 208 // Async execute first result is always a completion token. 209 parser.addTypeToList(tokenTy, result.types); 210 parser.addTypesToList(resultTypes, result.types); 211 212 // Parse operation attributes. 213 NamedAttrList attrs; 214 if (parser.parseOptionalAttrDictWithKeyword(attrs)) 215 return failure(); 216 result.addAttributes(attrs); 217 218 // Parse asynchronous region. 219 Region *body = result.addRegion(); 220 if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs}, 221 /*argTypes=*/{unwrappedTypes}, 222 /*enableNameShadowing=*/false)) 223 return failure(); 224 225 return success(); 226 } 227 228 static LogicalResult verify(ExecuteOp op) { 229 // Unwrap async.execute value operands types. 230 auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) { 231 return operand.getType().cast<ValueType>().getValueType(); 232 }); 233 234 // Verify that unwrapped argument types matches the body region arguments. 235 if (op.body().getArgumentTypes() != unwrappedTypes) 236 return op.emitOpError("async body region argument types do not match the " 237 "execute operation arguments types"); 238 239 return success(); 240 } 241 242 //===----------------------------------------------------------------------===// 243 /// CreateGroupOp 244 //===----------------------------------------------------------------------===// 245 246 LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op, 247 PatternRewriter &rewriter) { 248 // Find all `await_all` users of the group. 249 llvm::SmallVector<AwaitAllOp> awaitAllUsers; 250 251 auto isAwaitAll = [&](Operation *op) -> bool { 252 if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) { 253 awaitAllUsers.push_back(awaitAll); 254 return true; 255 } 256 return false; 257 }; 258 259 // Check if all users of the group are `await_all` operations. 260 if (!llvm::all_of(op->getUsers(), isAwaitAll)) 261 return failure(); 262 263 // If group is only awaited without adding anything to it, we can safely erase 264 // the create operation and all users. 265 for (AwaitAllOp awaitAll : awaitAllUsers) 266 rewriter.eraseOp(awaitAll); 267 rewriter.eraseOp(op); 268 269 return success(); 270 } 271 272 //===----------------------------------------------------------------------===// 273 /// AwaitOp 274 //===----------------------------------------------------------------------===// 275 276 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand, 277 ArrayRef<NamedAttribute> attrs) { 278 result.addOperands({operand}); 279 result.attributes.append(attrs.begin(), attrs.end()); 280 281 // Add unwrapped async.value type to the returned values types. 282 if (auto valueType = operand.getType().dyn_cast<ValueType>()) 283 result.addTypes(valueType.getValueType()); 284 } 285 286 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, 287 Type &resultType) { 288 if (parser.parseType(operandType)) 289 return failure(); 290 291 // Add unwrapped async.value type to the returned values types. 292 if (auto valueType = operandType.dyn_cast<ValueType>()) 293 resultType = valueType.getValueType(); 294 295 return success(); 296 } 297 298 static void printAwaitResultType(OpAsmPrinter &p, Operation *op, 299 Type operandType, Type resultType) { 300 p << operandType; 301 } 302 303 static LogicalResult verify(AwaitOp op) { 304 Type argType = op.operand().getType(); 305 306 // Awaiting on a token does not have any results. 307 if (argType.isa<TokenType>() && !op.getResultTypes().empty()) 308 return op.emitOpError("awaiting on a token must have empty result"); 309 310 // Awaiting on a value unwraps the async value type. 311 if (auto value = argType.dyn_cast<ValueType>()) { 312 if (*op.getResultType() != value.getValueType()) 313 return op.emitOpError() 314 << "result type " << *op.getResultType() 315 << " does not match async value type " << value.getValueType(); 316 } 317 318 return success(); 319 } 320 321 //===----------------------------------------------------------------------===// 322 // TableGen'd op method definitions 323 //===----------------------------------------------------------------------===// 324 325 #define GET_OP_CLASSES 326 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 327 328 //===----------------------------------------------------------------------===// 329 // TableGen'd type method definitions 330 //===----------------------------------------------------------------------===// 331 332 #define GET_TYPEDEF_CLASSES 333 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" 334 335 void ValueType::print(AsmPrinter &printer) const { 336 printer << "<"; 337 printer.printType(getValueType()); 338 printer << '>'; 339 } 340 341 Type ValueType::parse(mlir::AsmParser &parser) { 342 Type ty; 343 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) { 344 parser.emitError(parser.getNameLoc(), "failed to parse async value type"); 345 return Type(); 346 } 347 return ValueType::get(ty); 348 } 349 350 /// Print a type registered to this dialect. 351 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const { 352 if (failed(generatedTypePrinter(type, os))) 353 llvm_unreachable("unexpected 'async' type kind"); 354 } 355 356 /// Parse a type registered to this dialect. 357 Type AsyncDialect::parseType(DialectAsmParser &parser) const { 358 StringRef typeTag; 359 if (parser.parseKeyword(&typeTag)) 360 return Type(); 361 Type genType; 362 auto parseResult = generatedTypeParser(parser, typeTag, genType); 363 if (parseResult.hasValue()) 364 return genType; 365 parser.emitError(parser.getNameLoc(), "unknown async type: ") << typeTag; 366 return {}; 367 } 368