1 //===- Async.cpp - MLIR Async Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Async/IR/Async.h"
10 
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 
14 using namespace mlir;
15 using namespace mlir::async;
16 
17 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
18 
19 constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
20 
21 void AsyncDialect::initialize() {
22   addOperations<
23 #define GET_OP_LIST
24 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
25       >();
26   addTypes<
27 #define GET_TYPEDEF_LIST
28 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
29       >();
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // YieldOp
34 //===----------------------------------------------------------------------===//
35 
36 static LogicalResult verify(YieldOp op) {
37   // Get the underlying value types from async values returned from the
38   // parent `async.execute` operation.
39   auto executeOp = op->getParentOfType<ExecuteOp>();
40   auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) {
41     return result.getType().cast<ValueType>().getValueType();
42   });
43 
44   if (op.getOperandTypes() != types)
45     return op.emitOpError("operand types do not match the types returned from "
46                           "the parent ExecuteOp");
47 
48   return success();
49 }
50 
51 //===----------------------------------------------------------------------===//
52 /// ExecuteOp
53 //===----------------------------------------------------------------------===//
54 
55 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
56 
57 void ExecuteOp::getNumRegionInvocations(
58     ArrayRef<Attribute> operands, SmallVectorImpl<int64_t> &countPerRegion) {
59   (void)operands;
60   assert(countPerRegion.empty());
61   countPerRegion.push_back(1);
62 }
63 
64 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
65                                     ArrayRef<Attribute> operands,
66                                     SmallVectorImpl<RegionSuccessor> &regions) {
67   // The `body` region branch back to the parent operation.
68   if (index.hasValue()) {
69     assert(*index == 0);
70     regions.push_back(RegionSuccessor(getResults()));
71     return;
72   }
73 
74   // Otherwise the successor is the body region.
75   regions.push_back(RegionSuccessor(&body()));
76 }
77 
78 void ExecuteOp::build(OpBuilder &builder, OperationState &result,
79                       TypeRange resultTypes, ValueRange dependencies,
80                       ValueRange operands, BodyBuilderFn bodyBuilder) {
81 
82   result.addOperands(dependencies);
83   result.addOperands(operands);
84 
85   // Add derived `operand_segment_sizes` attribute based on parsed operands.
86   int32_t numDependencies = dependencies.size();
87   int32_t numOperands = operands.size();
88   auto operandSegmentSizes = DenseIntElementsAttr::get(
89       VectorType::get({2}, builder.getIntegerType(32)),
90       {numDependencies, numOperands});
91   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
92 
93   // First result is always a token, and then `resultTypes` wrapped into
94   // `async.value`.
95   result.addTypes({TokenType::get(result.getContext())});
96   for (Type type : resultTypes)
97     result.addTypes(ValueType::get(type));
98 
99   // Add a body region with block arguments as unwrapped async value operands.
100   Region *bodyRegion = result.addRegion();
101   bodyRegion->push_back(new Block);
102   Block &bodyBlock = bodyRegion->front();
103   for (Value operand : operands) {
104     auto valueType = operand.getType().dyn_cast<ValueType>();
105     bodyBlock.addArgument(valueType ? valueType.getValueType()
106                                     : operand.getType());
107   }
108 
109   // Create the default terminator if the builder is not provided and if the
110   // expected result is empty. Otherwise, leave this to the caller
111   // because we don't know which values to return from the execute op.
112   if (resultTypes.empty() && !bodyBuilder) {
113     OpBuilder::InsertionGuard guard(builder);
114     builder.setInsertionPointToStart(&bodyBlock);
115     builder.create<async::YieldOp>(result.location, ValueRange());
116   } else if (bodyBuilder) {
117     OpBuilder::InsertionGuard guard(builder);
118     builder.setInsertionPointToStart(&bodyBlock);
119     bodyBuilder(builder, result.location, bodyBlock.getArguments());
120   }
121 }
122 
123 static void print(OpAsmPrinter &p, ExecuteOp op) {
124   p << op.getOperationName();
125 
126   // [%tokens,...]
127   if (!op.dependencies().empty())
128     p << " [" << op.dependencies() << "]";
129 
130   // (%value as %unwrapped: !async.value<!arg.type>, ...)
131   if (!op.operands().empty()) {
132     p << " (";
133     Block *entry = op.body().empty() ? nullptr : &op.body().front();
134     llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable {
135       Value argument = entry ? entry->getArgument(n++) : Value();
136       p << operand << " as " << argument << ": " << operand.getType();
137     });
138     p << ")";
139   }
140 
141   // -> (!async.value<!return.type>, ...)
142   p.printOptionalArrowTypeList(llvm::drop_begin(op.getResultTypes()));
143   p.printOptionalAttrDictWithKeyword(op->getAttrs(),
144                                      {kOperandSegmentSizesAttr});
145   p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
146 }
147 
148 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
149   MLIRContext *ctx = result.getContext();
150 
151   // Sizes of parsed variadic operands, will be updated below after parsing.
152   int32_t numDependencies = 0;
153   int32_t numOperands = 0;
154 
155   auto tokenTy = TokenType::get(ctx);
156 
157   // Parse dependency tokens.
158   if (succeeded(parser.parseOptionalLSquare())) {
159     SmallVector<OpAsmParser::OperandType, 4> tokenArgs;
160     if (parser.parseOperandList(tokenArgs) ||
161         parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
162         parser.parseRSquare())
163       return failure();
164 
165     numDependencies = tokenArgs.size();
166   }
167 
168   // Parse async value operands (%value as %unwrapped : !async.value<!type>).
169   SmallVector<OpAsmParser::OperandType, 4> valueArgs;
170   SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs;
171   SmallVector<Type, 4> valueTypes;
172   SmallVector<Type, 4> unwrappedTypes;
173 
174   if (succeeded(parser.parseOptionalLParen())) {
175     auto argsLoc = parser.getCurrentLocation();
176 
177     // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
178     auto parseAsyncValueArg = [&]() -> ParseResult {
179       if (parser.parseOperand(valueArgs.emplace_back()) ||
180           parser.parseKeyword("as") ||
181           parser.parseOperand(unwrappedArgs.emplace_back()) ||
182           parser.parseColonType(valueTypes.emplace_back()))
183         return failure();
184 
185       auto valueTy = valueTypes.back().dyn_cast<ValueType>();
186       unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
187 
188       return success();
189     };
190 
191     // If the next token is `)` skip async value arguments parsing.
192     if (failed(parser.parseOptionalRParen())) {
193       do {
194         if (parseAsyncValueArg())
195           return failure();
196       } while (succeeded(parser.parseOptionalComma()));
197 
198       if (parser.parseRParen() ||
199           parser.resolveOperands(valueArgs, valueTypes, argsLoc,
200                                  result.operands))
201         return failure();
202     }
203 
204     numOperands = valueArgs.size();
205   }
206 
207   // Add derived `operand_segment_sizes` attribute based on parsed operands.
208   auto operandSegmentSizes = DenseIntElementsAttr::get(
209       VectorType::get({2}, parser.getBuilder().getI32Type()),
210       {numDependencies, numOperands});
211   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
212 
213   // Parse the types of results returned from the async execute op.
214   SmallVector<Type, 4> resultTypes;
215   if (parser.parseOptionalArrowTypeList(resultTypes))
216     return failure();
217 
218   // Async execute first result is always a completion token.
219   parser.addTypeToList(tokenTy, result.types);
220   parser.addTypesToList(resultTypes, result.types);
221 
222   // Parse operation attributes.
223   NamedAttrList attrs;
224   if (parser.parseOptionalAttrDictWithKeyword(attrs))
225     return failure();
226   result.addAttributes(attrs);
227 
228   // Parse asynchronous region.
229   Region *body = result.addRegion();
230   if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs},
231                          /*argTypes=*/{unwrappedTypes},
232                          /*enableNameShadowing=*/false))
233     return failure();
234 
235   return success();
236 }
237 
238 static LogicalResult verify(ExecuteOp op) {
239   // Unwrap async.execute value operands types.
240   auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) {
241     return operand.getType().cast<ValueType>().getValueType();
242   });
243 
244   // Verify that unwrapped argument types matches the body region arguments.
245   if (op.body().getArgumentTypes() != unwrappedTypes)
246     return op.emitOpError("async body region argument types do not match the "
247                           "execute operation arguments types");
248 
249   return success();
250 }
251 
252 //===----------------------------------------------------------------------===//
253 /// CreateGroupOp
254 //===----------------------------------------------------------------------===//
255 
256 LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
257                                           PatternRewriter &rewriter) {
258   // Find all `await_all` users of the group.
259   llvm::SmallVector<AwaitAllOp> awaitAllUsers;
260 
261   auto isAwaitAll = [&](Operation *op) -> bool {
262     if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
263       awaitAllUsers.push_back(awaitAll);
264       return true;
265     }
266     return false;
267   };
268 
269   // Check if all users of the group are `await_all` operations.
270   if (!llvm::all_of(op->getUsers(), isAwaitAll))
271     return failure();
272 
273   // If group is only awaited without adding anything to it, we can safely erase
274   // the create operation and all users.
275   for (AwaitAllOp awaitAll : awaitAllUsers)
276     rewriter.eraseOp(awaitAll);
277   rewriter.eraseOp(op);
278 
279   return success();
280 }
281 
282 //===----------------------------------------------------------------------===//
283 /// AwaitOp
284 //===----------------------------------------------------------------------===//
285 
286 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
287                     ArrayRef<NamedAttribute> attrs) {
288   result.addOperands({operand});
289   result.attributes.append(attrs.begin(), attrs.end());
290 
291   // Add unwrapped async.value type to the returned values types.
292   if (auto valueType = operand.getType().dyn_cast<ValueType>())
293     result.addTypes(valueType.getValueType());
294 }
295 
296 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
297                                         Type &resultType) {
298   if (parser.parseType(operandType))
299     return failure();
300 
301   // Add unwrapped async.value type to the returned values types.
302   if (auto valueType = operandType.dyn_cast<ValueType>())
303     resultType = valueType.getValueType();
304 
305   return success();
306 }
307 
308 static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
309                                  Type operandType, Type resultType) {
310   p << operandType;
311 }
312 
313 static LogicalResult verify(AwaitOp op) {
314   Type argType = op.operand().getType();
315 
316   // Awaiting on a token does not have any results.
317   if (argType.isa<TokenType>() && !op.getResultTypes().empty())
318     return op.emitOpError("awaiting on a token must have empty result");
319 
320   // Awaiting on a value unwraps the async value type.
321   if (auto value = argType.dyn_cast<ValueType>()) {
322     if (*op.getResultType() != value.getValueType())
323       return op.emitOpError()
324              << "result type " << *op.getResultType()
325              << " does not match async value type " << value.getValueType();
326   }
327 
328   return success();
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // TableGen'd op method definitions
333 //===----------------------------------------------------------------------===//
334 
335 #define GET_OP_CLASSES
336 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
337 
338 //===----------------------------------------------------------------------===//
339 // TableGen'd type method definitions
340 //===----------------------------------------------------------------------===//
341 
342 #define GET_TYPEDEF_CLASSES
343 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
344 
345 void ValueType::print(DialectAsmPrinter &printer) const {
346   printer << getMnemonic();
347   printer << "<";
348   printer.printType(getValueType());
349   printer << '>';
350 }
351 
352 Type ValueType::parse(mlir::MLIRContext *, mlir::DialectAsmParser &parser) {
353   Type ty;
354   if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
355     parser.emitError(parser.getNameLoc(), "failed to parse async value type");
356     return Type();
357   }
358   return ValueType::get(ty);
359 }
360 
361 /// Print a type registered to this dialect.
362 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
363   if (failed(generatedTypePrinter(type, os)))
364     llvm_unreachable("unexpected 'async' type kind");
365 }
366 
367 /// Parse a type registered to this dialect.
368 Type AsyncDialect::parseType(DialectAsmParser &parser) const {
369   StringRef typeTag;
370   if (parser.parseKeyword(&typeTag))
371     return Type();
372   Type genType;
373   auto parseResult = generatedTypeParser(parser.getBuilder().getContext(),
374                                          parser, typeTag, genType);
375   if (parseResult.hasValue())
376     return genType;
377   parser.emitError(parser.getNameLoc(), "unknown async type: ") << typeTag;
378   return {};
379 }
380