1 //===- Async.cpp - MLIR Async Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Async/IR/Async.h" 10 11 #include "mlir/IR/DialectImplementation.h" 12 #include "llvm/ADT/TypeSwitch.h" 13 14 using namespace mlir; 15 using namespace mlir::async; 16 17 void AsyncDialect::initialize() { 18 addOperations< 19 #define GET_OP_LIST 20 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 21 >(); 22 addTypes< 23 #define GET_TYPEDEF_LIST 24 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" 25 >(); 26 } 27 28 //===----------------------------------------------------------------------===// 29 // YieldOp 30 //===----------------------------------------------------------------------===// 31 32 static LogicalResult verify(YieldOp op) { 33 // Get the underlying value types from async values returned from the 34 // parent `async.execute` operation. 35 auto executeOp = op->getParentOfType<ExecuteOp>(); 36 auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) { 37 return result.getType().cast<ValueType>().getValueType(); 38 }); 39 40 if (op.getOperandTypes() != types) 41 return op.emitOpError("operand types do not match the types returned from " 42 "the parent ExecuteOp"); 43 44 return success(); 45 } 46 47 //===----------------------------------------------------------------------===// 48 /// ExecuteOp 49 //===----------------------------------------------------------------------===// 50 51 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes"; 52 53 void ExecuteOp::getNumRegionInvocations( 54 ArrayRef<Attribute> operands, SmallVectorImpl<int64_t> &countPerRegion) { 55 (void)operands; 56 assert(countPerRegion.empty()); 57 countPerRegion.push_back(1); 58 } 59 60 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index, 61 ArrayRef<Attribute> operands, 62 SmallVectorImpl<RegionSuccessor> ®ions) { 63 // The `body` region branch back to the parent operation. 64 if (index.hasValue()) { 65 assert(*index == 0); 66 regions.push_back(RegionSuccessor(getResults())); 67 return; 68 } 69 70 // Otherwise the successor is the body region. 71 regions.push_back(RegionSuccessor(&body())); 72 } 73 74 void ExecuteOp::build(OpBuilder &builder, OperationState &result, 75 TypeRange resultTypes, ValueRange dependencies, 76 ValueRange operands, BodyBuilderFn bodyBuilder) { 77 78 result.addOperands(dependencies); 79 result.addOperands(operands); 80 81 // Add derived `operand_segment_sizes` attribute based on parsed operands. 82 int32_t numDependencies = dependencies.size(); 83 int32_t numOperands = operands.size(); 84 auto operandSegmentSizes = DenseIntElementsAttr::get( 85 VectorType::get({2}, builder.getIntegerType(32)), 86 {numDependencies, numOperands}); 87 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); 88 89 // First result is always a token, and then `resultTypes` wrapped into 90 // `async.value`. 91 result.addTypes({TokenType::get(result.getContext())}); 92 for (Type type : resultTypes) 93 result.addTypes(ValueType::get(type)); 94 95 // Add a body region with block arguments as unwrapped async value operands. 96 Region *bodyRegion = result.addRegion(); 97 bodyRegion->push_back(new Block); 98 Block &bodyBlock = bodyRegion->front(); 99 for (Value operand : operands) { 100 auto valueType = operand.getType().dyn_cast<ValueType>(); 101 bodyBlock.addArgument(valueType ? valueType.getValueType() 102 : operand.getType()); 103 } 104 105 // Create the default terminator if the builder is not provided and if the 106 // expected result is empty. Otherwise, leave this to the caller 107 // because we don't know which values to return from the execute op. 108 if (resultTypes.empty() && !bodyBuilder) { 109 OpBuilder::InsertionGuard guard(builder); 110 builder.setInsertionPointToStart(&bodyBlock); 111 builder.create<async::YieldOp>(result.location, ValueRange()); 112 } else if (bodyBuilder) { 113 OpBuilder::InsertionGuard guard(builder); 114 builder.setInsertionPointToStart(&bodyBlock); 115 bodyBuilder(builder, result.location, bodyBlock.getArguments()); 116 } 117 } 118 119 static void print(OpAsmPrinter &p, ExecuteOp op) { 120 p << op.getOperationName(); 121 122 // [%tokens,...] 123 if (!op.dependencies().empty()) 124 p << " [" << op.dependencies() << "]"; 125 126 // (%value as %unwrapped: !async.value<!arg.type>, ...) 127 if (!op.operands().empty()) { 128 p << " ("; 129 llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable { 130 p << operand << " as " << op.body().front().getArgument(n++) << ": " 131 << operand.getType(); 132 }); 133 p << ")"; 134 } 135 136 // -> (!async.value<!return.type>, ...) 137 p.printOptionalArrowTypeList(llvm::drop_begin(op.getResultTypes())); 138 p.printOptionalAttrDictWithKeyword(op->getAttrs(), 139 {kOperandSegmentSizesAttr}); 140 p.printRegion(op.body(), /*printEntryBlockArgs=*/false); 141 } 142 143 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) { 144 MLIRContext *ctx = result.getContext(); 145 146 // Sizes of parsed variadic operands, will be updated below after parsing. 147 int32_t numDependencies = 0; 148 int32_t numOperands = 0; 149 150 auto tokenTy = TokenType::get(ctx); 151 152 // Parse dependency tokens. 153 if (succeeded(parser.parseOptionalLSquare())) { 154 SmallVector<OpAsmParser::OperandType, 4> tokenArgs; 155 if (parser.parseOperandList(tokenArgs) || 156 parser.resolveOperands(tokenArgs, tokenTy, result.operands) || 157 parser.parseRSquare()) 158 return failure(); 159 160 numDependencies = tokenArgs.size(); 161 } 162 163 // Parse async value operands (%value as %unwrapped : !async.value<!type>). 164 SmallVector<OpAsmParser::OperandType, 4> valueArgs; 165 SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs; 166 SmallVector<Type, 4> valueTypes; 167 SmallVector<Type, 4> unwrappedTypes; 168 169 if (succeeded(parser.parseOptionalLParen())) { 170 auto argsLoc = parser.getCurrentLocation(); 171 172 // Parse a single instance of `%value as %unwrapped : !async.value<!type>`. 173 auto parseAsyncValueArg = [&]() -> ParseResult { 174 if (parser.parseOperand(valueArgs.emplace_back()) || 175 parser.parseKeyword("as") || 176 parser.parseOperand(unwrappedArgs.emplace_back()) || 177 parser.parseColonType(valueTypes.emplace_back())) 178 return failure(); 179 180 auto valueTy = valueTypes.back().dyn_cast<ValueType>(); 181 unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type()); 182 183 return success(); 184 }; 185 186 // If the next token is `)` skip async value arguments parsing. 187 if (failed(parser.parseOptionalRParen())) { 188 do { 189 if (parseAsyncValueArg()) 190 return failure(); 191 } while (succeeded(parser.parseOptionalComma())); 192 193 if (parser.parseRParen() || 194 parser.resolveOperands(valueArgs, valueTypes, argsLoc, 195 result.operands)) 196 return failure(); 197 } 198 199 numOperands = valueArgs.size(); 200 } 201 202 // Add derived `operand_segment_sizes` attribute based on parsed operands. 203 auto operandSegmentSizes = DenseIntElementsAttr::get( 204 VectorType::get({2}, parser.getBuilder().getI32Type()), 205 {numDependencies, numOperands}); 206 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); 207 208 // Parse the types of results returned from the async execute op. 209 SmallVector<Type, 4> resultTypes; 210 if (parser.parseOptionalArrowTypeList(resultTypes)) 211 return failure(); 212 213 // Async execute first result is always a completion token. 214 parser.addTypeToList(tokenTy, result.types); 215 parser.addTypesToList(resultTypes, result.types); 216 217 // Parse operation attributes. 218 NamedAttrList attrs; 219 if (parser.parseOptionalAttrDictWithKeyword(attrs)) 220 return failure(); 221 result.addAttributes(attrs); 222 223 // Parse asynchronous region. 224 Region *body = result.addRegion(); 225 if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs}, 226 /*argTypes=*/{unwrappedTypes}, 227 /*enableNameShadowing=*/false)) 228 return failure(); 229 230 return success(); 231 } 232 233 static LogicalResult verify(ExecuteOp op) { 234 // Unwrap async.execute value operands types. 235 auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) { 236 return operand.getType().cast<ValueType>().getValueType(); 237 }); 238 239 // Verify that unwrapped argument types matches the body region arguments. 240 if (op.body().getArgumentTypes() != unwrappedTypes) 241 return op.emitOpError("async body region argument types do not match the " 242 "execute operation arguments types"); 243 244 return success(); 245 } 246 247 //===----------------------------------------------------------------------===// 248 /// AwaitOp 249 //===----------------------------------------------------------------------===// 250 251 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand, 252 ArrayRef<NamedAttribute> attrs) { 253 result.addOperands({operand}); 254 result.attributes.append(attrs.begin(), attrs.end()); 255 256 // Add unwrapped async.value type to the returned values types. 257 if (auto valueType = operand.getType().dyn_cast<ValueType>()) 258 result.addTypes(valueType.getValueType()); 259 } 260 261 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, 262 Type &resultType) { 263 if (parser.parseType(operandType)) 264 return failure(); 265 266 // Add unwrapped async.value type to the returned values types. 267 if (auto valueType = operandType.dyn_cast<ValueType>()) 268 resultType = valueType.getValueType(); 269 270 return success(); 271 } 272 273 static void printAwaitResultType(OpAsmPrinter &p, Operation *op, 274 Type operandType, Type resultType) { 275 p << operandType; 276 } 277 278 static LogicalResult verify(AwaitOp op) { 279 Type argType = op.operand().getType(); 280 281 // Awaiting on a token does not have any results. 282 if (argType.isa<TokenType>() && !op.getResultTypes().empty()) 283 return op.emitOpError("awaiting on a token must have empty result"); 284 285 // Awaiting on a value unwraps the async value type. 286 if (auto value = argType.dyn_cast<ValueType>()) { 287 if (*op.getResultType() != value.getValueType()) 288 return op.emitOpError() 289 << "result type " << *op.getResultType() 290 << " does not match async value type " << value.getValueType(); 291 } 292 293 return success(); 294 } 295 296 //===----------------------------------------------------------------------===// 297 // TableGen'd op method definitions 298 //===----------------------------------------------------------------------===// 299 300 #define GET_OP_CLASSES 301 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 302 303 //===----------------------------------------------------------------------===// 304 // TableGen'd type method definitions 305 //===----------------------------------------------------------------------===// 306 307 #define GET_TYPEDEF_CLASSES 308 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" 309 310 void ValueType::print(DialectAsmPrinter &printer) const { 311 printer << getMnemonic(); 312 printer << "<"; 313 printer.printType(getValueType()); 314 printer << '>'; 315 } 316 317 Type ValueType::parse(mlir::MLIRContext *, mlir::DialectAsmParser &parser) { 318 Type ty; 319 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) { 320 parser.emitError(parser.getNameLoc(), "failed to parse async value type"); 321 return Type(); 322 } 323 return ValueType::get(ty); 324 } 325 326 /// Print a type registered to this dialect. 327 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const { 328 if (failed(generatedTypePrinter(type, os))) 329 llvm_unreachable("unexpected 'async' type kind"); 330 } 331 332 /// Parse a type registered to this dialect. 333 Type AsyncDialect::parseType(DialectAsmParser &parser) const { 334 StringRef typeTag; 335 if (parser.parseKeyword(&typeTag)) 336 return Type(); 337 Type genType; 338 auto parseResult = generatedTypeParser(parser.getBuilder().getContext(), 339 parser, typeTag, genType); 340 if (parseResult.hasValue()) 341 return genType; 342 parser.emitError(parser.getNameLoc(), "unknown async type: ") << typeTag; 343 return {}; 344 } 345