1 //===- Async.cpp - MLIR Async Operations ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Async/IR/Async.h" 10 11 #include "mlir/IR/DialectImplementation.h" 12 #include "llvm/ADT/TypeSwitch.h" 13 14 using namespace mlir; 15 using namespace mlir::async; 16 17 void AsyncDialect::initialize() { 18 addOperations< 19 #define GET_OP_LIST 20 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 21 >(); 22 addTypes<TokenType>(); 23 addTypes<ValueType>(); 24 } 25 26 /// Parse a type registered to this dialect. 27 Type AsyncDialect::parseType(DialectAsmParser &parser) const { 28 StringRef keyword; 29 if (parser.parseKeyword(&keyword)) 30 return Type(); 31 32 if (keyword == "token") 33 return TokenType::get(getContext()); 34 35 if (keyword == "value") { 36 Type ty; 37 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) { 38 parser.emitError(parser.getNameLoc(), "failed to parse async value type"); 39 return Type(); 40 } 41 return ValueType::get(ty); 42 } 43 44 parser.emitError(parser.getNameLoc(), "unknown async type: ") << keyword; 45 return Type(); 46 } 47 48 /// Print a type registered to this dialect. 49 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const { 50 TypeSwitch<Type>(type) 51 .Case<TokenType>([&](TokenType) { os << "token"; }) 52 .Case<ValueType>([&](ValueType valueTy) { 53 os << "value<"; 54 os.printType(valueTy.getValueType()); 55 os << '>'; 56 }) 57 .Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); }); 58 } 59 60 //===----------------------------------------------------------------------===// 61 /// ValueType 62 //===----------------------------------------------------------------------===// 63 64 namespace mlir { 65 namespace async { 66 namespace detail { 67 68 // Storage for `async.value<T>` type, the only member is the wrapped type. 69 struct ValueTypeStorage : public TypeStorage { 70 ValueTypeStorage(Type valueType) : valueType(valueType) {} 71 72 /// The hash key used for uniquing. 73 using KeyTy = Type; 74 bool operator==(const KeyTy &key) const { return key == valueType; } 75 76 /// Construction. 77 static ValueTypeStorage *construct(TypeStorageAllocator &allocator, 78 Type valueType) { 79 return new (allocator.allocate<ValueTypeStorage>()) 80 ValueTypeStorage(valueType); 81 } 82 83 Type valueType; 84 }; 85 86 } // namespace detail 87 } // namespace async 88 } // namespace mlir 89 90 ValueType ValueType::get(Type valueType) { 91 return Base::get(valueType.getContext(), valueType); 92 } 93 94 Type ValueType::getValueType() { return getImpl()->valueType; } 95 96 //===----------------------------------------------------------------------===// 97 // YieldOp 98 //===----------------------------------------------------------------------===// 99 100 static LogicalResult verify(YieldOp op) { 101 // Get the underlying value types from async values returned from the 102 // parent `async.execute` operation. 103 auto executeOp = op.getParentOfType<ExecuteOp>(); 104 auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) { 105 return result.getType().cast<ValueType>().getValueType(); 106 }); 107 108 if (!std::equal(types.begin(), types.end(), op.getOperandTypes().begin())) 109 return op.emitOpError("Operand types do not match the types returned from " 110 "the parent ExecuteOp"); 111 112 return success(); 113 } 114 115 //===----------------------------------------------------------------------===// 116 /// ExecuteOp 117 //===----------------------------------------------------------------------===// 118 119 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes"; 120 121 static void print(OpAsmPrinter &p, ExecuteOp op) { 122 p << op.getOperationName(); 123 124 // [%tokens,...] 125 if (!op.dependencies().empty()) 126 p << " [" << op.dependencies() << "]"; 127 128 // (%value as %unwrapped: !async.value<!arg.type>, ...) 129 if (!op.operands().empty()) { 130 p << " ("; 131 llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable { 132 p << operand << " as " << op.body().front().getArgument(n++) << ": " 133 << operand.getType(); 134 }); 135 p << ")"; 136 } 137 138 // -> (!async.value<!return.type>, ...) 139 p.printOptionalArrowTypeList(op.getResultTypes().drop_front(1)); 140 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {kOperandSegmentSizesAttr}); 141 p.printRegion(op.body(), /*printEntryBlockArgs=*/false); 142 } 143 144 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) { 145 MLIRContext *ctx = result.getContext(); 146 147 // Sizes of parsed variadic operands, will be updated below after parsing. 148 int32_t numDependencies = 0; 149 int32_t numOperands = 0; 150 151 auto tokenTy = TokenType::get(ctx); 152 153 // Parse dependency tokens. 154 if (succeeded(parser.parseOptionalLSquare())) { 155 SmallVector<OpAsmParser::OperandType, 4> tokenArgs; 156 if (parser.parseOperandList(tokenArgs) || 157 parser.resolveOperands(tokenArgs, tokenTy, result.operands) || 158 parser.parseRSquare()) 159 return failure(); 160 161 numDependencies = tokenArgs.size(); 162 } 163 164 // Parse async value operands (%value as %unwrapped : !async.value<!type>). 165 SmallVector<OpAsmParser::OperandType, 4> valueArgs; 166 SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs; 167 SmallVector<Type, 4> valueTypes; 168 SmallVector<Type, 4> unwrappedTypes; 169 170 if (succeeded(parser.parseOptionalLParen())) { 171 auto argsLoc = parser.getCurrentLocation(); 172 173 // Parse a single instance of `%value as %unwrapped : !async.value<!type>`. 174 auto parseAsyncValueArg = [&]() -> ParseResult { 175 if (parser.parseOperand(valueArgs.emplace_back()) || 176 parser.parseKeyword("as") || 177 parser.parseOperand(unwrappedArgs.emplace_back()) || 178 parser.parseColonType(valueTypes.emplace_back())) 179 return failure(); 180 181 auto valueTy = valueTypes.back().dyn_cast<ValueType>(); 182 unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type()); 183 184 return success(); 185 }; 186 187 // If the next token is `)` skip async value arguments parsing. 188 if (failed(parser.parseOptionalRParen())) { 189 do { 190 if (parseAsyncValueArg()) 191 return failure(); 192 } while (succeeded(parser.parseOptionalComma())); 193 194 if (parser.parseRParen() || 195 parser.resolveOperands(valueArgs, valueTypes, argsLoc, 196 result.operands)) 197 return failure(); 198 } 199 200 numOperands = valueArgs.size(); 201 } 202 203 // Add derived `operand_segment_sizes` attribute based on parsed operands. 204 auto operandSegmentSizes = DenseIntElementsAttr::get( 205 VectorType::get({2}, parser.getBuilder().getI32Type()), 206 {numDependencies, numOperands}); 207 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); 208 209 // Parse the types of results returned from the async execute op. 210 SmallVector<Type, 4> resultTypes; 211 if (parser.parseOptionalArrowTypeList(resultTypes)) 212 return failure(); 213 214 // Async execute first result is always a completion token. 215 parser.addTypeToList(tokenTy, result.types); 216 parser.addTypesToList(resultTypes, result.types); 217 218 // Parse operation attributes. 219 NamedAttrList attrs; 220 if (parser.parseOptionalAttrDictWithKeyword(attrs)) 221 return failure(); 222 result.addAttributes(attrs); 223 224 // Parse asynchronous region. 225 Region *body = result.addRegion(); 226 if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs}, 227 /*argTypes=*/{unwrappedTypes}, 228 /*enableNameShadowing=*/false)) 229 return failure(); 230 231 return success(); 232 } 233 234 static LogicalResult verify(ExecuteOp op) { 235 // Unwrap async.execute value operands types. 236 auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) { 237 return operand.getType().cast<ValueType>().getValueType(); 238 }); 239 240 // Verify that unwrapped argument types matches the body region arguments. 241 if (llvm::size(unwrappedTypes) != llvm::size(op.body().getArgumentTypes())) 242 return op.emitOpError("the number of async body region arguments does not " 243 "match the number of execute operation arguments"); 244 245 if (!std::equal(unwrappedTypes.begin(), unwrappedTypes.end(), 246 op.body().getArgumentTypes().begin())) 247 return op.emitOpError("async body region argument types do not match the " 248 "execute operation arguments types"); 249 250 return success(); 251 } 252 253 //===----------------------------------------------------------------------===// 254 /// AwaitOp 255 //===----------------------------------------------------------------------===// 256 257 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand, 258 ArrayRef<NamedAttribute> attrs) { 259 result.addOperands({operand}); 260 result.attributes.append(attrs.begin(), attrs.end()); 261 262 // Add unwrapped async.value type to the returned values types. 263 if (auto valueType = operand.getType().dyn_cast<ValueType>()) 264 result.addTypes(valueType.getValueType()); 265 } 266 267 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, 268 Type &resultType) { 269 if (parser.parseType(operandType)) 270 return failure(); 271 272 // Add unwrapped async.value type to the returned values types. 273 if (auto valueType = operandType.dyn_cast<ValueType>()) 274 resultType = valueType.getValueType(); 275 276 return success(); 277 } 278 279 static void printAwaitResultType(OpAsmPrinter &p, Type operandType, 280 Type resultType) { 281 p << operandType; 282 } 283 284 static LogicalResult verify(AwaitOp op) { 285 Type argType = op.operand().getType(); 286 287 // Awaiting on a token does not have any results. 288 if (argType.isa<TokenType>() && !op.getResultTypes().empty()) 289 return op.emitOpError("awaiting on a token must have empty result"); 290 291 // Awaiting on a value unwraps the async value type. 292 if (auto value = argType.dyn_cast<ValueType>()) { 293 if (*op.getResultType() != value.getValueType()) 294 return op.emitOpError() 295 << "result type " << *op.getResultType() 296 << " does not match async value type " << value.getValueType(); 297 } 298 299 return success(); 300 } 301 302 #define GET_OP_CLASSES 303 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" 304