1 //===- Async.cpp - MLIR Async Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Async/IR/Async.h"
10
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/TypeSwitch.h"
13
14 using namespace mlir;
15 using namespace mlir::async;
16
17 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
18
19 constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
20
initialize()21 void AsyncDialect::initialize() {
22 addOperations<
23 #define GET_OP_LIST
24 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
25 >();
26 addTypes<
27 #define GET_TYPEDEF_LIST
28 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
29 >();
30 }
31
32 //===----------------------------------------------------------------------===//
33 // YieldOp
34 //===----------------------------------------------------------------------===//
35
verify()36 LogicalResult YieldOp::verify() {
37 // Get the underlying value types from async values returned from the
38 // parent `async.execute` operation.
39 auto executeOp = (*this)->getParentOfType<ExecuteOp>();
40 auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) {
41 return result.getType().cast<ValueType>().getValueType();
42 });
43
44 if (getOperandTypes() != types)
45 return emitOpError("operand types do not match the types returned from "
46 "the parent ExecuteOp");
47
48 return success();
49 }
50
51 MutableOperandRange
getMutableSuccessorOperands(Optional<unsigned> index)52 YieldOp::getMutableSuccessorOperands(Optional<unsigned> index) {
53 return operandsMutable();
54 }
55
56 //===----------------------------------------------------------------------===//
57 /// ExecuteOp
58 //===----------------------------------------------------------------------===//
59
60 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
61
getSuccessorEntryOperands(Optional<unsigned> index)62 OperandRange ExecuteOp::getSuccessorEntryOperands(Optional<unsigned> index) {
63 assert(index && *index == 0 && "invalid region index");
64 return operands();
65 }
66
areTypesCompatible(Type lhs,Type rhs)67 bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
68 const auto getValueOrTokenType = [](Type type) {
69 if (auto value = type.dyn_cast<ValueType>())
70 return value.getValueType();
71 return type;
72 };
73 return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
74 }
75
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute>,SmallVectorImpl<RegionSuccessor> & regions)76 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
77 ArrayRef<Attribute>,
78 SmallVectorImpl<RegionSuccessor> ®ions) {
79 // The `body` region branch back to the parent operation.
80 if (index) {
81 assert(*index == 0 && "invalid region index");
82 regions.push_back(RegionSuccessor(results()));
83 return;
84 }
85
86 // Otherwise the successor is the body region.
87 regions.push_back(RegionSuccessor(&body(), body().getArguments()));
88 }
89
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ValueRange dependencies,ValueRange operands,BodyBuilderFn bodyBuilder)90 void ExecuteOp::build(OpBuilder &builder, OperationState &result,
91 TypeRange resultTypes, ValueRange dependencies,
92 ValueRange operands, BodyBuilderFn bodyBuilder) {
93
94 result.addOperands(dependencies);
95 result.addOperands(operands);
96
97 // Add derived `operand_segment_sizes` attribute based on parsed operands.
98 int32_t numDependencies = dependencies.size();
99 int32_t numOperands = operands.size();
100 auto operandSegmentSizes = DenseIntElementsAttr::get(
101 VectorType::get({2}, builder.getIntegerType(32)),
102 {numDependencies, numOperands});
103 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
104
105 // First result is always a token, and then `resultTypes` wrapped into
106 // `async.value`.
107 result.addTypes({TokenType::get(result.getContext())});
108 for (Type type : resultTypes)
109 result.addTypes(ValueType::get(type));
110
111 // Add a body region with block arguments as unwrapped async value operands.
112 Region *bodyRegion = result.addRegion();
113 bodyRegion->push_back(new Block);
114 Block &bodyBlock = bodyRegion->front();
115 for (Value operand : operands) {
116 auto valueType = operand.getType().dyn_cast<ValueType>();
117 bodyBlock.addArgument(valueType ? valueType.getValueType()
118 : operand.getType(),
119 operand.getLoc());
120 }
121
122 // Create the default terminator if the builder is not provided and if the
123 // expected result is empty. Otherwise, leave this to the caller
124 // because we don't know which values to return from the execute op.
125 if (resultTypes.empty() && !bodyBuilder) {
126 OpBuilder::InsertionGuard guard(builder);
127 builder.setInsertionPointToStart(&bodyBlock);
128 builder.create<async::YieldOp>(result.location, ValueRange());
129 } else if (bodyBuilder) {
130 OpBuilder::InsertionGuard guard(builder);
131 builder.setInsertionPointToStart(&bodyBlock);
132 bodyBuilder(builder, result.location, bodyBlock.getArguments());
133 }
134 }
135
print(OpAsmPrinter & p)136 void ExecuteOp::print(OpAsmPrinter &p) {
137 // [%tokens,...]
138 if (!dependencies().empty())
139 p << " [" << dependencies() << "]";
140
141 // (%value as %unwrapped: !async.value<!arg.type>, ...)
142 if (!operands().empty()) {
143 p << " (";
144 Block *entry = body().empty() ? nullptr : &body().front();
145 llvm::interleaveComma(operands(), p, [&, n = 0](Value operand) mutable {
146 Value argument = entry ? entry->getArgument(n++) : Value();
147 p << operand << " as " << argument << ": " << operand.getType();
148 });
149 p << ")";
150 }
151
152 // -> (!async.value<!return.type>, ...)
153 p.printOptionalArrowTypeList(llvm::drop_begin(getResultTypes()));
154 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
155 {kOperandSegmentSizesAttr});
156 p << ' ';
157 p.printRegion(body(), /*printEntryBlockArgs=*/false);
158 }
159
parse(OpAsmParser & parser,OperationState & result)160 ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
161 MLIRContext *ctx = result.getContext();
162
163 // Sizes of parsed variadic operands, will be updated below after parsing.
164 int32_t numDependencies = 0;
165
166 auto tokenTy = TokenType::get(ctx);
167
168 // Parse dependency tokens.
169 if (succeeded(parser.parseOptionalLSquare())) {
170 SmallVector<OpAsmParser::UnresolvedOperand, 4> tokenArgs;
171 if (parser.parseOperandList(tokenArgs) ||
172 parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
173 parser.parseRSquare())
174 return failure();
175
176 numDependencies = tokenArgs.size();
177 }
178
179 // Parse async value operands (%value as %unwrapped : !async.value<!type>).
180 SmallVector<OpAsmParser::UnresolvedOperand, 4> valueArgs;
181 SmallVector<OpAsmParser::Argument, 4> unwrappedArgs;
182 SmallVector<Type, 4> valueTypes;
183
184 // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
185 auto parseAsyncValueArg = [&]() -> ParseResult {
186 if (parser.parseOperand(valueArgs.emplace_back()) ||
187 parser.parseKeyword("as") ||
188 parser.parseArgument(unwrappedArgs.emplace_back()) ||
189 parser.parseColonType(valueTypes.emplace_back()))
190 return failure();
191
192 auto valueTy = valueTypes.back().dyn_cast<ValueType>();
193 unwrappedArgs.back().type = valueTy ? valueTy.getValueType() : Type();
194 return success();
195 };
196
197 auto argsLoc = parser.getCurrentLocation();
198 if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::OptionalParen,
199 parseAsyncValueArg) ||
200 parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands))
201 return failure();
202
203 int32_t numOperands = valueArgs.size();
204
205 // Add derived `operand_segment_sizes` attribute based on parsed operands.
206 auto operandSegmentSizes = DenseIntElementsAttr::get(
207 VectorType::get({2}, parser.getBuilder().getI32Type()),
208 {numDependencies, numOperands});
209 result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
210
211 // Parse the types of results returned from the async execute op.
212 SmallVector<Type, 4> resultTypes;
213 NamedAttrList attrs;
214 if (parser.parseOptionalArrowTypeList(resultTypes) ||
215 // Async execute first result is always a completion token.
216 parser.addTypeToList(tokenTy, result.types) ||
217 parser.addTypesToList(resultTypes, result.types) ||
218 // Parse operation attributes.
219 parser.parseOptionalAttrDictWithKeyword(attrs))
220 return failure();
221
222 result.addAttributes(attrs);
223
224 // Parse asynchronous region.
225 Region *body = result.addRegion();
226 return parser.parseRegion(*body, /*arguments=*/unwrappedArgs);
227 }
228
verifyRegions()229 LogicalResult ExecuteOp::verifyRegions() {
230 // Unwrap async.execute value operands types.
231 auto unwrappedTypes = llvm::map_range(operands(), [](Value operand) {
232 return operand.getType().cast<ValueType>().getValueType();
233 });
234
235 // Verify that unwrapped argument types matches the body region arguments.
236 if (body().getArgumentTypes() != unwrappedTypes)
237 return emitOpError("async body region argument types do not match the "
238 "execute operation arguments types");
239
240 return success();
241 }
242
243 //===----------------------------------------------------------------------===//
244 /// CreateGroupOp
245 //===----------------------------------------------------------------------===//
246
canonicalize(CreateGroupOp op,PatternRewriter & rewriter)247 LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
248 PatternRewriter &rewriter) {
249 // Find all `await_all` users of the group.
250 llvm::SmallVector<AwaitAllOp> awaitAllUsers;
251
252 auto isAwaitAll = [&](Operation *op) -> bool {
253 if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
254 awaitAllUsers.push_back(awaitAll);
255 return true;
256 }
257 return false;
258 };
259
260 // Check if all users of the group are `await_all` operations.
261 if (!llvm::all_of(op->getUsers(), isAwaitAll))
262 return failure();
263
264 // If group is only awaited without adding anything to it, we can safely erase
265 // the create operation and all users.
266 for (AwaitAllOp awaitAll : awaitAllUsers)
267 rewriter.eraseOp(awaitAll);
268 rewriter.eraseOp(op);
269
270 return success();
271 }
272
273 //===----------------------------------------------------------------------===//
274 /// AwaitOp
275 //===----------------------------------------------------------------------===//
276
build(OpBuilder & builder,OperationState & result,Value operand,ArrayRef<NamedAttribute> attrs)277 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
278 ArrayRef<NamedAttribute> attrs) {
279 result.addOperands({operand});
280 result.attributes.append(attrs.begin(), attrs.end());
281
282 // Add unwrapped async.value type to the returned values types.
283 if (auto valueType = operand.getType().dyn_cast<ValueType>())
284 result.addTypes(valueType.getValueType());
285 }
286
parseAwaitResultType(OpAsmParser & parser,Type & operandType,Type & resultType)287 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
288 Type &resultType) {
289 if (parser.parseType(operandType))
290 return failure();
291
292 // Add unwrapped async.value type to the returned values types.
293 if (auto valueType = operandType.dyn_cast<ValueType>())
294 resultType = valueType.getValueType();
295
296 return success();
297 }
298
printAwaitResultType(OpAsmPrinter & p,Operation * op,Type operandType,Type resultType)299 static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
300 Type operandType, Type resultType) {
301 p << operandType;
302 }
303
verify()304 LogicalResult AwaitOp::verify() {
305 Type argType = operand().getType();
306
307 // Awaiting on a token does not have any results.
308 if (argType.isa<TokenType>() && !getResultTypes().empty())
309 return emitOpError("awaiting on a token must have empty result");
310
311 // Awaiting on a value unwraps the async value type.
312 if (auto value = argType.dyn_cast<ValueType>()) {
313 if (*getResultType() != value.getValueType())
314 return emitOpError() << "result type " << *getResultType()
315 << " does not match async value type "
316 << value.getValueType();
317 }
318
319 return success();
320 }
321
322 //===----------------------------------------------------------------------===//
323 // TableGen'd op method definitions
324 //===----------------------------------------------------------------------===//
325
326 #define GET_OP_CLASSES
327 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
328
329 //===----------------------------------------------------------------------===//
330 // TableGen'd type method definitions
331 //===----------------------------------------------------------------------===//
332
333 #define GET_TYPEDEF_CLASSES
334 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
335
print(AsmPrinter & printer) const336 void ValueType::print(AsmPrinter &printer) const {
337 printer << "<";
338 printer.printType(getValueType());
339 printer << '>';
340 }
341
parse(mlir::AsmParser & parser)342 Type ValueType::parse(mlir::AsmParser &parser) {
343 Type ty;
344 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
345 parser.emitError(parser.getNameLoc(), "failed to parse async value type");
346 return Type();
347 }
348 return ValueType::get(ty);
349 }
350