1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/Async/IR/Async.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
19 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21 #include "mlir/IR/ImplicitLocOpBuilder.h"
22 #include "mlir/IR/TypeUtilities.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/ADT/TypeSwitch.h"
26
27 #define DEBUG_TYPE "convert-async-to-llvm"
28
29 using namespace mlir;
30 using namespace mlir::async;
31
32 //===----------------------------------------------------------------------===//
33 // Async Runtime C API declaration.
34 //===----------------------------------------------------------------------===//
35
36 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
37 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
38 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
39 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
40 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
41 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
42 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
43 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
44 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
45 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
46 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
47 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
48 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
49 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
50 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
51 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
52 static constexpr const char *kGetValueStorage =
53 "mlirAsyncRuntimeGetValueStorage";
54 static constexpr const char *kAddTokenToGroup =
55 "mlirAsyncRuntimeAddTokenToGroup";
56 static constexpr const char *kAwaitTokenAndExecute =
57 "mlirAsyncRuntimeAwaitTokenAndExecute";
58 static constexpr const char *kAwaitValueAndExecute =
59 "mlirAsyncRuntimeAwaitValueAndExecute";
60 static constexpr const char *kAwaitAllAndExecute =
61 "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
62 static constexpr const char *kGetNumWorkerThreads =
63 "mlirAsyncRuntimGetNumWorkerThreads";
64
65 namespace {
66 /// Async Runtime API function types.
67 ///
68 /// Because we can't create API function signature for type parametrized
69 /// async.value type, we use opaque pointers (!llvm.ptr<i8>) instead. After
70 /// lowering all async data types become opaque pointers at runtime.
71 struct AsyncAPI {
72 // All async types are lowered to opaque i8* LLVM pointers at runtime.
opaquePointerType__anonbcc505750111::AsyncAPI73 static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
74 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
75 }
76
tokenType__anonbcc505750111::AsyncAPI77 static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
78 return LLVM::LLVMTokenType::get(ctx);
79 }
80
addOrDropRefFunctionType__anonbcc505750111::AsyncAPI81 static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
82 auto ref = opaquePointerType(ctx);
83 auto count = IntegerType::get(ctx, 64);
84 return FunctionType::get(ctx, {ref, count}, {});
85 }
86
createTokenFunctionType__anonbcc505750111::AsyncAPI87 static FunctionType createTokenFunctionType(MLIRContext *ctx) {
88 return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
89 }
90
createValueFunctionType__anonbcc505750111::AsyncAPI91 static FunctionType createValueFunctionType(MLIRContext *ctx) {
92 auto i64 = IntegerType::get(ctx, 64);
93 auto value = opaquePointerType(ctx);
94 return FunctionType::get(ctx, {i64}, {value});
95 }
96
createGroupFunctionType__anonbcc505750111::AsyncAPI97 static FunctionType createGroupFunctionType(MLIRContext *ctx) {
98 auto i64 = IntegerType::get(ctx, 64);
99 return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
100 }
101
getValueStorageFunctionType__anonbcc505750111::AsyncAPI102 static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
103 auto value = opaquePointerType(ctx);
104 auto storage = opaquePointerType(ctx);
105 return FunctionType::get(ctx, {value}, {storage});
106 }
107
emplaceTokenFunctionType__anonbcc505750111::AsyncAPI108 static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
109 return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
110 }
111
emplaceValueFunctionType__anonbcc505750111::AsyncAPI112 static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
113 auto value = opaquePointerType(ctx);
114 return FunctionType::get(ctx, {value}, {});
115 }
116
setTokenErrorFunctionType__anonbcc505750111::AsyncAPI117 static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
118 return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
119 }
120
setValueErrorFunctionType__anonbcc505750111::AsyncAPI121 static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
122 auto value = opaquePointerType(ctx);
123 return FunctionType::get(ctx, {value}, {});
124 }
125
isTokenErrorFunctionType__anonbcc505750111::AsyncAPI126 static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
127 auto i1 = IntegerType::get(ctx, 1);
128 return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
129 }
130
isValueErrorFunctionType__anonbcc505750111::AsyncAPI131 static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
132 auto value = opaquePointerType(ctx);
133 auto i1 = IntegerType::get(ctx, 1);
134 return FunctionType::get(ctx, {value}, {i1});
135 }
136
isGroupErrorFunctionType__anonbcc505750111::AsyncAPI137 static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
138 auto i1 = IntegerType::get(ctx, 1);
139 return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
140 }
141
awaitTokenFunctionType__anonbcc505750111::AsyncAPI142 static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
143 return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
144 }
145
awaitValueFunctionType__anonbcc505750111::AsyncAPI146 static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
147 auto value = opaquePointerType(ctx);
148 return FunctionType::get(ctx, {value}, {});
149 }
150
awaitGroupFunctionType__anonbcc505750111::AsyncAPI151 static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
152 return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
153 }
154
executeFunctionType__anonbcc505750111::AsyncAPI155 static FunctionType executeFunctionType(MLIRContext *ctx) {
156 auto hdl = opaquePointerType(ctx);
157 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
158 return FunctionType::get(ctx, {hdl, resume}, {});
159 }
160
addTokenToGroupFunctionType__anonbcc505750111::AsyncAPI161 static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
162 auto i64 = IntegerType::get(ctx, 64);
163 return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
164 {i64});
165 }
166
awaitTokenAndExecuteFunctionType__anonbcc505750111::AsyncAPI167 static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
168 auto hdl = opaquePointerType(ctx);
169 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
170 return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
171 }
172
awaitValueAndExecuteFunctionType__anonbcc505750111::AsyncAPI173 static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
174 auto value = opaquePointerType(ctx);
175 auto hdl = opaquePointerType(ctx);
176 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
177 return FunctionType::get(ctx, {value, hdl, resume}, {});
178 }
179
awaitAllAndExecuteFunctionType__anonbcc505750111::AsyncAPI180 static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
181 auto hdl = opaquePointerType(ctx);
182 auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx));
183 return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
184 }
185
getNumWorkerThreads__anonbcc505750111::AsyncAPI186 static FunctionType getNumWorkerThreads(MLIRContext *ctx) {
187 return FunctionType::get(ctx, {}, {IndexType::get(ctx)});
188 }
189
190 // Auxiliary coroutine resume intrinsic wrapper.
resumeFunctionType__anonbcc505750111::AsyncAPI191 static Type resumeFunctionType(MLIRContext *ctx) {
192 auto voidTy = LLVM::LLVMVoidType::get(ctx);
193 auto i8Ptr = opaquePointerType(ctx);
194 return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
195 }
196 };
197 } // namespace
198
199 /// Adds Async Runtime C API declarations to the module.
addAsyncRuntimeApiDeclarations(ModuleOp module)200 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
201 auto builder =
202 ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
203
204 auto addFuncDecl = [&](StringRef name, FunctionType type) {
205 if (module.lookupSymbol(name))
206 return;
207 builder.create<func::FuncOp>(name, type).setPrivate();
208 };
209
210 MLIRContext *ctx = module.getContext();
211 addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
212 addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
213 addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
214 addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
215 addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
216 addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
217 addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
218 addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
219 addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
220 addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
221 addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
222 addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
223 addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
224 addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
225 addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
226 addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
227 addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
228 addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
229 addFuncDecl(kAwaitTokenAndExecute,
230 AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
231 addFuncDecl(kAwaitValueAndExecute,
232 AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
233 addFuncDecl(kAwaitAllAndExecute,
234 AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
235 addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
236 }
237
238 //===----------------------------------------------------------------------===//
239 // Coroutine resume function wrapper.
240 //===----------------------------------------------------------------------===//
241
242 static constexpr const char *kResume = "__resume";
243
244 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
245 /// intrinsics. We need this function to be able to pass it to the async
246 /// runtime execute API.
addResumeFunction(ModuleOp module)247 static void addResumeFunction(ModuleOp module) {
248 if (module.lookupSymbol(kResume))
249 return;
250
251 MLIRContext *ctx = module.getContext();
252 auto loc = module.getLoc();
253 auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
254
255 auto voidTy = LLVM::LLVMVoidType::get(ctx);
256 auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
257
258 auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
259 kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
260 resumeOp.setPrivate();
261
262 auto *block = resumeOp.addEntryBlock();
263 auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
264
265 blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
266 blockBuilder.create<LLVM::ReturnOp>(ValueRange());
267 }
268
269 //===----------------------------------------------------------------------===//
270 // Convert Async dialect types to LLVM types.
271 //===----------------------------------------------------------------------===//
272
273 namespace {
274 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
275 /// their runtime type (opaque pointers) and does not convert any other types.
276 class AsyncRuntimeTypeConverter : public TypeConverter {
277 public:
AsyncRuntimeTypeConverter()278 AsyncRuntimeTypeConverter() {
279 addConversion([](Type type) { return type; });
280 addConversion(convertAsyncTypes);
281 }
282
convertAsyncTypes(Type type)283 static Optional<Type> convertAsyncTypes(Type type) {
284 if (type.isa<TokenType, GroupType, ValueType>())
285 return AsyncAPI::opaquePointerType(type.getContext());
286
287 if (type.isa<CoroIdType, CoroStateType>())
288 return AsyncAPI::tokenType(type.getContext());
289 if (type.isa<CoroHandleType>())
290 return AsyncAPI::opaquePointerType(type.getContext());
291
292 return llvm::None;
293 }
294 };
295 } // namespace
296
297 //===----------------------------------------------------------------------===//
298 // Convert async.coro.id to @llvm.coro.id intrinsic.
299 //===----------------------------------------------------------------------===//
300
301 namespace {
302 class CoroIdOpConversion : public OpConversionPattern<CoroIdOp> {
303 public:
304 using OpConversionPattern::OpConversionPattern;
305
306 LogicalResult
matchAndRewrite(CoroIdOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const307 matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
308 ConversionPatternRewriter &rewriter) const override {
309 auto token = AsyncAPI::tokenType(op->getContext());
310 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
311 auto loc = op->getLoc();
312
313 // Constants for initializing coroutine frame.
314 auto constZero = rewriter.create<LLVM::ConstantOp>(
315 loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
316 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr);
317
318 // Get coroutine id: @llvm.coro.id.
319 rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
320 op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
321
322 return success();
323 }
324 };
325 } // namespace
326
327 //===----------------------------------------------------------------------===//
328 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
329 //===----------------------------------------------------------------------===//
330
331 namespace {
332 class CoroBeginOpConversion : public OpConversionPattern<CoroBeginOp> {
333 public:
334 using OpConversionPattern::OpConversionPattern;
335
336 LogicalResult
matchAndRewrite(CoroBeginOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const337 matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
338 ConversionPatternRewriter &rewriter) const override {
339 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
340 auto loc = op->getLoc();
341
342 // Get coroutine frame size: @llvm.coro.size.i64.
343 Value coroSize =
344 rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
345 // Get coroutine frame alignment: @llvm.coro.align.i64.
346 Value coroAlign =
347 rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type());
348
349 // Round up the size to be multiple of the alignment. Since aligned_alloc
350 // requires the size parameter be an integral multiple of the alignment
351 // parameter.
352 auto makeConstant = [&](uint64_t c) {
353 return rewriter.create<LLVM::ConstantOp>(
354 op->getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(c));
355 };
356 coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
357 coroSize =
358 rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1));
359 Value negCoroAlign =
360 rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign);
361 coroSize =
362 rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign);
363
364 // Allocate memory for the coroutine frame.
365 auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
366 op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
367 auto coroAlloc = rewriter.create<LLVM::CallOp>(
368 loc, i8Ptr, SymbolRefAttr::get(allocFuncOp),
369 ValueRange{coroAlign, coroSize});
370
371 // Begin a coroutine: @llvm.coro.begin.
372 auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id();
373 rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
374 op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)}));
375
376 return success();
377 }
378 };
379 } // namespace
380
381 //===----------------------------------------------------------------------===//
382 // Convert async.coro.free to @llvm.coro.free intrinsic.
383 //===----------------------------------------------------------------------===//
384
385 namespace {
386 class CoroFreeOpConversion : public OpConversionPattern<CoroFreeOp> {
387 public:
388 using OpConversionPattern::OpConversionPattern;
389
390 LogicalResult
matchAndRewrite(CoroFreeOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const391 matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
392 ConversionPatternRewriter &rewriter) const override {
393 auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext());
394 auto loc = op->getLoc();
395
396 // Get a pointer to the coroutine frame memory: @llvm.coro.free.
397 auto coroMem =
398 rewriter.create<LLVM::CoroFreeOp>(loc, i8Ptr, adaptor.getOperands());
399
400 // Free the memory.
401 auto freeFuncOp =
402 LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
403 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, TypeRange(),
404 SymbolRefAttr::get(freeFuncOp),
405 ValueRange(coroMem.getResult()));
406
407 return success();
408 }
409 };
410 } // namespace
411
412 //===----------------------------------------------------------------------===//
413 // Convert async.coro.end to @llvm.coro.end intrinsic.
414 //===----------------------------------------------------------------------===//
415
416 namespace {
417 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
418 public:
419 using OpConversionPattern::OpConversionPattern;
420
421 LogicalResult
matchAndRewrite(CoroEndOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const422 matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
423 ConversionPatternRewriter &rewriter) const override {
424 // We are not in the block that is part of the unwind sequence.
425 auto constFalse = rewriter.create<LLVM::ConstantOp>(
426 op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
427
428 // Mark the end of a coroutine: @llvm.coro.end.
429 auto coroHdl = adaptor.handle();
430 rewriter.create<LLVM::CoroEndOp>(op->getLoc(), rewriter.getI1Type(),
431 ValueRange({coroHdl, constFalse}));
432 rewriter.eraseOp(op);
433
434 return success();
435 }
436 };
437 } // namespace
438
439 //===----------------------------------------------------------------------===//
440 // Convert async.coro.save to @llvm.coro.save intrinsic.
441 //===----------------------------------------------------------------------===//
442
443 namespace {
444 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
445 public:
446 using OpConversionPattern::OpConversionPattern;
447
448 LogicalResult
matchAndRewrite(CoroSaveOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const449 matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
450 ConversionPatternRewriter &rewriter) const override {
451 // Save the coroutine state: @llvm.coro.save
452 rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
453 op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
454
455 return success();
456 }
457 };
458 } // namespace
459
460 //===----------------------------------------------------------------------===//
461 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
462 //===----------------------------------------------------------------------===//
463
464 namespace {
465
466 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
467 /// branch to the appropriate block based on the return code.
468 ///
469 /// Before:
470 ///
471 /// ^suspended:
472 /// "opBefore"(...)
473 /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup
474 /// ^resume:
475 /// "op"(...)
476 /// ^cleanup: ...
477 /// ^suspend: ...
478 ///
479 /// After:
480 ///
481 /// ^suspended:
482 /// "opBefore"(...)
483 /// %suspend = llmv.intr.coro.suspend ...
484 /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
485 /// ^resume:
486 /// "op"(...)
487 /// ^cleanup: ...
488 /// ^suspend: ...
489 ///
490 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
491 public:
492 using OpConversionPattern::OpConversionPattern;
493
494 LogicalResult
matchAndRewrite(CoroSuspendOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const495 matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
496 ConversionPatternRewriter &rewriter) const override {
497 auto i8 = rewriter.getIntegerType(8);
498 auto i32 = rewriter.getI32Type();
499 auto loc = op->getLoc();
500
501 // This is not a final suspension point.
502 auto constFalse = rewriter.create<LLVM::ConstantOp>(
503 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
504
505 // Suspend a coroutine: @llvm.coro.suspend
506 auto coroState = adaptor.state();
507 auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
508 loc, i8, ValueRange({coroState, constFalse}));
509
510 // Cast return code to i32.
511
512 // After a suspension point decide if we should branch into resume, cleanup
513 // or suspend block of the coroutine (see @llvm.coro.suspend return code
514 // documentation).
515 llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
516 llvm::SmallVector<Block *, 2> caseDest = {op.resumeDest(),
517 op.cleanupDest()};
518 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
519 op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
520 /*defaultDestination=*/op.suspendDest(),
521 /*defaultOperands=*/ValueRange(),
522 /*caseValues=*/caseValues,
523 /*caseDestinations=*/caseDest,
524 /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
525 /*branchWeights=*/ArrayRef<int32_t>());
526
527 return success();
528 }
529 };
530 } // namespace
531
532 //===----------------------------------------------------------------------===//
533 // Convert async.runtime.create to the corresponding runtime API call.
534 //
535 // To allocate storage for the async values we use getelementptr trick:
536 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
537 //===----------------------------------------------------------------------===//
538
539 namespace {
540 class RuntimeCreateOpLowering : public OpConversionPattern<RuntimeCreateOp> {
541 public:
542 using OpConversionPattern::OpConversionPattern;
543
544 LogicalResult
matchAndRewrite(RuntimeCreateOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const545 matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
546 ConversionPatternRewriter &rewriter) const override {
547 TypeConverter *converter = getTypeConverter();
548 Type resultType = op->getResultTypes()[0];
549
550 // Tokens creation maps to a simple function call.
551 if (resultType.isa<TokenType>()) {
552 rewriter.replaceOpWithNewOp<func::CallOp>(
553 op, kCreateToken, converter->convertType(resultType));
554 return success();
555 }
556
557 // To create a value we need to compute the storage requirement.
558 if (auto value = resultType.dyn_cast<ValueType>()) {
559 // Returns the size requirements for the async value storage.
560 auto sizeOf = [&](ValueType valueType) -> Value {
561 auto loc = op->getLoc();
562 auto i64 = rewriter.getI64Type();
563
564 auto storedType = converter->convertType(valueType.getValueType());
565 auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
566
567 // %Size = getelementptr %T* null, int 1
568 // %SizeI = ptrtoint %T* %Size to i64
569 auto nullPtr = rewriter.create<LLVM::NullOp>(loc, storagePtrType);
570 auto one = rewriter.create<LLVM::ConstantOp>(
571 loc, i64, rewriter.getI64IntegerAttr(1));
572 auto gep = rewriter.create<LLVM::GEPOp>(loc, storagePtrType, nullPtr,
573 one.getResult());
574 return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
575 };
576
577 rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType,
578 sizeOf(value));
579
580 return success();
581 }
582
583 return rewriter.notifyMatchFailure(op, "unsupported async type");
584 }
585 };
586 } // namespace
587
588 //===----------------------------------------------------------------------===//
589 // Convert async.runtime.create_group to the corresponding runtime API call.
590 //===----------------------------------------------------------------------===//
591
592 namespace {
593 class RuntimeCreateGroupOpLowering
594 : public OpConversionPattern<RuntimeCreateGroupOp> {
595 public:
596 using OpConversionPattern::OpConversionPattern;
597
598 LogicalResult
matchAndRewrite(RuntimeCreateGroupOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const599 matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
600 ConversionPatternRewriter &rewriter) const override {
601 TypeConverter *converter = getTypeConverter();
602 Type resultType = op.getResult().getType();
603
604 rewriter.replaceOpWithNewOp<func::CallOp>(
605 op, kCreateGroup, converter->convertType(resultType),
606 adaptor.getOperands());
607 return success();
608 }
609 };
610 } // namespace
611
612 //===----------------------------------------------------------------------===//
613 // Convert async.runtime.set_available to the corresponding runtime API call.
614 //===----------------------------------------------------------------------===//
615
616 namespace {
617 class RuntimeSetAvailableOpLowering
618 : public OpConversionPattern<RuntimeSetAvailableOp> {
619 public:
620 using OpConversionPattern::OpConversionPattern;
621
622 LogicalResult
matchAndRewrite(RuntimeSetAvailableOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const623 matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
624 ConversionPatternRewriter &rewriter) const override {
625 StringRef apiFuncName =
626 TypeSwitch<Type, StringRef>(op.operand().getType())
627 .Case<TokenType>([](Type) { return kEmplaceToken; })
628 .Case<ValueType>([](Type) { return kEmplaceValue; });
629
630 rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
631 adaptor.getOperands());
632
633 return success();
634 }
635 };
636 } // namespace
637
638 //===----------------------------------------------------------------------===//
639 // Convert async.runtime.set_error to the corresponding runtime API call.
640 //===----------------------------------------------------------------------===//
641
642 namespace {
643 class RuntimeSetErrorOpLowering
644 : public OpConversionPattern<RuntimeSetErrorOp> {
645 public:
646 using OpConversionPattern::OpConversionPattern;
647
648 LogicalResult
matchAndRewrite(RuntimeSetErrorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const649 matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
650 ConversionPatternRewriter &rewriter) const override {
651 StringRef apiFuncName =
652 TypeSwitch<Type, StringRef>(op.operand().getType())
653 .Case<TokenType>([](Type) { return kSetTokenError; })
654 .Case<ValueType>([](Type) { return kSetValueError; });
655
656 rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
657 adaptor.getOperands());
658
659 return success();
660 }
661 };
662 } // namespace
663
664 //===----------------------------------------------------------------------===//
665 // Convert async.runtime.is_error to the corresponding runtime API call.
666 //===----------------------------------------------------------------------===//
667
668 namespace {
669 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
670 public:
671 using OpConversionPattern::OpConversionPattern;
672
673 LogicalResult
matchAndRewrite(RuntimeIsErrorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const674 matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
675 ConversionPatternRewriter &rewriter) const override {
676 StringRef apiFuncName =
677 TypeSwitch<Type, StringRef>(op.operand().getType())
678 .Case<TokenType>([](Type) { return kIsTokenError; })
679 .Case<GroupType>([](Type) { return kIsGroupError; })
680 .Case<ValueType>([](Type) { return kIsValueError; });
681
682 rewriter.replaceOpWithNewOp<func::CallOp>(
683 op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands());
684 return success();
685 }
686 };
687 } // namespace
688
689 //===----------------------------------------------------------------------===//
690 // Convert async.runtime.await to the corresponding runtime API call.
691 //===----------------------------------------------------------------------===//
692
693 namespace {
694 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
695 public:
696 using OpConversionPattern::OpConversionPattern;
697
698 LogicalResult
matchAndRewrite(RuntimeAwaitOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const699 matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
700 ConversionPatternRewriter &rewriter) const override {
701 StringRef apiFuncName =
702 TypeSwitch<Type, StringRef>(op.operand().getType())
703 .Case<TokenType>([](Type) { return kAwaitToken; })
704 .Case<ValueType>([](Type) { return kAwaitValue; })
705 .Case<GroupType>([](Type) { return kAwaitGroup; });
706
707 rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(),
708 adaptor.getOperands());
709 rewriter.eraseOp(op);
710
711 return success();
712 }
713 };
714 } // namespace
715
716 //===----------------------------------------------------------------------===//
717 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
718 //===----------------------------------------------------------------------===//
719
720 namespace {
721 class RuntimeAwaitAndResumeOpLowering
722 : public OpConversionPattern<RuntimeAwaitAndResumeOp> {
723 public:
724 using OpConversionPattern::OpConversionPattern;
725
726 LogicalResult
matchAndRewrite(RuntimeAwaitAndResumeOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const727 matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
728 ConversionPatternRewriter &rewriter) const override {
729 StringRef apiFuncName =
730 TypeSwitch<Type, StringRef>(op.operand().getType())
731 .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
732 .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
733 .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
734
735 Value operand = adaptor.operand();
736 Value handle = adaptor.handle();
737
738 // A pointer to coroutine resume intrinsic wrapper.
739 addResumeFunction(op->getParentOfType<ModuleOp>());
740 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
741 auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
742 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
743
744 rewriter.create<func::CallOp>(
745 op->getLoc(), apiFuncName, TypeRange(),
746 ValueRange({operand, handle, resumePtr.getRes()}));
747 rewriter.eraseOp(op);
748
749 return success();
750 }
751 };
752 } // namespace
753
754 //===----------------------------------------------------------------------===//
755 // Convert async.runtime.resume to the corresponding runtime API call.
756 //===----------------------------------------------------------------------===//
757
758 namespace {
759 class RuntimeResumeOpLowering : public OpConversionPattern<RuntimeResumeOp> {
760 public:
761 using OpConversionPattern::OpConversionPattern;
762
763 LogicalResult
matchAndRewrite(RuntimeResumeOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const764 matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
765 ConversionPatternRewriter &rewriter) const override {
766 // A pointer to coroutine resume intrinsic wrapper.
767 addResumeFunction(op->getParentOfType<ModuleOp>());
768 auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext());
769 auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
770 op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume);
771
772 // Call async runtime API to execute a coroutine in the managed thread.
773 auto coroHdl = adaptor.handle();
774 rewriter.replaceOpWithNewOp<func::CallOp>(
775 op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()}));
776
777 return success();
778 }
779 };
780 } // namespace
781
782 //===----------------------------------------------------------------------===//
783 // Convert async.runtime.store to the corresponding runtime API call.
784 //===----------------------------------------------------------------------===//
785
786 namespace {
787 class RuntimeStoreOpLowering : public OpConversionPattern<RuntimeStoreOp> {
788 public:
789 using OpConversionPattern::OpConversionPattern;
790
791 LogicalResult
matchAndRewrite(RuntimeStoreOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const792 matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
793 ConversionPatternRewriter &rewriter) const override {
794 Location loc = op->getLoc();
795
796 // Get a pointer to the async value storage from the runtime.
797 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
798 auto storage = adaptor.storage();
799 auto storagePtr = rewriter.create<func::CallOp>(loc, kGetValueStorage,
800 TypeRange(i8Ptr), storage);
801
802 // Cast from i8* to the LLVM pointer type.
803 auto valueType = op.value().getType();
804 auto llvmValueType = getTypeConverter()->convertType(valueType);
805 if (!llvmValueType)
806 return rewriter.notifyMatchFailure(
807 op, "failed to convert stored value type to LLVM type");
808
809 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
810 loc, LLVM::LLVMPointerType::get(llvmValueType),
811 storagePtr.getResult(0));
812
813 // Store the yielded value into the async value storage.
814 auto value = adaptor.value();
815 rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr.getResult());
816
817 // Erase the original runtime store operation.
818 rewriter.eraseOp(op);
819
820 return success();
821 }
822 };
823 } // namespace
824
825 //===----------------------------------------------------------------------===//
826 // Convert async.runtime.load to the corresponding runtime API call.
827 //===----------------------------------------------------------------------===//
828
829 namespace {
830 class RuntimeLoadOpLowering : public OpConversionPattern<RuntimeLoadOp> {
831 public:
832 using OpConversionPattern::OpConversionPattern;
833
834 LogicalResult
matchAndRewrite(RuntimeLoadOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const835 matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
836 ConversionPatternRewriter &rewriter) const override {
837 Location loc = op->getLoc();
838
839 // Get a pointer to the async value storage from the runtime.
840 auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext());
841 auto storage = adaptor.storage();
842 auto storagePtr = rewriter.create<func::CallOp>(loc, kGetValueStorage,
843 TypeRange(i8Ptr), storage);
844
845 // Cast from i8* to the LLVM pointer type.
846 auto valueType = op.result().getType();
847 auto llvmValueType = getTypeConverter()->convertType(valueType);
848 if (!llvmValueType)
849 return rewriter.notifyMatchFailure(
850 op, "failed to convert loaded value type to LLVM type");
851
852 auto castedStoragePtr = rewriter.create<LLVM::BitcastOp>(
853 loc, LLVM::LLVMPointerType::get(llvmValueType),
854 storagePtr.getResult(0));
855
856 // Load from the casted pointer.
857 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, castedStoragePtr.getResult());
858
859 return success();
860 }
861 };
862 } // namespace
863
864 //===----------------------------------------------------------------------===//
865 // Convert async.runtime.add_to_group to the corresponding runtime API call.
866 //===----------------------------------------------------------------------===//
867
868 namespace {
869 class RuntimeAddToGroupOpLowering
870 : public OpConversionPattern<RuntimeAddToGroupOp> {
871 public:
872 using OpConversionPattern::OpConversionPattern;
873
874 LogicalResult
matchAndRewrite(RuntimeAddToGroupOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const875 matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
876 ConversionPatternRewriter &rewriter) const override {
877 // Currently we can only add tokens to the group.
878 if (!op.operand().getType().isa<TokenType>())
879 return rewriter.notifyMatchFailure(op, "only token type is supported");
880
881 // Replace with a runtime API function call.
882 rewriter.replaceOpWithNewOp<func::CallOp>(
883 op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands());
884
885 return success();
886 }
887 };
888 } // namespace
889
890 //===----------------------------------------------------------------------===//
891 // Convert async.runtime.num_worker_threads to the corresponding runtime API
892 // call.
893 //===----------------------------------------------------------------------===//
894
895 namespace {
896 class RuntimeNumWorkerThreadsOpLowering
897 : public OpConversionPattern<RuntimeNumWorkerThreadsOp> {
898 public:
899 using OpConversionPattern::OpConversionPattern;
900
901 LogicalResult
matchAndRewrite(RuntimeNumWorkerThreadsOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const902 matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor,
903 ConversionPatternRewriter &rewriter) const override {
904
905 // Replace with a runtime API function call.
906 rewriter.replaceOpWithNewOp<func::CallOp>(op, kGetNumWorkerThreads,
907 rewriter.getIndexType());
908
909 return success();
910 }
911 };
912 } // namespace
913
914 //===----------------------------------------------------------------------===//
915 // Async reference counting ops lowering (`async.runtime.add_ref` and
916 // `async.runtime.drop_ref` to the corresponding API calls).
917 //===----------------------------------------------------------------------===//
918
919 namespace {
920 template <typename RefCountingOp>
921 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
922 public:
RefCountingOpLowering(TypeConverter & converter,MLIRContext * ctx,StringRef apiFunctionName)923 explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx,
924 StringRef apiFunctionName)
925 : OpConversionPattern<RefCountingOp>(converter, ctx),
926 apiFunctionName(apiFunctionName) {}
927
928 LogicalResult
matchAndRewrite(RefCountingOp op,typename RefCountingOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const929 matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
930 ConversionPatternRewriter &rewriter) const override {
931 auto count = rewriter.create<arith::ConstantOp>(
932 op->getLoc(), rewriter.getI64Type(),
933 rewriter.getI64IntegerAttr(op.count()));
934
935 auto operand = adaptor.operand();
936 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName,
937 ValueRange({operand, count}));
938
939 return success();
940 }
941
942 private:
943 StringRef apiFunctionName;
944 };
945
946 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
947 public:
RuntimeAddRefOpLowering(TypeConverter & converter,MLIRContext * ctx)948 explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
949 : RefCountingOpLowering(converter, ctx, kAddRef) {}
950 };
951
952 class RuntimeDropRefOpLowering
953 : public RefCountingOpLowering<RuntimeDropRefOp> {
954 public:
RuntimeDropRefOpLowering(TypeConverter & converter,MLIRContext * ctx)955 explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx)
956 : RefCountingOpLowering(converter, ctx, kDropRef) {}
957 };
958 } // namespace
959
960 //===----------------------------------------------------------------------===//
961 // Convert return operations that return async values from async regions.
962 //===----------------------------------------------------------------------===//
963
964 namespace {
965 class ReturnOpOpConversion : public OpConversionPattern<func::ReturnOp> {
966 public:
967 using OpConversionPattern::OpConversionPattern;
968
969 LogicalResult
matchAndRewrite(func::ReturnOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const970 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
971 ConversionPatternRewriter &rewriter) const override {
972 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
973 return success();
974 }
975 };
976 } // namespace
977
978 //===----------------------------------------------------------------------===//
979
980 namespace {
981 struct ConvertAsyncToLLVMPass
982 : public ConvertAsyncToLLVMBase<ConvertAsyncToLLVMPass> {
983 void runOnOperation() override;
984 };
985 } // namespace
986
runOnOperation()987 void ConvertAsyncToLLVMPass::runOnOperation() {
988 ModuleOp module = getOperation();
989 MLIRContext *ctx = module->getContext();
990
991 // Add declarations for most functions required by the coroutines lowering.
992 // We delay adding the resume function until it's needed because it currently
993 // fails to compile unless '-O0' is specified.
994 addAsyncRuntimeApiDeclarations(module);
995
996 // Lower async.runtime and async.coro operations to Async Runtime API and
997 // LLVM coroutine intrinsics.
998
999 // Convert async dialect types and operations to LLVM dialect.
1000 AsyncRuntimeTypeConverter converter;
1001 RewritePatternSet patterns(ctx);
1002
1003 // We use conversion to LLVM type to lower async.runtime load and store
1004 // operations.
1005 LLVMTypeConverter llvmConverter(ctx);
1006 llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
1007
1008 // Convert async types in function signatures and function calls.
1009 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1010 converter);
1011 populateCallOpTypeConversionPattern(patterns, converter);
1012
1013 // Convert return operations inside async.execute regions.
1014 patterns.add<ReturnOpOpConversion>(converter, ctx);
1015
1016 // Lower async.runtime operations to the async runtime API calls.
1017 patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
1018 RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
1019 RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
1020 RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering,
1021 RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter,
1022 ctx);
1023
1024 // Lower async.runtime operations that rely on LLVM type converter to convert
1025 // from async value payload type to the LLVM type.
1026 patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
1027 RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter,
1028 ctx);
1029
1030 // Lower async coroutine operations to LLVM coroutine intrinsics.
1031 patterns
1032 .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
1033 CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
1034 converter, ctx);
1035
1036 ConversionTarget target(*ctx);
1037 target.addLegalOp<arith::ConstantOp, func::ConstantOp,
1038 UnrealizedConversionCastOp>();
1039 target.addLegalDialect<LLVM::LLVMDialect>();
1040
1041 // All operations from Async dialect must be lowered to the runtime API and
1042 // LLVM intrinsics calls.
1043 target.addIllegalDialect<AsyncDialect>();
1044
1045 // Add dynamic legality constraints to apply conversions defined above.
1046 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1047 return converter.isSignatureLegal(op.getFunctionType());
1048 });
1049 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
1050 return converter.isLegal(op.getOperandTypes());
1051 });
1052 target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
1053 return converter.isSignatureLegal(op.getCalleeType());
1054 });
1055
1056 if (failed(applyPartialConversion(module, target, std::move(patterns))))
1057 signalPassFailure();
1058 }
1059
1060 //===----------------------------------------------------------------------===//
1061 // Patterns for structural type conversions for the Async dialect operations.
1062 //===----------------------------------------------------------------------===//
1063
1064 namespace {
1065 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1066 public:
1067 using OpConversionPattern::OpConversionPattern;
1068 LogicalResult
matchAndRewrite(ExecuteOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1069 matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
1070 ConversionPatternRewriter &rewriter) const override {
1071 ExecuteOp newOp =
1072 cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1073 rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1074 newOp.getRegion().end());
1075
1076 // Set operands and update block argument and result types.
1077 newOp->setOperands(adaptor.getOperands());
1078 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1079 return failure();
1080 for (auto result : newOp.getResults())
1081 result.setType(typeConverter->convertType(result.getType()));
1082
1083 rewriter.replaceOp(op, newOp.getResults());
1084 return success();
1085 }
1086 };
1087
1088 // Dummy pattern to trigger the appropriate type conversion / materialization.
1089 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1090 public:
1091 using OpConversionPattern::OpConversionPattern;
1092 LogicalResult
matchAndRewrite(AwaitOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1093 matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
1094 ConversionPatternRewriter &rewriter) const override {
1095 rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
1096 return success();
1097 }
1098 };
1099
1100 // Dummy pattern to trigger the appropriate type conversion / materialization.
1101 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1102 public:
1103 using OpConversionPattern::OpConversionPattern;
1104 LogicalResult
matchAndRewrite(async::YieldOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1105 matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
1106 ConversionPatternRewriter &rewriter) const override {
1107 rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
1108 return success();
1109 }
1110 };
1111 } // namespace
1112
createConvertAsyncToLLVMPass()1113 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
1114 return std::make_unique<ConvertAsyncToLLVMPass>();
1115 }
1116
populateAsyncStructuralTypeConversionsAndLegality(TypeConverter & typeConverter,RewritePatternSet & patterns,ConversionTarget & target)1117 void mlir::populateAsyncStructuralTypeConversionsAndLegality(
1118 TypeConverter &typeConverter, RewritePatternSet &patterns,
1119 ConversionTarget &target) {
1120 typeConverter.addConversion([&](TokenType type) { return type; });
1121 typeConverter.addConversion([&](ValueType type) {
1122 Type converted = typeConverter.convertType(type.getValueType());
1123 return converted ? ValueType::get(converted) : converted;
1124 });
1125
1126 patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1127 typeConverter, patterns.getContext());
1128
1129 target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1130 [&](Operation *op) { return typeConverter.isLegal(op); });
1131 }
1132