1 //===- Async.cpp - MLIR Async Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Async/IR/Async.h" 10 11 #include "mlir/IR/DialectImplementation.h" 12 #include "llvm/ADT/TypeSwitch.h" 13 14 using namespace mlir; 15 using namespace mlir::async; 16 17 void AsyncDialect::initialize() { 18 addOperations< 19 #define GET_OP_LIST 20 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 21 >(); 22 addTypes<TokenType>(); 23 addTypes<ValueType>(); 24 addTypes<GroupType>(); 25 } 26 27 /// Parse a type registered to this dialect. 28 Type AsyncDialect::parseType(DialectAsmParser &parser) const { 29 StringRef keyword; 30 if (parser.parseKeyword(&keyword)) 31 return Type(); 32 33 if (keyword == "token") 34 return TokenType::get(getContext()); 35 36 if (keyword == "value") { 37 Type ty; 38 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) { 39 parser.emitError(parser.getNameLoc(), "failed to parse async value type"); 40 return Type(); 41 } 42 return ValueType::get(ty); 43 } 44 45 parser.emitError(parser.getNameLoc(), "unknown async type: ") << keyword; 46 return Type(); 47 } 48 49 /// Print a type registered to this dialect. 50 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const { 51 TypeSwitch<Type>(type) 52 .Case<TokenType>([&](TokenType) { os << "token"; }) 53 .Case<ValueType>([&](ValueType valueTy) { 54 os << "value<"; 55 os.printType(valueTy.getValueType()); 56 os << '>'; 57 }) 58 .Case<GroupType>([&](GroupType) { os << "group"; }) 59 .Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); }); 60 } 61 62 //===----------------------------------------------------------------------===// 63 /// ValueType 64 //===----------------------------------------------------------------------===// 65 66 namespace mlir { 67 namespace async { 68 namespace detail { 69 70 // Storage for `async.value<T>` type, the only member is the wrapped type. 71 struct ValueTypeStorage : public TypeStorage { 72 ValueTypeStorage(Type valueType) : valueType(valueType) {} 73 74 /// The hash key used for uniquing. 75 using KeyTy = Type; 76 bool operator==(const KeyTy &key) const { return key == valueType; } 77 78 /// Construction. 79 static ValueTypeStorage *construct(TypeStorageAllocator &allocator, 80 Type valueType) { 81 return new (allocator.allocate<ValueTypeStorage>()) 82 ValueTypeStorage(valueType); 83 } 84 85 Type valueType; 86 }; 87 88 } // namespace detail 89 } // namespace async 90 } // namespace mlir 91 92 ValueType ValueType::get(Type valueType) { 93 return Base::get(valueType.getContext(), valueType); 94 } 95 96 Type ValueType::getValueType() { return getImpl()->valueType; } 97 98 //===----------------------------------------------------------------------===// 99 // YieldOp 100 //===----------------------------------------------------------------------===// 101 102 static LogicalResult verify(YieldOp op) { 103 // Get the underlying value types from async values returned from the 104 // parent `async.execute` operation. 105 auto executeOp = op->getParentOfType<ExecuteOp>(); 106 auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) { 107 return result.getType().cast<ValueType>().getValueType(); 108 }); 109 110 if (op.getOperandTypes() != types) 111 return op.emitOpError("operand types do not match the types returned from " 112 "the parent ExecuteOp"); 113 114 return success(); 115 } 116 117 //===----------------------------------------------------------------------===// 118 /// ExecuteOp 119 //===----------------------------------------------------------------------===// 120 121 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes"; 122 123 void ExecuteOp::getNumRegionInvocations( 124 ArrayRef<Attribute> operands, SmallVectorImpl<int64_t> &countPerRegion) { 125 (void)operands; 126 assert(countPerRegion.empty()); 127 countPerRegion.push_back(1); 128 } 129 130 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index, 131 ArrayRef<Attribute> operands, 132 SmallVectorImpl<RegionSuccessor> ®ions) { 133 // The `body` region branch back to the parent operation. 134 if (index.hasValue()) { 135 assert(*index == 0); 136 regions.push_back(RegionSuccessor(getResults())); 137 return; 138 } 139 140 // Otherwise the successor is the body region. 141 regions.push_back(RegionSuccessor(&body())); 142 } 143 144 void ExecuteOp::build(OpBuilder &builder, OperationState &result, 145 TypeRange resultTypes, ValueRange dependencies, 146 ValueRange operands, BodyBuilderFn bodyBuilder) { 147 148 result.addOperands(dependencies); 149 result.addOperands(operands); 150 151 // Add derived `operand_segment_sizes` attribute based on parsed operands. 152 int32_t numDependencies = dependencies.size(); 153 int32_t numOperands = operands.size(); 154 auto operandSegmentSizes = DenseIntElementsAttr::get( 155 VectorType::get({2}, builder.getIntegerType(32)), 156 {numDependencies, numOperands}); 157 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); 158 159 // First result is always a token, and then `resultTypes` wrapped into 160 // `async.value`. 161 result.addTypes({TokenType::get(result.getContext())}); 162 for (Type type : resultTypes) 163 result.addTypes(ValueType::get(type)); 164 165 // Add a body region with block arguments as unwrapped async value operands. 166 Region *bodyRegion = result.addRegion(); 167 bodyRegion->push_back(new Block); 168 Block &bodyBlock = bodyRegion->front(); 169 for (Value operand : operands) { 170 auto valueType = operand.getType().dyn_cast<ValueType>(); 171 bodyBlock.addArgument(valueType ? valueType.getValueType() 172 : operand.getType()); 173 } 174 175 // Create the default terminator if the builder is not provided and if the 176 // expected result is empty. Otherwise, leave this to the caller 177 // because we don't know which values to return from the execute op. 178 if (resultTypes.empty() && !bodyBuilder) { 179 OpBuilder::InsertionGuard guard(builder); 180 builder.setInsertionPointToStart(&bodyBlock); 181 builder.create<async::YieldOp>(result.location, ValueRange()); 182 } else if (bodyBuilder) { 183 OpBuilder::InsertionGuard guard(builder); 184 builder.setInsertionPointToStart(&bodyBlock); 185 bodyBuilder(builder, result.location, bodyBlock.getArguments()); 186 } 187 } 188 189 static void print(OpAsmPrinter &p, ExecuteOp op) { 190 p << op.getOperationName(); 191 192 // [%tokens,...] 193 if (!op.dependencies().empty()) 194 p << " [" << op.dependencies() << "]"; 195 196 // (%value as %unwrapped: !async.value<!arg.type>, ...) 197 if (!op.operands().empty()) { 198 p << " ("; 199 llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable { 200 p << operand << " as " << op.body().front().getArgument(n++) << ": " 201 << operand.getType(); 202 }); 203 p << ")"; 204 } 205 206 // -> (!async.value<!return.type>, ...) 207 p.printOptionalArrowTypeList(op.getResultTypes().drop_front(1)); 208 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {kOperandSegmentSizesAttr}); 209 p.printRegion(op.body(), /*printEntryBlockArgs=*/false); 210 } 211 212 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) { 213 MLIRContext *ctx = result.getContext(); 214 215 // Sizes of parsed variadic operands, will be updated below after parsing. 216 int32_t numDependencies = 0; 217 int32_t numOperands = 0; 218 219 auto tokenTy = TokenType::get(ctx); 220 221 // Parse dependency tokens. 222 if (succeeded(parser.parseOptionalLSquare())) { 223 SmallVector<OpAsmParser::OperandType, 4> tokenArgs; 224 if (parser.parseOperandList(tokenArgs) || 225 parser.resolveOperands(tokenArgs, tokenTy, result.operands) || 226 parser.parseRSquare()) 227 return failure(); 228 229 numDependencies = tokenArgs.size(); 230 } 231 232 // Parse async value operands (%value as %unwrapped : !async.value<!type>). 233 SmallVector<OpAsmParser::OperandType, 4> valueArgs; 234 SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs; 235 SmallVector<Type, 4> valueTypes; 236 SmallVector<Type, 4> unwrappedTypes; 237 238 if (succeeded(parser.parseOptionalLParen())) { 239 auto argsLoc = parser.getCurrentLocation(); 240 241 // Parse a single instance of `%value as %unwrapped : !async.value<!type>`. 242 auto parseAsyncValueArg = [&]() -> ParseResult { 243 if (parser.parseOperand(valueArgs.emplace_back()) || 244 parser.parseKeyword("as") || 245 parser.parseOperand(unwrappedArgs.emplace_back()) || 246 parser.parseColonType(valueTypes.emplace_back())) 247 return failure(); 248 249 auto valueTy = valueTypes.back().dyn_cast<ValueType>(); 250 unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type()); 251 252 return success(); 253 }; 254 255 // If the next token is `)` skip async value arguments parsing. 256 if (failed(parser.parseOptionalRParen())) { 257 do { 258 if (parseAsyncValueArg()) 259 return failure(); 260 } while (succeeded(parser.parseOptionalComma())); 261 262 if (parser.parseRParen() || 263 parser.resolveOperands(valueArgs, valueTypes, argsLoc, 264 result.operands)) 265 return failure(); 266 } 267 268 numOperands = valueArgs.size(); 269 } 270 271 // Add derived `operand_segment_sizes` attribute based on parsed operands. 272 auto operandSegmentSizes = DenseIntElementsAttr::get( 273 VectorType::get({2}, parser.getBuilder().getI32Type()), 274 {numDependencies, numOperands}); 275 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); 276 277 // Parse the types of results returned from the async execute op. 278 SmallVector<Type, 4> resultTypes; 279 if (parser.parseOptionalArrowTypeList(resultTypes)) 280 return failure(); 281 282 // Async execute first result is always a completion token. 283 parser.addTypeToList(tokenTy, result.types); 284 parser.addTypesToList(resultTypes, result.types); 285 286 // Parse operation attributes. 287 NamedAttrList attrs; 288 if (parser.parseOptionalAttrDictWithKeyword(attrs)) 289 return failure(); 290 result.addAttributes(attrs); 291 292 // Parse asynchronous region. 293 Region *body = result.addRegion(); 294 if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs}, 295 /*argTypes=*/{unwrappedTypes}, 296 /*enableNameShadowing=*/false)) 297 return failure(); 298 299 return success(); 300 } 301 302 static LogicalResult verify(ExecuteOp op) { 303 // Unwrap async.execute value operands types. 304 auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) { 305 return operand.getType().cast<ValueType>().getValueType(); 306 }); 307 308 // Verify that unwrapped argument types matches the body region arguments. 309 if (op.body().getArgumentTypes() != unwrappedTypes) 310 return op.emitOpError("async body region argument types do not match the " 311 "execute operation arguments types"); 312 313 return success(); 314 } 315 316 //===----------------------------------------------------------------------===// 317 /// AwaitOp 318 //===----------------------------------------------------------------------===// 319 320 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand, 321 ArrayRef<NamedAttribute> attrs) { 322 result.addOperands({operand}); 323 result.attributes.append(attrs.begin(), attrs.end()); 324 325 // Add unwrapped async.value type to the returned values types. 326 if (auto valueType = operand.getType().dyn_cast<ValueType>()) 327 result.addTypes(valueType.getValueType()); 328 } 329 330 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, 331 Type &resultType) { 332 if (parser.parseType(operandType)) 333 return failure(); 334 335 // Add unwrapped async.value type to the returned values types. 336 if (auto valueType = operandType.dyn_cast<ValueType>()) 337 resultType = valueType.getValueType(); 338 339 return success(); 340 } 341 342 static void printAwaitResultType(OpAsmPrinter &p, Operation *op, 343 Type operandType, Type resultType) { 344 p << operandType; 345 } 346 347 static LogicalResult verify(AwaitOp op) { 348 Type argType = op.operand().getType(); 349 350 // Awaiting on a token does not have any results. 351 if (argType.isa<TokenType>() && !op.getResultTypes().empty()) 352 return op.emitOpError("awaiting on a token must have empty result"); 353 354 // Awaiting on a value unwraps the async value type. 355 if (auto value = argType.dyn_cast<ValueType>()) { 356 if (*op.getResultType() != value.getValueType()) 357 return op.emitOpError() 358 << "result type " << *op.getResultType() 359 << " does not match async value type " << value.getValueType(); 360 } 361 362 return success(); 363 } 364 365 #define GET_OP_CLASSES 366 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 367