1 //===- Async.cpp - MLIR Async Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Async/IR/Async.h"
10 
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 
14 using namespace mlir;
15 using namespace mlir::async;
16 
17 void AsyncDialect::initialize() {
18   addOperations<
19 #define GET_OP_LIST
20 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
21       >();
22   addTypes<TokenType>();
23   addTypes<ValueType>();
24 }
25 
26 /// Parse a type registered to this dialect.
27 Type AsyncDialect::parseType(DialectAsmParser &parser) const {
28   StringRef keyword;
29   if (parser.parseKeyword(&keyword))
30     return Type();
31 
32   if (keyword == "token")
33     return TokenType::get(getContext());
34 
35   if (keyword == "value") {
36     Type ty;
37     if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
38       parser.emitError(parser.getNameLoc(), "failed to parse async value type");
39       return Type();
40     }
41     return ValueType::get(ty);
42   }
43 
44   parser.emitError(parser.getNameLoc(), "unknown async type: ") << keyword;
45   return Type();
46 }
47 
48 /// Print a type registered to this dialect.
49 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
50   TypeSwitch<Type>(type)
51       .Case<TokenType>([&](TokenType) { os << "token"; })
52       .Case<ValueType>([&](ValueType valueTy) {
53         os << "value<";
54         os.printType(valueTy.getValueType());
55         os << '>';
56       })
57       .Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); });
58 }
59 
60 //===----------------------------------------------------------------------===//
61 /// ValueType
62 //===----------------------------------------------------------------------===//
63 
64 namespace mlir {
65 namespace async {
66 namespace detail {
67 
68 // Storage for `async.value<T>` type, the only member is the wrapped type.
69 struct ValueTypeStorage : public TypeStorage {
70   ValueTypeStorage(Type valueType) : valueType(valueType) {}
71 
72   /// The hash key used for uniquing.
73   using KeyTy = Type;
74   bool operator==(const KeyTy &key) const { return key == valueType; }
75 
76   /// Construction.
77   static ValueTypeStorage *construct(TypeStorageAllocator &allocator,
78                                      Type valueType) {
79     return new (allocator.allocate<ValueTypeStorage>())
80         ValueTypeStorage(valueType);
81   }
82 
83   Type valueType;
84 };
85 
86 } // namespace detail
87 } // namespace async
88 } // namespace mlir
89 
90 ValueType ValueType::get(Type valueType) {
91   return Base::get(valueType.getContext(), valueType);
92 }
93 
94 Type ValueType::getValueType() { return getImpl()->valueType; }
95 
96 //===----------------------------------------------------------------------===//
97 // YieldOp
98 //===----------------------------------------------------------------------===//
99 
100 static LogicalResult verify(YieldOp op) {
101   // Get the underlying value types from async values returned from the
102   // parent `async.execute` operation.
103   auto executeOp = op.getParentOfType<ExecuteOp>();
104   auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) {
105     return result.getType().cast<ValueType>().getValueType();
106   });
107 
108   if (!std::equal(types.begin(), types.end(), op.getOperandTypes().begin()))
109     return op.emitOpError("Operand types do not match the types returned from "
110                           "the parent ExecuteOp");
111 
112   return success();
113 }
114 
115 //===----------------------------------------------------------------------===//
116 /// ExecuteOp
117 //===----------------------------------------------------------------------===//
118 
119 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
120 
121 static void print(OpAsmPrinter &p, ExecuteOp op) {
122   p << op.getOperationName();
123 
124   // [%tokens,...]
125   if (!op.dependencies().empty())
126     p << " [" << op.dependencies() << "]";
127 
128   // (%value as %unwrapped: !async.value<!arg.type>, ...)
129   if (!op.operands().empty()) {
130     p << " (";
131     llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable {
132       p << operand << " as " << op.body().front().getArgument(n++) << ": "
133         << operand.getType();
134     });
135     p << ")";
136   }
137 
138   // -> (!async.value<!return.type>, ...)
139   p.printOptionalArrowTypeList(op.getResultTypes().drop_front(1));
140   p.printOptionalAttrDictWithKeyword(op.getAttrs(), {kOperandSegmentSizesAttr});
141   p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
142 }
143 
144 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
145   MLIRContext *ctx = result.getContext();
146 
147   // Sizes of parsed variadic operands, will be updated below after parsing.
148   int32_t numDependencies = 0;
149   int32_t numOperands = 0;
150 
151   auto tokenTy = TokenType::get(ctx);
152 
153   // Parse dependency tokens.
154   if (succeeded(parser.parseOptionalLSquare())) {
155     SmallVector<OpAsmParser::OperandType, 4> tokenArgs;
156     if (parser.parseOperandList(tokenArgs) ||
157         parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
158         parser.parseRSquare())
159       return failure();
160 
161     numDependencies = tokenArgs.size();
162   }
163 
164   // Parse async value operands (%value as %unwrapped : !async.value<!type>).
165   SmallVector<OpAsmParser::OperandType, 4> valueArgs;
166   SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs;
167   SmallVector<Type, 4> valueTypes;
168   SmallVector<Type, 4> unwrappedTypes;
169 
170   if (succeeded(parser.parseOptionalLParen())) {
171     auto argsLoc = parser.getCurrentLocation();
172 
173     // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
174     auto parseAsyncValueArg = [&]() -> ParseResult {
175       if (parser.parseOperand(valueArgs.emplace_back()) ||
176           parser.parseKeyword("as") ||
177           parser.parseOperand(unwrappedArgs.emplace_back()) ||
178           parser.parseColonType(valueTypes.emplace_back()))
179         return failure();
180 
181       auto valueTy = valueTypes.back().dyn_cast<ValueType>();
182       unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
183 
184       return success();
185     };
186 
187     // If the next token is `)` skip async value arguments parsing.
188     if (failed(parser.parseOptionalRParen())) {
189       do {
190         if (parseAsyncValueArg())
191           return failure();
192       } while (succeeded(parser.parseOptionalComma()));
193 
194       if (parser.parseRParen() ||
195           parser.resolveOperands(valueArgs, valueTypes, argsLoc,
196                                  result.operands))
197         return failure();
198     }
199 
200     numOperands = valueArgs.size();
201   }
202 
203   // Add derived `operand_segment_sizes` attribute based on parsed operands.
204   auto operandSegmentSizes = DenseIntElementsAttr::get(
205       VectorType::get({2}, parser.getBuilder().getI32Type()),
206       {numDependencies, numOperands});
207   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
208 
209   // Parse the types of results returned from the async execute op.
210   SmallVector<Type, 4> resultTypes;
211   if (parser.parseOptionalArrowTypeList(resultTypes))
212     return failure();
213 
214   // Async execute first result is always a completion token.
215   parser.addTypeToList(tokenTy, result.types);
216   parser.addTypesToList(resultTypes, result.types);
217 
218   // Parse operation attributes.
219   NamedAttrList attrs;
220   if (parser.parseOptionalAttrDictWithKeyword(attrs))
221     return failure();
222   result.addAttributes(attrs);
223 
224   // Parse asynchronous region.
225   Region *body = result.addRegion();
226   if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs},
227                          /*argTypes=*/{unwrappedTypes},
228                          /*enableNameShadowing=*/false))
229     return failure();
230 
231   return success();
232 }
233 
234 static LogicalResult verify(ExecuteOp op) {
235   // Unwrap async.execute value operands types.
236   auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) {
237     return operand.getType().cast<ValueType>().getValueType();
238   });
239 
240   // Verify that unwrapped argument types matches the body region arguments.
241   if (llvm::size(unwrappedTypes) != llvm::size(op.body().getArgumentTypes()))
242     return op.emitOpError("the number of async body region arguments does not "
243                           "match the number of execute operation arguments");
244 
245   if (!std::equal(unwrappedTypes.begin(), unwrappedTypes.end(),
246                   op.body().getArgumentTypes().begin()))
247     return op.emitOpError("async body region argument types do not match the "
248                           "execute operation arguments types");
249 
250   return success();
251 }
252 
253 //===----------------------------------------------------------------------===//
254 /// AwaitOp
255 //===----------------------------------------------------------------------===//
256 
257 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
258                     ArrayRef<NamedAttribute> attrs) {
259   result.addOperands({operand});
260   result.attributes.append(attrs.begin(), attrs.end());
261 
262   // Add unwrapped async.value type to the returned values types.
263   if (auto valueType = operand.getType().dyn_cast<ValueType>())
264     result.addTypes(valueType.getValueType());
265 }
266 
267 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
268                                         Type &resultType) {
269   if (parser.parseType(operandType))
270     return failure();
271 
272   // Add unwrapped async.value type to the returned values types.
273   if (auto valueType = operandType.dyn_cast<ValueType>())
274     resultType = valueType.getValueType();
275 
276   return success();
277 }
278 
279 static void printAwaitResultType(OpAsmPrinter &p, Type operandType,
280                                  Type resultType) {
281   p << operandType;
282 }
283 
284 static LogicalResult verify(AwaitOp op) {
285   Type argType = op.operand().getType();
286 
287   // Awaiting on a token does not have any results.
288   if (argType.isa<TokenType>() && !op.getResultTypes().empty())
289     return op.emitOpError("awaiting on a token must have empty result");
290 
291   // Awaiting on a value unwraps the async value type.
292   if (auto value = argType.dyn_cast<ValueType>()) {
293     if (*op.getResultType() != value.getValueType())
294       return op.emitOpError()
295              << "result type " << *op.getResultType()
296              << " does not match async value type " << value.getValueType();
297   }
298 
299   return success();
300 }
301 
302 #define GET_OP_CLASSES
303 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
304