1 //===- Async.cpp - MLIR Async Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Async/IR/Async.h"
10 
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 
14 using namespace mlir;
15 using namespace mlir::async;
16 
17 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
18 
19 constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
20 
21 void AsyncDialect::initialize() {
22   addOperations<
23 #define GET_OP_LIST
24 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
25       >();
26   addTypes<
27 #define GET_TYPEDEF_LIST
28 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
29       >();
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // YieldOp
34 //===----------------------------------------------------------------------===//
35 
36 static LogicalResult verify(YieldOp op) {
37   // Get the underlying value types from async values returned from the
38   // parent `async.execute` operation.
39   auto executeOp = op->getParentOfType<ExecuteOp>();
40   auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) {
41     return result.getType().cast<ValueType>().getValueType();
42   });
43 
44   if (op.getOperandTypes() != types)
45     return op.emitOpError("operand types do not match the types returned from "
46                           "the parent ExecuteOp");
47 
48   return success();
49 }
50 
51 MutableOperandRange
52 YieldOp::getMutableSuccessorOperands(Optional<unsigned> index) {
53   assert(!index.hasValue());
54   return operandsMutable();
55 }
56 
57 //===----------------------------------------------------------------------===//
58 /// ExecuteOp
59 //===----------------------------------------------------------------------===//
60 
61 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
62 
63 OperandRange ExecuteOp::getSuccessorEntryOperands(unsigned index) {
64   assert(index == 0 && "invalid region index");
65   return operands();
66 }
67 
68 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
69                                     ArrayRef<Attribute>,
70                                     SmallVectorImpl<RegionSuccessor> &regions) {
71   // The `body` region branch back to the parent operation.
72   if (index.hasValue()) {
73     assert(*index == 0 && "invalid region index");
74     regions.push_back(RegionSuccessor(results()));
75     return;
76   }
77 
78   // Otherwise the successor is the body region.
79   regions.push_back(RegionSuccessor(&body(), body().getArguments()));
80 }
81 
82 void ExecuteOp::build(OpBuilder &builder, OperationState &result,
83                       TypeRange resultTypes, ValueRange dependencies,
84                       ValueRange operands, BodyBuilderFn bodyBuilder) {
85 
86   result.addOperands(dependencies);
87   result.addOperands(operands);
88 
89   // Add derived `operand_segment_sizes` attribute based on parsed operands.
90   int32_t numDependencies = dependencies.size();
91   int32_t numOperands = operands.size();
92   auto operandSegmentSizes = DenseIntElementsAttr::get(
93       VectorType::get({2}, builder.getIntegerType(32)),
94       {numDependencies, numOperands});
95   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
96 
97   // First result is always a token, and then `resultTypes` wrapped into
98   // `async.value`.
99   result.addTypes({TokenType::get(result.getContext())});
100   for (Type type : resultTypes)
101     result.addTypes(ValueType::get(type));
102 
103   // Add a body region with block arguments as unwrapped async value operands.
104   Region *bodyRegion = result.addRegion();
105   bodyRegion->push_back(new Block);
106   Block &bodyBlock = bodyRegion->front();
107   for (Value operand : operands) {
108     auto valueType = operand.getType().dyn_cast<ValueType>();
109     bodyBlock.addArgument(valueType ? valueType.getValueType()
110                                     : operand.getType());
111   }
112 
113   // Create the default terminator if the builder is not provided and if the
114   // expected result is empty. Otherwise, leave this to the caller
115   // because we don't know which values to return from the execute op.
116   if (resultTypes.empty() && !bodyBuilder) {
117     OpBuilder::InsertionGuard guard(builder);
118     builder.setInsertionPointToStart(&bodyBlock);
119     builder.create<async::YieldOp>(result.location, ValueRange());
120   } else if (bodyBuilder) {
121     OpBuilder::InsertionGuard guard(builder);
122     builder.setInsertionPointToStart(&bodyBlock);
123     bodyBuilder(builder, result.location, bodyBlock.getArguments());
124   }
125 }
126 
127 static void print(OpAsmPrinter &p, ExecuteOp op) {
128   // [%tokens,...]
129   if (!op.dependencies().empty())
130     p << " [" << op.dependencies() << "]";
131 
132   // (%value as %unwrapped: !async.value<!arg.type>, ...)
133   if (!op.operands().empty()) {
134     p << " (";
135     Block *entry = op.body().empty() ? nullptr : &op.body().front();
136     llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable {
137       Value argument = entry ? entry->getArgument(n++) : Value();
138       p << operand << " as " << argument << ": " << operand.getType();
139     });
140     p << ")";
141   }
142 
143   // -> (!async.value<!return.type>, ...)
144   p.printOptionalArrowTypeList(llvm::drop_begin(op.getResultTypes()));
145   p.printOptionalAttrDictWithKeyword(op->getAttrs(),
146                                      {kOperandSegmentSizesAttr});
147   p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
148 }
149 
150 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
151   MLIRContext *ctx = result.getContext();
152 
153   // Sizes of parsed variadic operands, will be updated below after parsing.
154   int32_t numDependencies = 0;
155 
156   auto tokenTy = TokenType::get(ctx);
157 
158   // Parse dependency tokens.
159   if (succeeded(parser.parseOptionalLSquare())) {
160     SmallVector<OpAsmParser::OperandType, 4> tokenArgs;
161     if (parser.parseOperandList(tokenArgs) ||
162         parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
163         parser.parseRSquare())
164       return failure();
165 
166     numDependencies = tokenArgs.size();
167   }
168 
169   // Parse async value operands (%value as %unwrapped : !async.value<!type>).
170   SmallVector<OpAsmParser::OperandType, 4> valueArgs;
171   SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs;
172   SmallVector<Type, 4> valueTypes;
173   SmallVector<Type, 4> unwrappedTypes;
174 
175   // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
176   auto parseAsyncValueArg = [&]() -> ParseResult {
177     if (parser.parseOperand(valueArgs.emplace_back()) ||
178         parser.parseKeyword("as") ||
179         parser.parseOperand(unwrappedArgs.emplace_back()) ||
180         parser.parseColonType(valueTypes.emplace_back()))
181       return failure();
182 
183     auto valueTy = valueTypes.back().dyn_cast<ValueType>();
184     unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
185 
186     return success();
187   };
188 
189   auto argsLoc = parser.getCurrentLocation();
190   if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::OptionalParen,
191                                      parseAsyncValueArg) ||
192       parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands))
193     return failure();
194 
195   int32_t numOperands = valueArgs.size();
196 
197   // Add derived `operand_segment_sizes` attribute based on parsed operands.
198   auto operandSegmentSizes = DenseIntElementsAttr::get(
199       VectorType::get({2}, parser.getBuilder().getI32Type()),
200       {numDependencies, numOperands});
201   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
202 
203   // Parse the types of results returned from the async execute op.
204   SmallVector<Type, 4> resultTypes;
205   if (parser.parseOptionalArrowTypeList(resultTypes))
206     return failure();
207 
208   // Async execute first result is always a completion token.
209   parser.addTypeToList(tokenTy, result.types);
210   parser.addTypesToList(resultTypes, result.types);
211 
212   // Parse operation attributes.
213   NamedAttrList attrs;
214   if (parser.parseOptionalAttrDictWithKeyword(attrs))
215     return failure();
216   result.addAttributes(attrs);
217 
218   // Parse asynchronous region.
219   Region *body = result.addRegion();
220   if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs},
221                          /*argTypes=*/{unwrappedTypes},
222                          /*enableNameShadowing=*/false))
223     return failure();
224 
225   return success();
226 }
227 
228 static LogicalResult verify(ExecuteOp op) {
229   // Unwrap async.execute value operands types.
230   auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) {
231     return operand.getType().cast<ValueType>().getValueType();
232   });
233 
234   // Verify that unwrapped argument types matches the body region arguments.
235   if (op.body().getArgumentTypes() != unwrappedTypes)
236     return op.emitOpError("async body region argument types do not match the "
237                           "execute operation arguments types");
238 
239   return success();
240 }
241 
242 //===----------------------------------------------------------------------===//
243 /// CreateGroupOp
244 //===----------------------------------------------------------------------===//
245 
246 LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
247                                           PatternRewriter &rewriter) {
248   // Find all `await_all` users of the group.
249   llvm::SmallVector<AwaitAllOp> awaitAllUsers;
250 
251   auto isAwaitAll = [&](Operation *op) -> bool {
252     if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
253       awaitAllUsers.push_back(awaitAll);
254       return true;
255     }
256     return false;
257   };
258 
259   // Check if all users of the group are `await_all` operations.
260   if (!llvm::all_of(op->getUsers(), isAwaitAll))
261     return failure();
262 
263   // If group is only awaited without adding anything to it, we can safely erase
264   // the create operation and all users.
265   for (AwaitAllOp awaitAll : awaitAllUsers)
266     rewriter.eraseOp(awaitAll);
267   rewriter.eraseOp(op);
268 
269   return success();
270 }
271 
272 //===----------------------------------------------------------------------===//
273 /// AwaitOp
274 //===----------------------------------------------------------------------===//
275 
276 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
277                     ArrayRef<NamedAttribute> attrs) {
278   result.addOperands({operand});
279   result.attributes.append(attrs.begin(), attrs.end());
280 
281   // Add unwrapped async.value type to the returned values types.
282   if (auto valueType = operand.getType().dyn_cast<ValueType>())
283     result.addTypes(valueType.getValueType());
284 }
285 
286 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
287                                         Type &resultType) {
288   if (parser.parseType(operandType))
289     return failure();
290 
291   // Add unwrapped async.value type to the returned values types.
292   if (auto valueType = operandType.dyn_cast<ValueType>())
293     resultType = valueType.getValueType();
294 
295   return success();
296 }
297 
298 static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
299                                  Type operandType, Type resultType) {
300   p << operandType;
301 }
302 
303 static LogicalResult verify(AwaitOp op) {
304   Type argType = op.operand().getType();
305 
306   // Awaiting on a token does not have any results.
307   if (argType.isa<TokenType>() && !op.getResultTypes().empty())
308     return op.emitOpError("awaiting on a token must have empty result");
309 
310   // Awaiting on a value unwraps the async value type.
311   if (auto value = argType.dyn_cast<ValueType>()) {
312     if (*op.getResultType() != value.getValueType())
313       return op.emitOpError()
314              << "result type " << *op.getResultType()
315              << " does not match async value type " << value.getValueType();
316   }
317 
318   return success();
319 }
320 
321 //===----------------------------------------------------------------------===//
322 // TableGen'd op method definitions
323 //===----------------------------------------------------------------------===//
324 
325 #define GET_OP_CLASSES
326 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
327 
328 //===----------------------------------------------------------------------===//
329 // TableGen'd type method definitions
330 //===----------------------------------------------------------------------===//
331 
332 #define GET_TYPEDEF_CLASSES
333 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
334 
335 void ValueType::print(AsmPrinter &printer) const {
336   printer << "<";
337   printer.printType(getValueType());
338   printer << '>';
339 }
340 
341 Type ValueType::parse(mlir::AsmParser &parser) {
342   Type ty;
343   if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
344     parser.emitError(parser.getNameLoc(), "failed to parse async value type");
345     return Type();
346   }
347   return ValueType::get(ty);
348 }
349 
350 /// Print a type registered to this dialect.
351 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
352   if (failed(generatedTypePrinter(type, os)))
353     llvm_unreachable("unexpected 'async' type kind");
354 }
355 
356 /// Parse a type registered to this dialect.
357 Type AsyncDialect::parseType(DialectAsmParser &parser) const {
358   StringRef typeTag;
359   if (parser.parseKeyword(&typeTag))
360     return Type();
361   Type genType;
362   auto parseResult = generatedTypeParser(parser, typeTag, genType);
363   if (parseResult.hasValue())
364     return genType;
365   parser.emitError(parser.getNameLoc(), "unknown async type: ") << typeTag;
366   return {};
367 }
368