1 //===- Async.cpp - MLIR Async Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Async/IR/Async.h"
10 
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 
14 using namespace mlir;
15 using namespace mlir::async;
16 
17 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
18 
19 constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
20 
21 void AsyncDialect::initialize() {
22   addOperations<
23 #define GET_OP_LIST
24 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
25       >();
26   addTypes<
27 #define GET_TYPEDEF_LIST
28 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
29       >();
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // YieldOp
34 //===----------------------------------------------------------------------===//
35 
36 static LogicalResult verify(YieldOp op) {
37   // Get the underlying value types from async values returned from the
38   // parent `async.execute` operation.
39   auto executeOp = op->getParentOfType<ExecuteOp>();
40   auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) {
41     return result.getType().cast<ValueType>().getValueType();
42   });
43 
44   if (op.getOperandTypes() != types)
45     return op.emitOpError("operand types do not match the types returned from "
46                           "the parent ExecuteOp");
47 
48   return success();
49 }
50 
51 MutableOperandRange
52 YieldOp::getMutableSuccessorOperands(Optional<unsigned> index) {
53   assert(!index.hasValue());
54   return operandsMutable();
55 }
56 
57 //===----------------------------------------------------------------------===//
58 /// ExecuteOp
59 //===----------------------------------------------------------------------===//
60 
61 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
62 
63 void ExecuteOp::getNumRegionInvocations(
64     ArrayRef<Attribute>, SmallVectorImpl<int64_t> &countPerRegion) {
65   assert(countPerRegion.empty());
66   countPerRegion.push_back(1);
67 }
68 
69 OperandRange ExecuteOp::getSuccessorEntryOperands(unsigned index) {
70   assert(index == 0 && "invalid region index");
71   return operands();
72 }
73 
74 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
75                                     ArrayRef<Attribute>,
76                                     SmallVectorImpl<RegionSuccessor> &regions) {
77   // The `body` region branch back to the parent operation.
78   if (index.hasValue()) {
79     assert(*index == 0 && "invalid region index");
80     regions.push_back(RegionSuccessor(results()));
81     return;
82   }
83 
84   // Otherwise the successor is the body region.
85   regions.push_back(RegionSuccessor(&body(), body().getArguments()));
86 }
87 
88 void ExecuteOp::build(OpBuilder &builder, OperationState &result,
89                       TypeRange resultTypes, ValueRange dependencies,
90                       ValueRange operands, BodyBuilderFn bodyBuilder) {
91 
92   result.addOperands(dependencies);
93   result.addOperands(operands);
94 
95   // Add derived `operand_segment_sizes` attribute based on parsed operands.
96   int32_t numDependencies = dependencies.size();
97   int32_t numOperands = operands.size();
98   auto operandSegmentSizes = DenseIntElementsAttr::get(
99       VectorType::get({2}, builder.getIntegerType(32)),
100       {numDependencies, numOperands});
101   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
102 
103   // First result is always a token, and then `resultTypes` wrapped into
104   // `async.value`.
105   result.addTypes({TokenType::get(result.getContext())});
106   for (Type type : resultTypes)
107     result.addTypes(ValueType::get(type));
108 
109   // Add a body region with block arguments as unwrapped async value operands.
110   Region *bodyRegion = result.addRegion();
111   bodyRegion->push_back(new Block);
112   Block &bodyBlock = bodyRegion->front();
113   for (Value operand : operands) {
114     auto valueType = operand.getType().dyn_cast<ValueType>();
115     bodyBlock.addArgument(valueType ? valueType.getValueType()
116                                     : operand.getType());
117   }
118 
119   // Create the default terminator if the builder is not provided and if the
120   // expected result is empty. Otherwise, leave this to the caller
121   // because we don't know which values to return from the execute op.
122   if (resultTypes.empty() && !bodyBuilder) {
123     OpBuilder::InsertionGuard guard(builder);
124     builder.setInsertionPointToStart(&bodyBlock);
125     builder.create<async::YieldOp>(result.location, ValueRange());
126   } else if (bodyBuilder) {
127     OpBuilder::InsertionGuard guard(builder);
128     builder.setInsertionPointToStart(&bodyBlock);
129     bodyBuilder(builder, result.location, bodyBlock.getArguments());
130   }
131 }
132 
133 static void print(OpAsmPrinter &p, ExecuteOp op) {
134   p << op.getOperationName();
135 
136   // [%tokens,...]
137   if (!op.dependencies().empty())
138     p << " [" << op.dependencies() << "]";
139 
140   // (%value as %unwrapped: !async.value<!arg.type>, ...)
141   if (!op.operands().empty()) {
142     p << " (";
143     Block *entry = op.body().empty() ? nullptr : &op.body().front();
144     llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable {
145       Value argument = entry ? entry->getArgument(n++) : Value();
146       p << operand << " as " << argument << ": " << operand.getType();
147     });
148     p << ")";
149   }
150 
151   // -> (!async.value<!return.type>, ...)
152   p.printOptionalArrowTypeList(llvm::drop_begin(op.getResultTypes()));
153   p.printOptionalAttrDictWithKeyword(op->getAttrs(),
154                                      {kOperandSegmentSizesAttr});
155   p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
156 }
157 
158 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
159   MLIRContext *ctx = result.getContext();
160 
161   // Sizes of parsed variadic operands, will be updated below after parsing.
162   int32_t numDependencies = 0;
163   int32_t numOperands = 0;
164 
165   auto tokenTy = TokenType::get(ctx);
166 
167   // Parse dependency tokens.
168   if (succeeded(parser.parseOptionalLSquare())) {
169     SmallVector<OpAsmParser::OperandType, 4> tokenArgs;
170     if (parser.parseOperandList(tokenArgs) ||
171         parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
172         parser.parseRSquare())
173       return failure();
174 
175     numDependencies = tokenArgs.size();
176   }
177 
178   // Parse async value operands (%value as %unwrapped : !async.value<!type>).
179   SmallVector<OpAsmParser::OperandType, 4> valueArgs;
180   SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs;
181   SmallVector<Type, 4> valueTypes;
182   SmallVector<Type, 4> unwrappedTypes;
183 
184   if (succeeded(parser.parseOptionalLParen())) {
185     auto argsLoc = parser.getCurrentLocation();
186 
187     // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
188     auto parseAsyncValueArg = [&]() -> ParseResult {
189       if (parser.parseOperand(valueArgs.emplace_back()) ||
190           parser.parseKeyword("as") ||
191           parser.parseOperand(unwrappedArgs.emplace_back()) ||
192           parser.parseColonType(valueTypes.emplace_back()))
193         return failure();
194 
195       auto valueTy = valueTypes.back().dyn_cast<ValueType>();
196       unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
197 
198       return success();
199     };
200 
201     // If the next token is `)` skip async value arguments parsing.
202     if (failed(parser.parseOptionalRParen())) {
203       do {
204         if (parseAsyncValueArg())
205           return failure();
206       } while (succeeded(parser.parseOptionalComma()));
207 
208       if (parser.parseRParen() ||
209           parser.resolveOperands(valueArgs, valueTypes, argsLoc,
210                                  result.operands))
211         return failure();
212     }
213 
214     numOperands = valueArgs.size();
215   }
216 
217   // Add derived `operand_segment_sizes` attribute based on parsed operands.
218   auto operandSegmentSizes = DenseIntElementsAttr::get(
219       VectorType::get({2}, parser.getBuilder().getI32Type()),
220       {numDependencies, numOperands});
221   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
222 
223   // Parse the types of results returned from the async execute op.
224   SmallVector<Type, 4> resultTypes;
225   if (parser.parseOptionalArrowTypeList(resultTypes))
226     return failure();
227 
228   // Async execute first result is always a completion token.
229   parser.addTypeToList(tokenTy, result.types);
230   parser.addTypesToList(resultTypes, result.types);
231 
232   // Parse operation attributes.
233   NamedAttrList attrs;
234   if (parser.parseOptionalAttrDictWithKeyword(attrs))
235     return failure();
236   result.addAttributes(attrs);
237 
238   // Parse asynchronous region.
239   Region *body = result.addRegion();
240   if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs},
241                          /*argTypes=*/{unwrappedTypes},
242                          /*enableNameShadowing=*/false))
243     return failure();
244 
245   return success();
246 }
247 
248 static LogicalResult verify(ExecuteOp op) {
249   // Unwrap async.execute value operands types.
250   auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) {
251     return operand.getType().cast<ValueType>().getValueType();
252   });
253 
254   // Verify that unwrapped argument types matches the body region arguments.
255   if (op.body().getArgumentTypes() != unwrappedTypes)
256     return op.emitOpError("async body region argument types do not match the "
257                           "execute operation arguments types");
258 
259   return success();
260 }
261 
262 //===----------------------------------------------------------------------===//
263 /// CreateGroupOp
264 //===----------------------------------------------------------------------===//
265 
266 LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
267                                           PatternRewriter &rewriter) {
268   // Find all `await_all` users of the group.
269   llvm::SmallVector<AwaitAllOp> awaitAllUsers;
270 
271   auto isAwaitAll = [&](Operation *op) -> bool {
272     if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
273       awaitAllUsers.push_back(awaitAll);
274       return true;
275     }
276     return false;
277   };
278 
279   // Check if all users of the group are `await_all` operations.
280   if (!llvm::all_of(op->getUsers(), isAwaitAll))
281     return failure();
282 
283   // If group is only awaited without adding anything to it, we can safely erase
284   // the create operation and all users.
285   for (AwaitAllOp awaitAll : awaitAllUsers)
286     rewriter.eraseOp(awaitAll);
287   rewriter.eraseOp(op);
288 
289   return success();
290 }
291 
292 //===----------------------------------------------------------------------===//
293 /// AwaitOp
294 //===----------------------------------------------------------------------===//
295 
296 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
297                     ArrayRef<NamedAttribute> attrs) {
298   result.addOperands({operand});
299   result.attributes.append(attrs.begin(), attrs.end());
300 
301   // Add unwrapped async.value type to the returned values types.
302   if (auto valueType = operand.getType().dyn_cast<ValueType>())
303     result.addTypes(valueType.getValueType());
304 }
305 
306 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
307                                         Type &resultType) {
308   if (parser.parseType(operandType))
309     return failure();
310 
311   // Add unwrapped async.value type to the returned values types.
312   if (auto valueType = operandType.dyn_cast<ValueType>())
313     resultType = valueType.getValueType();
314 
315   return success();
316 }
317 
318 static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
319                                  Type operandType, Type resultType) {
320   p << operandType;
321 }
322 
323 static LogicalResult verify(AwaitOp op) {
324   Type argType = op.operand().getType();
325 
326   // Awaiting on a token does not have any results.
327   if (argType.isa<TokenType>() && !op.getResultTypes().empty())
328     return op.emitOpError("awaiting on a token must have empty result");
329 
330   // Awaiting on a value unwraps the async value type.
331   if (auto value = argType.dyn_cast<ValueType>()) {
332     if (*op.getResultType() != value.getValueType())
333       return op.emitOpError()
334              << "result type " << *op.getResultType()
335              << " does not match async value type " << value.getValueType();
336   }
337 
338   return success();
339 }
340 
341 //===----------------------------------------------------------------------===//
342 // TableGen'd op method definitions
343 //===----------------------------------------------------------------------===//
344 
345 #define GET_OP_CLASSES
346 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
347 
348 //===----------------------------------------------------------------------===//
349 // TableGen'd type method definitions
350 //===----------------------------------------------------------------------===//
351 
352 #define GET_TYPEDEF_CLASSES
353 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
354 
355 void ValueType::print(DialectAsmPrinter &printer) const {
356   printer << getMnemonic();
357   printer << "<";
358   printer.printType(getValueType());
359   printer << '>';
360 }
361 
362 Type ValueType::parse(mlir::MLIRContext *, mlir::DialectAsmParser &parser) {
363   Type ty;
364   if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
365     parser.emitError(parser.getNameLoc(), "failed to parse async value type");
366     return Type();
367   }
368   return ValueType::get(ty);
369 }
370 
371 /// Print a type registered to this dialect.
372 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
373   if (failed(generatedTypePrinter(type, os)))
374     llvm_unreachable("unexpected 'async' type kind");
375 }
376 
377 /// Parse a type registered to this dialect.
378 Type AsyncDialect::parseType(DialectAsmParser &parser) const {
379   StringRef typeTag;
380   if (parser.parseKeyword(&typeTag))
381     return Type();
382   Type genType;
383   auto parseResult = generatedTypeParser(parser.getBuilder().getContext(),
384                                          parser, typeTag, genType);
385   if (parseResult.hasValue())
386     return genType;
387   parser.emitError(parser.getNameLoc(), "unknown async type: ") << typeTag;
388   return {};
389 }
390