1 //===- Async.cpp - MLIR Async Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Async/IR/Async.h" 10 11 #include "mlir/IR/DialectImplementation.h" 12 #include "llvm/ADT/TypeSwitch.h" 13 14 using namespace mlir; 15 using namespace mlir::async; 16 17 void AsyncDialect::initialize() { 18 addOperations< 19 #define GET_OP_LIST 20 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 21 >(); 22 addTypes< 23 #define GET_TYPEDEF_LIST 24 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" 25 >(); 26 } 27 28 //===----------------------------------------------------------------------===// 29 // YieldOp 30 //===----------------------------------------------------------------------===// 31 32 static LogicalResult verify(YieldOp op) { 33 // Get the underlying value types from async values returned from the 34 // parent `async.execute` operation. 35 auto executeOp = op->getParentOfType<ExecuteOp>(); 36 auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) { 37 return result.getType().cast<ValueType>().getValueType(); 38 }); 39 40 if (op.getOperandTypes() != types) 41 return op.emitOpError("operand types do not match the types returned from " 42 "the parent ExecuteOp"); 43 44 return success(); 45 } 46 47 //===----------------------------------------------------------------------===// 48 /// ExecuteOp 49 //===----------------------------------------------------------------------===// 50 51 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes"; 52 53 void ExecuteOp::getNumRegionInvocations( 54 ArrayRef<Attribute> operands, SmallVectorImpl<int64_t> &countPerRegion) { 55 (void)operands; 56 assert(countPerRegion.empty()); 57 countPerRegion.push_back(1); 58 } 59 60 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index, 61 ArrayRef<Attribute> operands, 62 SmallVectorImpl<RegionSuccessor> ®ions) { 63 // The `body` region branch back to the parent operation. 64 if (index.hasValue()) { 65 assert(*index == 0); 66 regions.push_back(RegionSuccessor(getResults())); 67 return; 68 } 69 70 // Otherwise the successor is the body region. 71 regions.push_back(RegionSuccessor(&body())); 72 } 73 74 void ExecuteOp::build(OpBuilder &builder, OperationState &result, 75 TypeRange resultTypes, ValueRange dependencies, 76 ValueRange operands, BodyBuilderFn bodyBuilder) { 77 78 result.addOperands(dependencies); 79 result.addOperands(operands); 80 81 // Add derived `operand_segment_sizes` attribute based on parsed operands. 82 int32_t numDependencies = dependencies.size(); 83 int32_t numOperands = operands.size(); 84 auto operandSegmentSizes = DenseIntElementsAttr::get( 85 VectorType::get({2}, builder.getIntegerType(32)), 86 {numDependencies, numOperands}); 87 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); 88 89 // First result is always a token, and then `resultTypes` wrapped into 90 // `async.value`. 91 result.addTypes({TokenType::get(result.getContext())}); 92 for (Type type : resultTypes) 93 result.addTypes(ValueType::get(type)); 94 95 // Add a body region with block arguments as unwrapped async value operands. 96 Region *bodyRegion = result.addRegion(); 97 bodyRegion->push_back(new Block); 98 Block &bodyBlock = bodyRegion->front(); 99 for (Value operand : operands) { 100 auto valueType = operand.getType().dyn_cast<ValueType>(); 101 bodyBlock.addArgument(valueType ? valueType.getValueType() 102 : operand.getType()); 103 } 104 105 // Create the default terminator if the builder is not provided and if the 106 // expected result is empty. Otherwise, leave this to the caller 107 // because we don't know which values to return from the execute op. 108 if (resultTypes.empty() && !bodyBuilder) { 109 OpBuilder::InsertionGuard guard(builder); 110 builder.setInsertionPointToStart(&bodyBlock); 111 builder.create<async::YieldOp>(result.location, ValueRange()); 112 } else if (bodyBuilder) { 113 OpBuilder::InsertionGuard guard(builder); 114 builder.setInsertionPointToStart(&bodyBlock); 115 bodyBuilder(builder, result.location, bodyBlock.getArguments()); 116 } 117 } 118 119 static void print(OpAsmPrinter &p, ExecuteOp op) { 120 p << op.getOperationName(); 121 122 // [%tokens,...] 123 if (!op.dependencies().empty()) 124 p << " [" << op.dependencies() << "]"; 125 126 // (%value as %unwrapped: !async.value<!arg.type>, ...) 127 if (!op.operands().empty()) { 128 p << " ("; 129 llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable { 130 p << operand << " as " << op.body().front().getArgument(n++) << ": " 131 << operand.getType(); 132 }); 133 p << ")"; 134 } 135 136 // -> (!async.value<!return.type>, ...) 137 p.printOptionalArrowTypeList(op.getResultTypes().drop_front(1)); 138 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {kOperandSegmentSizesAttr}); 139 p.printRegion(op.body(), /*printEntryBlockArgs=*/false); 140 } 141 142 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) { 143 MLIRContext *ctx = result.getContext(); 144 145 // Sizes of parsed variadic operands, will be updated below after parsing. 146 int32_t numDependencies = 0; 147 int32_t numOperands = 0; 148 149 auto tokenTy = TokenType::get(ctx); 150 151 // Parse dependency tokens. 152 if (succeeded(parser.parseOptionalLSquare())) { 153 SmallVector<OpAsmParser::OperandType, 4> tokenArgs; 154 if (parser.parseOperandList(tokenArgs) || 155 parser.resolveOperands(tokenArgs, tokenTy, result.operands) || 156 parser.parseRSquare()) 157 return failure(); 158 159 numDependencies = tokenArgs.size(); 160 } 161 162 // Parse async value operands (%value as %unwrapped : !async.value<!type>). 163 SmallVector<OpAsmParser::OperandType, 4> valueArgs; 164 SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs; 165 SmallVector<Type, 4> valueTypes; 166 SmallVector<Type, 4> unwrappedTypes; 167 168 if (succeeded(parser.parseOptionalLParen())) { 169 auto argsLoc = parser.getCurrentLocation(); 170 171 // Parse a single instance of `%value as %unwrapped : !async.value<!type>`. 172 auto parseAsyncValueArg = [&]() -> ParseResult { 173 if (parser.parseOperand(valueArgs.emplace_back()) || 174 parser.parseKeyword("as") || 175 parser.parseOperand(unwrappedArgs.emplace_back()) || 176 parser.parseColonType(valueTypes.emplace_back())) 177 return failure(); 178 179 auto valueTy = valueTypes.back().dyn_cast<ValueType>(); 180 unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type()); 181 182 return success(); 183 }; 184 185 // If the next token is `)` skip async value arguments parsing. 186 if (failed(parser.parseOptionalRParen())) { 187 do { 188 if (parseAsyncValueArg()) 189 return failure(); 190 } while (succeeded(parser.parseOptionalComma())); 191 192 if (parser.parseRParen() || 193 parser.resolveOperands(valueArgs, valueTypes, argsLoc, 194 result.operands)) 195 return failure(); 196 } 197 198 numOperands = valueArgs.size(); 199 } 200 201 // Add derived `operand_segment_sizes` attribute based on parsed operands. 202 auto operandSegmentSizes = DenseIntElementsAttr::get( 203 VectorType::get({2}, parser.getBuilder().getI32Type()), 204 {numDependencies, numOperands}); 205 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); 206 207 // Parse the types of results returned from the async execute op. 208 SmallVector<Type, 4> resultTypes; 209 if (parser.parseOptionalArrowTypeList(resultTypes)) 210 return failure(); 211 212 // Async execute first result is always a completion token. 213 parser.addTypeToList(tokenTy, result.types); 214 parser.addTypesToList(resultTypes, result.types); 215 216 // Parse operation attributes. 217 NamedAttrList attrs; 218 if (parser.parseOptionalAttrDictWithKeyword(attrs)) 219 return failure(); 220 result.addAttributes(attrs); 221 222 // Parse asynchronous region. 223 Region *body = result.addRegion(); 224 if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs}, 225 /*argTypes=*/{unwrappedTypes}, 226 /*enableNameShadowing=*/false)) 227 return failure(); 228 229 return success(); 230 } 231 232 static LogicalResult verify(ExecuteOp op) { 233 // Unwrap async.execute value operands types. 234 auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) { 235 return operand.getType().cast<ValueType>().getValueType(); 236 }); 237 238 // Verify that unwrapped argument types matches the body region arguments. 239 if (op.body().getArgumentTypes() != unwrappedTypes) 240 return op.emitOpError("async body region argument types do not match the " 241 "execute operation arguments types"); 242 243 return success(); 244 } 245 246 //===----------------------------------------------------------------------===// 247 /// AwaitOp 248 //===----------------------------------------------------------------------===// 249 250 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand, 251 ArrayRef<NamedAttribute> attrs) { 252 result.addOperands({operand}); 253 result.attributes.append(attrs.begin(), attrs.end()); 254 255 // Add unwrapped async.value type to the returned values types. 256 if (auto valueType = operand.getType().dyn_cast<ValueType>()) 257 result.addTypes(valueType.getValueType()); 258 } 259 260 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, 261 Type &resultType) { 262 if (parser.parseType(operandType)) 263 return failure(); 264 265 // Add unwrapped async.value type to the returned values types. 266 if (auto valueType = operandType.dyn_cast<ValueType>()) 267 resultType = valueType.getValueType(); 268 269 return success(); 270 } 271 272 static void printAwaitResultType(OpAsmPrinter &p, Operation *op, 273 Type operandType, Type resultType) { 274 p << operandType; 275 } 276 277 static LogicalResult verify(AwaitOp op) { 278 Type argType = op.operand().getType(); 279 280 // Awaiting on a token does not have any results. 281 if (argType.isa<TokenType>() && !op.getResultTypes().empty()) 282 return op.emitOpError("awaiting on a token must have empty result"); 283 284 // Awaiting on a value unwraps the async value type. 285 if (auto value = argType.dyn_cast<ValueType>()) { 286 if (*op.getResultType() != value.getValueType()) 287 return op.emitOpError() 288 << "result type " << *op.getResultType() 289 << " does not match async value type " << value.getValueType(); 290 } 291 292 return success(); 293 } 294 295 //===----------------------------------------------------------------------===// 296 // TableGen'd op method definitions 297 //===----------------------------------------------------------------------===// 298 299 #define GET_OP_CLASSES 300 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 301 302 //===----------------------------------------------------------------------===// 303 // TableGen'd type method definitions 304 //===----------------------------------------------------------------------===// 305 306 #define GET_TYPEDEF_CLASSES 307 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" 308 309 void ValueType::print(DialectAsmPrinter &printer) const { 310 printer << getMnemonic(); 311 printer << "<"; 312 printer.printType(getValueType()); 313 printer << '>'; 314 } 315 316 Type ValueType::parse(mlir::MLIRContext *, mlir::DialectAsmParser &parser) { 317 Type ty; 318 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) { 319 parser.emitError(parser.getNameLoc(), "failed to parse async value type"); 320 return Type(); 321 } 322 return ValueType::get(ty); 323 } 324 325 /// Print a type registered to this dialect. 326 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const { 327 if (failed(generatedTypePrinter(type, os))) 328 llvm_unreachable("unexpected 'async' type kind"); 329 } 330 331 /// Parse a type registered to this dialect. 332 Type AsyncDialect::parseType(DialectAsmParser &parser) const { 333 StringRef mnemonic; 334 if (parser.parseKeyword(&mnemonic)) 335 return Type(); 336 337 return generatedTypeParser(getContext(), parser, mnemonic); 338 } 339