1 //===- Async.cpp - MLIR Async Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Async/IR/Async.h"
10 
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 
14 using namespace mlir;
15 using namespace mlir::async;
16 
17 void AsyncDialect::initialize() {
18   addOperations<
19 #define GET_OP_LIST
20 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
21       >();
22   addTypes<
23 #define GET_TYPEDEF_LIST
24 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
25       >();
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // YieldOp
30 //===----------------------------------------------------------------------===//
31 
32 static LogicalResult verify(YieldOp op) {
33   // Get the underlying value types from async values returned from the
34   // parent `async.execute` operation.
35   auto executeOp = op->getParentOfType<ExecuteOp>();
36   auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) {
37     return result.getType().cast<ValueType>().getValueType();
38   });
39 
40   if (op.getOperandTypes() != types)
41     return op.emitOpError("operand types do not match the types returned from "
42                           "the parent ExecuteOp");
43 
44   return success();
45 }
46 
47 //===----------------------------------------------------------------------===//
48 /// ExecuteOp
49 //===----------------------------------------------------------------------===//
50 
51 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
52 
53 void ExecuteOp::getNumRegionInvocations(
54     ArrayRef<Attribute> operands, SmallVectorImpl<int64_t> &countPerRegion) {
55   (void)operands;
56   assert(countPerRegion.empty());
57   countPerRegion.push_back(1);
58 }
59 
60 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
61                                     ArrayRef<Attribute> operands,
62                                     SmallVectorImpl<RegionSuccessor> &regions) {
63   // The `body` region branch back to the parent operation.
64   if (index.hasValue()) {
65     assert(*index == 0);
66     regions.push_back(RegionSuccessor(getResults()));
67     return;
68   }
69 
70   // Otherwise the successor is the body region.
71   regions.push_back(RegionSuccessor(&body()));
72 }
73 
74 void ExecuteOp::build(OpBuilder &builder, OperationState &result,
75                       TypeRange resultTypes, ValueRange dependencies,
76                       ValueRange operands, BodyBuilderFn bodyBuilder) {
77 
78   result.addOperands(dependencies);
79   result.addOperands(operands);
80 
81   // Add derived `operand_segment_sizes` attribute based on parsed operands.
82   int32_t numDependencies = dependencies.size();
83   int32_t numOperands = operands.size();
84   auto operandSegmentSizes = DenseIntElementsAttr::get(
85       VectorType::get({2}, builder.getIntegerType(32)),
86       {numDependencies, numOperands});
87   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
88 
89   // First result is always a token, and then `resultTypes` wrapped into
90   // `async.value`.
91   result.addTypes({TokenType::get(result.getContext())});
92   for (Type type : resultTypes)
93     result.addTypes(ValueType::get(type));
94 
95   // Add a body region with block arguments as unwrapped async value operands.
96   Region *bodyRegion = result.addRegion();
97   bodyRegion->push_back(new Block);
98   Block &bodyBlock = bodyRegion->front();
99   for (Value operand : operands) {
100     auto valueType = operand.getType().dyn_cast<ValueType>();
101     bodyBlock.addArgument(valueType ? valueType.getValueType()
102                                     : operand.getType());
103   }
104 
105   // Create the default terminator if the builder is not provided and if the
106   // expected result is empty. Otherwise, leave this to the caller
107   // because we don't know which values to return from the execute op.
108   if (resultTypes.empty() && !bodyBuilder) {
109     OpBuilder::InsertionGuard guard(builder);
110     builder.setInsertionPointToStart(&bodyBlock);
111     builder.create<async::YieldOp>(result.location, ValueRange());
112   } else if (bodyBuilder) {
113     OpBuilder::InsertionGuard guard(builder);
114     builder.setInsertionPointToStart(&bodyBlock);
115     bodyBuilder(builder, result.location, bodyBlock.getArguments());
116   }
117 }
118 
119 static void print(OpAsmPrinter &p, ExecuteOp op) {
120   p << op.getOperationName();
121 
122   // [%tokens,...]
123   if (!op.dependencies().empty())
124     p << " [" << op.dependencies() << "]";
125 
126   // (%value as %unwrapped: !async.value<!arg.type>, ...)
127   if (!op.operands().empty()) {
128     p << " (";
129     Block *entry = op.body().empty() ? nullptr : &op.body().front();
130     llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable {
131       Value argument = entry ? entry->getArgument(n++) : Value();
132       p << operand << " as " << argument << ": " << operand.getType();
133     });
134     p << ")";
135   }
136 
137   // -> (!async.value<!return.type>, ...)
138   p.printOptionalArrowTypeList(llvm::drop_begin(op.getResultTypes()));
139   p.printOptionalAttrDictWithKeyword(op->getAttrs(),
140                                      {kOperandSegmentSizesAttr});
141   p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
142 }
143 
144 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
145   MLIRContext *ctx = result.getContext();
146 
147   // Sizes of parsed variadic operands, will be updated below after parsing.
148   int32_t numDependencies = 0;
149   int32_t numOperands = 0;
150 
151   auto tokenTy = TokenType::get(ctx);
152 
153   // Parse dependency tokens.
154   if (succeeded(parser.parseOptionalLSquare())) {
155     SmallVector<OpAsmParser::OperandType, 4> tokenArgs;
156     if (parser.parseOperandList(tokenArgs) ||
157         parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
158         parser.parseRSquare())
159       return failure();
160 
161     numDependencies = tokenArgs.size();
162   }
163 
164   // Parse async value operands (%value as %unwrapped : !async.value<!type>).
165   SmallVector<OpAsmParser::OperandType, 4> valueArgs;
166   SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs;
167   SmallVector<Type, 4> valueTypes;
168   SmallVector<Type, 4> unwrappedTypes;
169 
170   if (succeeded(parser.parseOptionalLParen())) {
171     auto argsLoc = parser.getCurrentLocation();
172 
173     // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
174     auto parseAsyncValueArg = [&]() -> ParseResult {
175       if (parser.parseOperand(valueArgs.emplace_back()) ||
176           parser.parseKeyword("as") ||
177           parser.parseOperand(unwrappedArgs.emplace_back()) ||
178           parser.parseColonType(valueTypes.emplace_back()))
179         return failure();
180 
181       auto valueTy = valueTypes.back().dyn_cast<ValueType>();
182       unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
183 
184       return success();
185     };
186 
187     // If the next token is `)` skip async value arguments parsing.
188     if (failed(parser.parseOptionalRParen())) {
189       do {
190         if (parseAsyncValueArg())
191           return failure();
192       } while (succeeded(parser.parseOptionalComma()));
193 
194       if (parser.parseRParen() ||
195           parser.resolveOperands(valueArgs, valueTypes, argsLoc,
196                                  result.operands))
197         return failure();
198     }
199 
200     numOperands = valueArgs.size();
201   }
202 
203   // Add derived `operand_segment_sizes` attribute based on parsed operands.
204   auto operandSegmentSizes = DenseIntElementsAttr::get(
205       VectorType::get({2}, parser.getBuilder().getI32Type()),
206       {numDependencies, numOperands});
207   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
208 
209   // Parse the types of results returned from the async execute op.
210   SmallVector<Type, 4> resultTypes;
211   if (parser.parseOptionalArrowTypeList(resultTypes))
212     return failure();
213 
214   // Async execute first result is always a completion token.
215   parser.addTypeToList(tokenTy, result.types);
216   parser.addTypesToList(resultTypes, result.types);
217 
218   // Parse operation attributes.
219   NamedAttrList attrs;
220   if (parser.parseOptionalAttrDictWithKeyword(attrs))
221     return failure();
222   result.addAttributes(attrs);
223 
224   // Parse asynchronous region.
225   Region *body = result.addRegion();
226   if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs},
227                          /*argTypes=*/{unwrappedTypes},
228                          /*enableNameShadowing=*/false))
229     return failure();
230 
231   return success();
232 }
233 
234 static LogicalResult verify(ExecuteOp op) {
235   // Unwrap async.execute value operands types.
236   auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) {
237     return operand.getType().cast<ValueType>().getValueType();
238   });
239 
240   // Verify that unwrapped argument types matches the body region arguments.
241   if (op.body().getArgumentTypes() != unwrappedTypes)
242     return op.emitOpError("async body region argument types do not match the "
243                           "execute operation arguments types");
244 
245   return success();
246 }
247 
248 //===----------------------------------------------------------------------===//
249 /// AwaitOp
250 //===----------------------------------------------------------------------===//
251 
252 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
253                     ArrayRef<NamedAttribute> attrs) {
254   result.addOperands({operand});
255   result.attributes.append(attrs.begin(), attrs.end());
256 
257   // Add unwrapped async.value type to the returned values types.
258   if (auto valueType = operand.getType().dyn_cast<ValueType>())
259     result.addTypes(valueType.getValueType());
260 }
261 
262 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
263                                         Type &resultType) {
264   if (parser.parseType(operandType))
265     return failure();
266 
267   // Add unwrapped async.value type to the returned values types.
268   if (auto valueType = operandType.dyn_cast<ValueType>())
269     resultType = valueType.getValueType();
270 
271   return success();
272 }
273 
274 static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
275                                  Type operandType, Type resultType) {
276   p << operandType;
277 }
278 
279 static LogicalResult verify(AwaitOp op) {
280   Type argType = op.operand().getType();
281 
282   // Awaiting on a token does not have any results.
283   if (argType.isa<TokenType>() && !op.getResultTypes().empty())
284     return op.emitOpError("awaiting on a token must have empty result");
285 
286   // Awaiting on a value unwraps the async value type.
287   if (auto value = argType.dyn_cast<ValueType>()) {
288     if (*op.getResultType() != value.getValueType())
289       return op.emitOpError()
290              << "result type " << *op.getResultType()
291              << " does not match async value type " << value.getValueType();
292   }
293 
294   return success();
295 }
296 
297 //===----------------------------------------------------------------------===//
298 // TableGen'd op method definitions
299 //===----------------------------------------------------------------------===//
300 
301 #define GET_OP_CLASSES
302 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
303 
304 //===----------------------------------------------------------------------===//
305 // TableGen'd type method definitions
306 //===----------------------------------------------------------------------===//
307 
308 #define GET_TYPEDEF_CLASSES
309 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
310 
311 void ValueType::print(DialectAsmPrinter &printer) const {
312   printer << getMnemonic();
313   printer << "<";
314   printer.printType(getValueType());
315   printer << '>';
316 }
317 
318 Type ValueType::parse(mlir::MLIRContext *, mlir::DialectAsmParser &parser) {
319   Type ty;
320   if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
321     parser.emitError(parser.getNameLoc(), "failed to parse async value type");
322     return Type();
323   }
324   return ValueType::get(ty);
325 }
326 
327 /// Print a type registered to this dialect.
328 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
329   if (failed(generatedTypePrinter(type, os)))
330     llvm_unreachable("unexpected 'async' type kind");
331 }
332 
333 /// Parse a type registered to this dialect.
334 Type AsyncDialect::parseType(DialectAsmParser &parser) const {
335   StringRef typeTag;
336   if (parser.parseKeyword(&typeTag))
337     return Type();
338   Type genType;
339   auto parseResult = generatedTypeParser(parser.getBuilder().getContext(),
340                                          parser, typeTag, genType);
341   if (parseResult.hasValue())
342     return genType;
343   parser.emitError(parser.getNameLoc(), "unknown async type: ") << typeTag;
344   return {};
345 }
346