1 //===- Async.cpp - MLIR Async Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Async/IR/Async.h"
10 
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 
14 using namespace mlir;
15 using namespace mlir::async;
16 
17 void AsyncDialect::initialize() {
18   addOperations<
19 #define GET_OP_LIST
20 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
21       >();
22   addTypes<
23 #define GET_TYPEDEF_LIST
24 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
25       >();
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // YieldOp
30 //===----------------------------------------------------------------------===//
31 
32 static LogicalResult verify(YieldOp op) {
33   // Get the underlying value types from async values returned from the
34   // parent `async.execute` operation.
35   auto executeOp = op->getParentOfType<ExecuteOp>();
36   auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) {
37     return result.getType().cast<ValueType>().getValueType();
38   });
39 
40   if (op.getOperandTypes() != types)
41     return op.emitOpError("operand types do not match the types returned from "
42                           "the parent ExecuteOp");
43 
44   return success();
45 }
46 
47 //===----------------------------------------------------------------------===//
48 /// ExecuteOp
49 //===----------------------------------------------------------------------===//
50 
51 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
52 
53 void ExecuteOp::getNumRegionInvocations(
54     ArrayRef<Attribute> operands, SmallVectorImpl<int64_t> &countPerRegion) {
55   (void)operands;
56   assert(countPerRegion.empty());
57   countPerRegion.push_back(1);
58 }
59 
60 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
61                                     ArrayRef<Attribute> operands,
62                                     SmallVectorImpl<RegionSuccessor> &regions) {
63   // The `body` region branch back to the parent operation.
64   if (index.hasValue()) {
65     assert(*index == 0);
66     regions.push_back(RegionSuccessor(getResults()));
67     return;
68   }
69 
70   // Otherwise the successor is the body region.
71   regions.push_back(RegionSuccessor(&body()));
72 }
73 
74 void ExecuteOp::build(OpBuilder &builder, OperationState &result,
75                       TypeRange resultTypes, ValueRange dependencies,
76                       ValueRange operands, BodyBuilderFn bodyBuilder) {
77 
78   result.addOperands(dependencies);
79   result.addOperands(operands);
80 
81   // Add derived `operand_segment_sizes` attribute based on parsed operands.
82   int32_t numDependencies = dependencies.size();
83   int32_t numOperands = operands.size();
84   auto operandSegmentSizes = DenseIntElementsAttr::get(
85       VectorType::get({2}, builder.getIntegerType(32)),
86       {numDependencies, numOperands});
87   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
88 
89   // First result is always a token, and then `resultTypes` wrapped into
90   // `async.value`.
91   result.addTypes({TokenType::get(result.getContext())});
92   for (Type type : resultTypes)
93     result.addTypes(ValueType::get(type));
94 
95   // Add a body region with block arguments as unwrapped async value operands.
96   Region *bodyRegion = result.addRegion();
97   bodyRegion->push_back(new Block);
98   Block &bodyBlock = bodyRegion->front();
99   for (Value operand : operands) {
100     auto valueType = operand.getType().dyn_cast<ValueType>();
101     bodyBlock.addArgument(valueType ? valueType.getValueType()
102                                     : operand.getType());
103   }
104 
105   // Create the default terminator if the builder is not provided and if the
106   // expected result is empty. Otherwise, leave this to the caller
107   // because we don't know which values to return from the execute op.
108   if (resultTypes.empty() && !bodyBuilder) {
109     OpBuilder::InsertionGuard guard(builder);
110     builder.setInsertionPointToStart(&bodyBlock);
111     builder.create<async::YieldOp>(result.location, ValueRange());
112   } else if (bodyBuilder) {
113     OpBuilder::InsertionGuard guard(builder);
114     builder.setInsertionPointToStart(&bodyBlock);
115     bodyBuilder(builder, result.location, bodyBlock.getArguments());
116   }
117 }
118 
119 static void print(OpAsmPrinter &p, ExecuteOp op) {
120   p << op.getOperationName();
121 
122   // [%tokens,...]
123   if (!op.dependencies().empty())
124     p << " [" << op.dependencies() << "]";
125 
126   // (%value as %unwrapped: !async.value<!arg.type>, ...)
127   if (!op.operands().empty()) {
128     p << " (";
129     llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable {
130       p << operand << " as " << op.body().front().getArgument(n++) << ": "
131         << operand.getType();
132     });
133     p << ")";
134   }
135 
136   // -> (!async.value<!return.type>, ...)
137   p.printOptionalArrowTypeList(op.getResultTypes().drop_front(1));
138   p.printOptionalAttrDictWithKeyword(op.getAttrs(), {kOperandSegmentSizesAttr});
139   p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
140 }
141 
142 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
143   MLIRContext *ctx = result.getContext();
144 
145   // Sizes of parsed variadic operands, will be updated below after parsing.
146   int32_t numDependencies = 0;
147   int32_t numOperands = 0;
148 
149   auto tokenTy = TokenType::get(ctx);
150 
151   // Parse dependency tokens.
152   if (succeeded(parser.parseOptionalLSquare())) {
153     SmallVector<OpAsmParser::OperandType, 4> tokenArgs;
154     if (parser.parseOperandList(tokenArgs) ||
155         parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
156         parser.parseRSquare())
157       return failure();
158 
159     numDependencies = tokenArgs.size();
160   }
161 
162   // Parse async value operands (%value as %unwrapped : !async.value<!type>).
163   SmallVector<OpAsmParser::OperandType, 4> valueArgs;
164   SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs;
165   SmallVector<Type, 4> valueTypes;
166   SmallVector<Type, 4> unwrappedTypes;
167 
168   if (succeeded(parser.parseOptionalLParen())) {
169     auto argsLoc = parser.getCurrentLocation();
170 
171     // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
172     auto parseAsyncValueArg = [&]() -> ParseResult {
173       if (parser.parseOperand(valueArgs.emplace_back()) ||
174           parser.parseKeyword("as") ||
175           parser.parseOperand(unwrappedArgs.emplace_back()) ||
176           parser.parseColonType(valueTypes.emplace_back()))
177         return failure();
178 
179       auto valueTy = valueTypes.back().dyn_cast<ValueType>();
180       unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
181 
182       return success();
183     };
184 
185     // If the next token is `)` skip async value arguments parsing.
186     if (failed(parser.parseOptionalRParen())) {
187       do {
188         if (parseAsyncValueArg())
189           return failure();
190       } while (succeeded(parser.parseOptionalComma()));
191 
192       if (parser.parseRParen() ||
193           parser.resolveOperands(valueArgs, valueTypes, argsLoc,
194                                  result.operands))
195         return failure();
196     }
197 
198     numOperands = valueArgs.size();
199   }
200 
201   // Add derived `operand_segment_sizes` attribute based on parsed operands.
202   auto operandSegmentSizes = DenseIntElementsAttr::get(
203       VectorType::get({2}, parser.getBuilder().getI32Type()),
204       {numDependencies, numOperands});
205   result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
206 
207   // Parse the types of results returned from the async execute op.
208   SmallVector<Type, 4> resultTypes;
209   if (parser.parseOptionalArrowTypeList(resultTypes))
210     return failure();
211 
212   // Async execute first result is always a completion token.
213   parser.addTypeToList(tokenTy, result.types);
214   parser.addTypesToList(resultTypes, result.types);
215 
216   // Parse operation attributes.
217   NamedAttrList attrs;
218   if (parser.parseOptionalAttrDictWithKeyword(attrs))
219     return failure();
220   result.addAttributes(attrs);
221 
222   // Parse asynchronous region.
223   Region *body = result.addRegion();
224   if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs},
225                          /*argTypes=*/{unwrappedTypes},
226                          /*enableNameShadowing=*/false))
227     return failure();
228 
229   return success();
230 }
231 
232 static LogicalResult verify(ExecuteOp op) {
233   // Unwrap async.execute value operands types.
234   auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) {
235     return operand.getType().cast<ValueType>().getValueType();
236   });
237 
238   // Verify that unwrapped argument types matches the body region arguments.
239   if (op.body().getArgumentTypes() != unwrappedTypes)
240     return op.emitOpError("async body region argument types do not match the "
241                           "execute operation arguments types");
242 
243   return success();
244 }
245 
246 //===----------------------------------------------------------------------===//
247 /// AwaitOp
248 //===----------------------------------------------------------------------===//
249 
250 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
251                     ArrayRef<NamedAttribute> attrs) {
252   result.addOperands({operand});
253   result.attributes.append(attrs.begin(), attrs.end());
254 
255   // Add unwrapped async.value type to the returned values types.
256   if (auto valueType = operand.getType().dyn_cast<ValueType>())
257     result.addTypes(valueType.getValueType());
258 }
259 
260 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
261                                         Type &resultType) {
262   if (parser.parseType(operandType))
263     return failure();
264 
265   // Add unwrapped async.value type to the returned values types.
266   if (auto valueType = operandType.dyn_cast<ValueType>())
267     resultType = valueType.getValueType();
268 
269   return success();
270 }
271 
272 static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
273                                  Type operandType, Type resultType) {
274   p << operandType;
275 }
276 
277 static LogicalResult verify(AwaitOp op) {
278   Type argType = op.operand().getType();
279 
280   // Awaiting on a token does not have any results.
281   if (argType.isa<TokenType>() && !op.getResultTypes().empty())
282     return op.emitOpError("awaiting on a token must have empty result");
283 
284   // Awaiting on a value unwraps the async value type.
285   if (auto value = argType.dyn_cast<ValueType>()) {
286     if (*op.getResultType() != value.getValueType())
287       return op.emitOpError()
288              << "result type " << *op.getResultType()
289              << " does not match async value type " << value.getValueType();
290   }
291 
292   return success();
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // TableGen'd op method definitions
297 //===----------------------------------------------------------------------===//
298 
299 #define GET_OP_CLASSES
300 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
301 
302 //===----------------------------------------------------------------------===//
303 // TableGen'd type method definitions
304 //===----------------------------------------------------------------------===//
305 
306 #define GET_TYPEDEF_CLASSES
307 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
308 
309 void ValueType::print(DialectAsmPrinter &printer) const {
310   printer << getMnemonic();
311   printer << "<";
312   printer.printType(getValueType());
313   printer << '>';
314 }
315 
316 Type ValueType::parse(mlir::MLIRContext *, mlir::DialectAsmParser &parser) {
317   Type ty;
318   if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
319     parser.emitError(parser.getNameLoc(), "failed to parse async value type");
320     return Type();
321   }
322   return ValueType::get(ty);
323 }
324 
325 /// Print a type registered to this dialect.
326 void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
327   if (failed(generatedTypePrinter(type, os)))
328     llvm_unreachable("unexpected 'async' type kind");
329 }
330 
331 /// Parse a type registered to this dialect.
332 Type AsyncDialect::parseType(DialectAsmParser &parser) const {
333   StringRef mnemonic;
334   if (parser.parseKeyword(&mnemonic))
335     return Type();
336 
337   return generatedTypeParser(getContext(), parser, mnemonic);
338 }
339