1 //===- Async.cpp - MLIR Async Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Async/IR/Async.h" 10 11 #include "mlir/IR/DialectImplementation.h" 12 #include "llvm/ADT/TypeSwitch.h" 13 14 using namespace mlir; 15 using namespace mlir::async; 16 17 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc" 18 19 constexpr StringRef AsyncDialect::kAllowedToBlockAttrName; 20 21 void AsyncDialect::initialize() { 22 addOperations< 23 #define GET_OP_LIST 24 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 25 >(); 26 addTypes< 27 #define GET_TYPEDEF_LIST 28 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" 29 >(); 30 } 31 32 //===----------------------------------------------------------------------===// 33 // YieldOp 34 //===----------------------------------------------------------------------===// 35 36 LogicalResult YieldOp::verify() { 37 // Get the underlying value types from async values returned from the 38 // parent `async.execute` operation. 39 auto executeOp = (*this)->getParentOfType<ExecuteOp>(); 40 auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) { 41 return result.getType().cast<ValueType>().getValueType(); 42 }); 43 44 if (getOperandTypes() != types) 45 return emitOpError("operand types do not match the types returned from " 46 "the parent ExecuteOp"); 47 48 return success(); 49 } 50 51 MutableOperandRange 52 YieldOp::getMutableSuccessorOperands(Optional<unsigned> index) { 53 return operandsMutable(); 54 } 55 56 //===----------------------------------------------------------------------===// 57 /// ExecuteOp 58 //===----------------------------------------------------------------------===// 59 60 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes"; 61 62 OperandRange ExecuteOp::getSuccessorEntryOperands(unsigned index) { 63 assert(index == 0 && "invalid region index"); 64 return operands(); 65 } 66 67 bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) { 68 const auto getValueOrTokenType = [](Type type) { 69 if (auto value = type.dyn_cast<ValueType>()) 70 return value.getValueType(); 71 return type; 72 }; 73 return getValueOrTokenType(lhs) == getValueOrTokenType(rhs); 74 } 75 76 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index, 77 ArrayRef<Attribute>, 78 SmallVectorImpl<RegionSuccessor> ®ions) { 79 // The `body` region branch back to the parent operation. 80 if (index.hasValue()) { 81 assert(*index == 0 && "invalid region index"); 82 regions.push_back(RegionSuccessor(results())); 83 return; 84 } 85 86 // Otherwise the successor is the body region. 87 regions.push_back(RegionSuccessor(&body(), body().getArguments())); 88 } 89 90 void ExecuteOp::build(OpBuilder &builder, OperationState &result, 91 TypeRange resultTypes, ValueRange dependencies, 92 ValueRange operands, BodyBuilderFn bodyBuilder) { 93 94 result.addOperands(dependencies); 95 result.addOperands(operands); 96 97 // Add derived `operand_segment_sizes` attribute based on parsed operands. 98 int32_t numDependencies = dependencies.size(); 99 int32_t numOperands = operands.size(); 100 auto operandSegmentSizes = DenseIntElementsAttr::get( 101 VectorType::get({2}, builder.getIntegerType(32)), 102 {numDependencies, numOperands}); 103 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); 104 105 // First result is always a token, and then `resultTypes` wrapped into 106 // `async.value`. 107 result.addTypes({TokenType::get(result.getContext())}); 108 for (Type type : resultTypes) 109 result.addTypes(ValueType::get(type)); 110 111 // Add a body region with block arguments as unwrapped async value operands. 112 Region *bodyRegion = result.addRegion(); 113 bodyRegion->push_back(new Block); 114 Block &bodyBlock = bodyRegion->front(); 115 for (Value operand : operands) { 116 auto valueType = operand.getType().dyn_cast<ValueType>(); 117 bodyBlock.addArgument(valueType ? valueType.getValueType() 118 : operand.getType(), 119 operand.getLoc()); 120 } 121 122 // Create the default terminator if the builder is not provided and if the 123 // expected result is empty. Otherwise, leave this to the caller 124 // because we don't know which values to return from the execute op. 125 if (resultTypes.empty() && !bodyBuilder) { 126 OpBuilder::InsertionGuard guard(builder); 127 builder.setInsertionPointToStart(&bodyBlock); 128 builder.create<async::YieldOp>(result.location, ValueRange()); 129 } else if (bodyBuilder) { 130 OpBuilder::InsertionGuard guard(builder); 131 builder.setInsertionPointToStart(&bodyBlock); 132 bodyBuilder(builder, result.location, bodyBlock.getArguments()); 133 } 134 } 135 136 void ExecuteOp::print(OpAsmPrinter &p) { 137 // [%tokens,...] 138 if (!dependencies().empty()) 139 p << " [" << dependencies() << "]"; 140 141 // (%value as %unwrapped: !async.value<!arg.type>, ...) 142 if (!operands().empty()) { 143 p << " ("; 144 Block *entry = body().empty() ? nullptr : &body().front(); 145 llvm::interleaveComma(operands(), p, [&, n = 0](Value operand) mutable { 146 Value argument = entry ? entry->getArgument(n++) : Value(); 147 p << operand << " as " << argument << ": " << operand.getType(); 148 }); 149 p << ")"; 150 } 151 152 // -> (!async.value<!return.type>, ...) 153 p.printOptionalArrowTypeList(llvm::drop_begin(getResultTypes())); 154 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), 155 {kOperandSegmentSizesAttr}); 156 p << ' '; 157 p.printRegion(body(), /*printEntryBlockArgs=*/false); 158 } 159 160 ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) { 161 MLIRContext *ctx = result.getContext(); 162 163 // Sizes of parsed variadic operands, will be updated below after parsing. 164 int32_t numDependencies = 0; 165 166 auto tokenTy = TokenType::get(ctx); 167 168 // Parse dependency tokens. 169 if (succeeded(parser.parseOptionalLSquare())) { 170 SmallVector<OpAsmParser::OperandType, 4> tokenArgs; 171 if (parser.parseOperandList(tokenArgs) || 172 parser.resolveOperands(tokenArgs, tokenTy, result.operands) || 173 parser.parseRSquare()) 174 return failure(); 175 176 numDependencies = tokenArgs.size(); 177 } 178 179 // Parse async value operands (%value as %unwrapped : !async.value<!type>). 180 SmallVector<OpAsmParser::OperandType, 4> valueArgs; 181 SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs; 182 SmallVector<Type, 4> valueTypes; 183 SmallVector<Type, 4> unwrappedTypes; 184 185 // Parse a single instance of `%value as %unwrapped : !async.value<!type>`. 186 auto parseAsyncValueArg = [&]() -> ParseResult { 187 if (parser.parseOperand(valueArgs.emplace_back()) || 188 parser.parseKeyword("as") || 189 parser.parseOperand(unwrappedArgs.emplace_back()) || 190 parser.parseColonType(valueTypes.emplace_back())) 191 return failure(); 192 193 auto valueTy = valueTypes.back().dyn_cast<ValueType>(); 194 unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type()); 195 196 return success(); 197 }; 198 199 auto argsLoc = parser.getCurrentLocation(); 200 if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::OptionalParen, 201 parseAsyncValueArg) || 202 parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands)) 203 return failure(); 204 205 int32_t numOperands = valueArgs.size(); 206 207 // Add derived `operand_segment_sizes` attribute based on parsed operands. 208 auto operandSegmentSizes = DenseIntElementsAttr::get( 209 VectorType::get({2}, parser.getBuilder().getI32Type()), 210 {numDependencies, numOperands}); 211 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); 212 213 // Parse the types of results returned from the async execute op. 214 SmallVector<Type, 4> resultTypes; 215 if (parser.parseOptionalArrowTypeList(resultTypes)) 216 return failure(); 217 218 // Async execute first result is always a completion token. 219 parser.addTypeToList(tokenTy, result.types); 220 parser.addTypesToList(resultTypes, result.types); 221 222 // Parse operation attributes. 223 NamedAttrList attrs; 224 if (parser.parseOptionalAttrDictWithKeyword(attrs)) 225 return failure(); 226 result.addAttributes(attrs); 227 228 // Parse asynchronous region. 229 Region *body = result.addRegion(); 230 if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs}, 231 /*argTypes=*/{unwrappedTypes}, 232 /*argLocations=*/{}, 233 /*enableNameShadowing=*/false)) 234 return failure(); 235 236 return success(); 237 } 238 239 LogicalResult ExecuteOp::verifyRegions() { 240 // Unwrap async.execute value operands types. 241 auto unwrappedTypes = llvm::map_range(operands(), [](Value operand) { 242 return operand.getType().cast<ValueType>().getValueType(); 243 }); 244 245 // Verify that unwrapped argument types matches the body region arguments. 246 if (body().getArgumentTypes() != unwrappedTypes) 247 return emitOpError("async body region argument types do not match the " 248 "execute operation arguments types"); 249 250 return success(); 251 } 252 253 //===----------------------------------------------------------------------===// 254 /// CreateGroupOp 255 //===----------------------------------------------------------------------===// 256 257 LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op, 258 PatternRewriter &rewriter) { 259 // Find all `await_all` users of the group. 260 llvm::SmallVector<AwaitAllOp> awaitAllUsers; 261 262 auto isAwaitAll = [&](Operation *op) -> bool { 263 if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) { 264 awaitAllUsers.push_back(awaitAll); 265 return true; 266 } 267 return false; 268 }; 269 270 // Check if all users of the group are `await_all` operations. 271 if (!llvm::all_of(op->getUsers(), isAwaitAll)) 272 return failure(); 273 274 // If group is only awaited without adding anything to it, we can safely erase 275 // the create operation and all users. 276 for (AwaitAllOp awaitAll : awaitAllUsers) 277 rewriter.eraseOp(awaitAll); 278 rewriter.eraseOp(op); 279 280 return success(); 281 } 282 283 //===----------------------------------------------------------------------===// 284 /// AwaitOp 285 //===----------------------------------------------------------------------===// 286 287 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand, 288 ArrayRef<NamedAttribute> attrs) { 289 result.addOperands({operand}); 290 result.attributes.append(attrs.begin(), attrs.end()); 291 292 // Add unwrapped async.value type to the returned values types. 293 if (auto valueType = operand.getType().dyn_cast<ValueType>()) 294 result.addTypes(valueType.getValueType()); 295 } 296 297 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, 298 Type &resultType) { 299 if (parser.parseType(operandType)) 300 return failure(); 301 302 // Add unwrapped async.value type to the returned values types. 303 if (auto valueType = operandType.dyn_cast<ValueType>()) 304 resultType = valueType.getValueType(); 305 306 return success(); 307 } 308 309 static void printAwaitResultType(OpAsmPrinter &p, Operation *op, 310 Type operandType, Type resultType) { 311 p << operandType; 312 } 313 314 LogicalResult AwaitOp::verify() { 315 Type argType = operand().getType(); 316 317 // Awaiting on a token does not have any results. 318 if (argType.isa<TokenType>() && !getResultTypes().empty()) 319 return emitOpError("awaiting on a token must have empty result"); 320 321 // Awaiting on a value unwraps the async value type. 322 if (auto value = argType.dyn_cast<ValueType>()) { 323 if (*getResultType() != value.getValueType()) 324 return emitOpError() << "result type " << *getResultType() 325 << " does not match async value type " 326 << value.getValueType(); 327 } 328 329 return success(); 330 } 331 332 //===----------------------------------------------------------------------===// 333 // TableGen'd op method definitions 334 //===----------------------------------------------------------------------===// 335 336 #define GET_OP_CLASSES 337 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 338 339 //===----------------------------------------------------------------------===// 340 // TableGen'd type method definitions 341 //===----------------------------------------------------------------------===// 342 343 #define GET_TYPEDEF_CLASSES 344 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" 345 346 void ValueType::print(AsmPrinter &printer) const { 347 printer << "<"; 348 printer.printType(getValueType()); 349 printer << '>'; 350 } 351 352 Type ValueType::parse(mlir::AsmParser &parser) { 353 Type ty; 354 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) { 355 parser.emitError(parser.getNameLoc(), "failed to parse async value type"); 356 return Type(); 357 } 358 return ValueType::get(ty); 359 } 360