1 //===- Async.cpp - MLIR Async Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Async/IR/Async.h"
10 
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 
14 using namespace mlir;
15 using namespace mlir::async;
16 
17 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
18 
19 constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
20 
21 void AsyncDialect::initialize() {
22   addOperations<
23 #define GET_OP_LIST
24 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
25       >();
26   addTypes<
27 #define GET_TYPEDEF_LIST
28 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
29       >();
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // YieldOp
34 //===----------------------------------------------------------------------===//
35 
36 LogicalResult YieldOp::verify() {
37   // Get the underlying value types from async values returned from the
38   // parent `async.execute` operation.
39   auto executeOp = (*this)->getParentOfType<ExecuteOp>();
40   auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) {
41     return result.getType().cast<ValueType>().getValueType();
42   });
43 
44   if (getOperandTypes() != types)
45     return emitOpError("operand types do not match the types returned from "
46                        "the parent ExecuteOp");
47 
48   return success();
49 }
50 
51 MutableOperandRange
52 YieldOp::getMutableSuccessorOperands(Optional<unsigned> index) {
53   return operandsMutable();
54 }
55 
56 //===----------------------------------------------------------------------===//
57 /// ExecuteOp
58 //===----------------------------------------------------------------------===//
59 
60 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
61 
62 OperandRange ExecuteOp::getSuccessorEntryOperands(unsigned index) {
63   assert(index == 0 && "invalid region index");
64   return operands();
65 }
66 
67 bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
68   const auto getValueOrTokenType = [](Type type) {
69     if (auto value = type.dyn_cast<ValueType>())
70       return value.getValueType();
71     return type;
72   };
73   return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
74 }
75 
76 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
77                                     ArrayRef<Attribute>,
78                                     SmallVectorImpl<RegionSuccessor> &regions) {
79   // The `body` region branch back to the parent operation.
80   if (index.hasValue()) {
81     assert(*index == 0 && "invalid region index");
82     regions.push_back(RegionSuccessor(results()));
83     return;
84   }
85 
86   // Otherwise the successor is the body region.
87   regions.push_back(RegionSuccessor(&body(), body().getArguments()));
88 }
89 
90 void ExecuteOp::build(OpBuilder &builder, OperationState &result,
91                       TypeRange resultTypes, ValueRange dependencies,
92                       ValueRange operands, BodyBuilderFn bodyBuilder) {
93 
94   result.addOperands(dependencies);
95   result.addOperands(operands);
96 
97   // Add derived `operand_segment_sizes` attribute based on parsed operands.
98   int32_t numDependencies = dependencies.size();
99   int32_t numOperands = operands.size();
100   auto operandSegmentSizes = DenseIntElementsAttr::get(
101       VectorType::get({2}, builder.getIntegerType(32)),
102       {numDependencies, numOperands});
103   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
104 
105   // First result is always a token, and then `resultTypes` wrapped into
106   // `async.value`.
107   result.addTypes({TokenType::get(result.getContext())});
108   for (Type type : resultTypes)
109     result.addTypes(ValueType::get(type));
110 
111   // Add a body region with block arguments as unwrapped async value operands.
112   Region *bodyRegion = result.addRegion();
113   bodyRegion->push_back(new Block);
114   Block &bodyBlock = bodyRegion->front();
115   for (Value operand : operands) {
116     auto valueType = operand.getType().dyn_cast<ValueType>();
117     bodyBlock.addArgument(valueType ? valueType.getValueType()
118                                     : operand.getType(),
119                           operand.getLoc());
120   }
121 
122   // Create the default terminator if the builder is not provided and if the
123   // expected result is empty. Otherwise, leave this to the caller
124   // because we don't know which values to return from the execute op.
125   if (resultTypes.empty() && !bodyBuilder) {
126     OpBuilder::InsertionGuard guard(builder);
127     builder.setInsertionPointToStart(&bodyBlock);
128     builder.create<async::YieldOp>(result.location, ValueRange());
129   } else if (bodyBuilder) {
130     OpBuilder::InsertionGuard guard(builder);
131     builder.setInsertionPointToStart(&bodyBlock);
132     bodyBuilder(builder, result.location, bodyBlock.getArguments());
133   }
134 }
135 
136 void ExecuteOp::print(OpAsmPrinter &p) {
137   // [%tokens,...]
138   if (!dependencies().empty())
139     p << " [" << dependencies() << "]";
140 
141   // (%value as %unwrapped: !async.value<!arg.type>, ...)
142   if (!operands().empty()) {
143     p << " (";
144     Block *entry = body().empty() ? nullptr : &body().front();
145     llvm::interleaveComma(operands(), p, [&, n = 0](Value operand) mutable {
146       Value argument = entry ? entry->getArgument(n++) : Value();
147       p << operand << " as " << argument << ": " << operand.getType();
148     });
149     p << ")";
150   }
151 
152   // -> (!async.value<!return.type>, ...)
153   p.printOptionalArrowTypeList(llvm::drop_begin(getResultTypes()));
154   p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
155                                      {kOperandSegmentSizesAttr});
156   p << ' ';
157   p.printRegion(body(), /*printEntryBlockArgs=*/false);
158 }
159 
160 ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
161   MLIRContext *ctx = result.getContext();
162 
163   // Sizes of parsed variadic operands, will be updated below after parsing.
164   int32_t numDependencies = 0;
165 
166   auto tokenTy = TokenType::get(ctx);
167 
168   // Parse dependency tokens.
169   if (succeeded(parser.parseOptionalLSquare())) {
170     SmallVector<OpAsmParser::OperandType, 4> tokenArgs;
171     if (parser.parseOperandList(tokenArgs) ||
172         parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
173         parser.parseRSquare())
174       return failure();
175 
176     numDependencies = tokenArgs.size();
177   }
178 
179   // Parse async value operands (%value as %unwrapped : !async.value<!type>).
180   SmallVector<OpAsmParser::OperandType, 4> valueArgs;
181   SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs;
182   SmallVector<Type, 4> valueTypes;
183   SmallVector<Type, 4> unwrappedTypes;
184 
185   // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
186   auto parseAsyncValueArg = [&]() -> ParseResult {
187     if (parser.parseOperand(valueArgs.emplace_back()) ||
188         parser.parseKeyword("as") ||
189         parser.parseOperand(unwrappedArgs.emplace_back()) ||
190         parser.parseColonType(valueTypes.emplace_back()))
191       return failure();
192 
193     auto valueTy = valueTypes.back().dyn_cast<ValueType>();
194     unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
195 
196     return success();
197   };
198 
199   auto argsLoc = parser.getCurrentLocation();
200   if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::OptionalParen,
201                                      parseAsyncValueArg) ||
202       parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands))
203     return failure();
204 
205   int32_t numOperands = valueArgs.size();
206 
207   // Add derived `operand_segment_sizes` attribute based on parsed operands.
208   auto operandSegmentSizes = DenseIntElementsAttr::get(
209       VectorType::get({2}, parser.getBuilder().getI32Type()),
210       {numDependencies, numOperands});
211   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
212 
213   // Parse the types of results returned from the async execute op.
214   SmallVector<Type, 4> resultTypes;
215   if (parser.parseOptionalArrowTypeList(resultTypes))
216     return failure();
217 
218   // Async execute first result is always a completion token.
219   parser.addTypeToList(tokenTy, result.types);
220   parser.addTypesToList(resultTypes, result.types);
221 
222   // Parse operation attributes.
223   NamedAttrList attrs;
224   if (parser.parseOptionalAttrDictWithKeyword(attrs))
225     return failure();
226   result.addAttributes(attrs);
227 
228   // Parse asynchronous region.
229   Region *body = result.addRegion();
230   if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs},
231                          /*argTypes=*/{unwrappedTypes},
232                          /*argLocations=*/{},
233                          /*enableNameShadowing=*/false))
234     return failure();
235 
236   return success();
237 }
238 
239 LogicalResult ExecuteOp::verifyRegions() {
240   // Unwrap async.execute value operands types.
241   auto unwrappedTypes = llvm::map_range(operands(), [](Value operand) {
242     return operand.getType().cast<ValueType>().getValueType();
243   });
244 
245   // Verify that unwrapped argument types matches the body region arguments.
246   if (body().getArgumentTypes() != unwrappedTypes)
247     return emitOpError("async body region argument types do not match the "
248                        "execute operation arguments types");
249 
250   return success();
251 }
252 
253 //===----------------------------------------------------------------------===//
254 /// CreateGroupOp
255 //===----------------------------------------------------------------------===//
256 
257 LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
258                                           PatternRewriter &rewriter) {
259   // Find all `await_all` users of the group.
260   llvm::SmallVector<AwaitAllOp> awaitAllUsers;
261 
262   auto isAwaitAll = [&](Operation *op) -> bool {
263     if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
264       awaitAllUsers.push_back(awaitAll);
265       return true;
266     }
267     return false;
268   };
269 
270   // Check if all users of the group are `await_all` operations.
271   if (!llvm::all_of(op->getUsers(), isAwaitAll))
272     return failure();
273 
274   // If group is only awaited without adding anything to it, we can safely erase
275   // the create operation and all users.
276   for (AwaitAllOp awaitAll : awaitAllUsers)
277     rewriter.eraseOp(awaitAll);
278   rewriter.eraseOp(op);
279 
280   return success();
281 }
282 
283 //===----------------------------------------------------------------------===//
284 /// AwaitOp
285 //===----------------------------------------------------------------------===//
286 
287 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
288                     ArrayRef<NamedAttribute> attrs) {
289   result.addOperands({operand});
290   result.attributes.append(attrs.begin(), attrs.end());
291 
292   // Add unwrapped async.value type to the returned values types.
293   if (auto valueType = operand.getType().dyn_cast<ValueType>())
294     result.addTypes(valueType.getValueType());
295 }
296 
297 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
298                                         Type &resultType) {
299   if (parser.parseType(operandType))
300     return failure();
301 
302   // Add unwrapped async.value type to the returned values types.
303   if (auto valueType = operandType.dyn_cast<ValueType>())
304     resultType = valueType.getValueType();
305 
306   return success();
307 }
308 
309 static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
310                                  Type operandType, Type resultType) {
311   p << operandType;
312 }
313 
314 LogicalResult AwaitOp::verify() {
315   Type argType = operand().getType();
316 
317   // Awaiting on a token does not have any results.
318   if (argType.isa<TokenType>() && !getResultTypes().empty())
319     return emitOpError("awaiting on a token must have empty result");
320 
321   // Awaiting on a value unwraps the async value type.
322   if (auto value = argType.dyn_cast<ValueType>()) {
323     if (*getResultType() != value.getValueType())
324       return emitOpError() << "result type " << *getResultType()
325                            << " does not match async value type "
326                            << value.getValueType();
327   }
328 
329   return success();
330 }
331 
332 //===----------------------------------------------------------------------===//
333 // TableGen'd op method definitions
334 //===----------------------------------------------------------------------===//
335 
336 #define GET_OP_CLASSES
337 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
338 
339 //===----------------------------------------------------------------------===//
340 // TableGen'd type method definitions
341 //===----------------------------------------------------------------------===//
342 
343 #define GET_TYPEDEF_CLASSES
344 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
345 
346 void ValueType::print(AsmPrinter &printer) const {
347   printer << "<";
348   printer.printType(getValueType());
349   printer << '>';
350 }
351 
352 Type ValueType::parse(mlir::AsmParser &parser) {
353   Type ty;
354   if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
355     parser.emitError(parser.getNameLoc(), "failed to parse async value type");
356     return Type();
357   }
358   return ValueType::get(ty);
359 }
360