//===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "../PassDetail.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/FormatVariadic.h" #define DEBUG_TYPE "convert-async-to-llvm" using namespace mlir; using namespace mlir::async; // Prefix for functions outlined from `async.execute` op regions. static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; //===----------------------------------------------------------------------===// // Async Runtime C API declaration. //===----------------------------------------------------------------------===// static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; static constexpr const char *kGetValueStorage = "mlirAsyncRuntimeGetValueStorage"; static constexpr const char *kAddTokenToGroup = "mlirAsyncRuntimeAddTokenToGroup"; static constexpr const char *kAwaitTokenAndExecute = "mlirAsyncRuntimeAwaitTokenAndExecute"; static constexpr const char *kAwaitValueAndExecute = "mlirAsyncRuntimeAwaitValueAndExecute"; static constexpr const char *kAwaitAllAndExecute = "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; namespace { /// Async Runtime API function types. /// /// Because we can't create API function signature for type parametrized /// async.value type, we use opaque pointers (!llvm.ptr) instead. After /// lowering all async data types become opaque pointers at runtime. struct AsyncAPI { // All async types are lowered to opaque i8* LLVM pointers at runtime. static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); } static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { return LLVM::LLVMTokenType::get(ctx); } static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { auto ref = opaquePointerType(ctx); auto count = IntegerType::get(ctx, 32); return FunctionType::get(ctx, {ref, count}, {}); } static FunctionType createTokenFunctionType(MLIRContext *ctx) { return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); } static FunctionType createValueFunctionType(MLIRContext *ctx) { auto i32 = IntegerType::get(ctx, 32); auto value = opaquePointerType(ctx); return FunctionType::get(ctx, {i32}, {value}); } static FunctionType createGroupFunctionType(MLIRContext *ctx) { return FunctionType::get(ctx, {}, {GroupType::get(ctx)}); } static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { auto value = opaquePointerType(ctx); auto storage = opaquePointerType(ctx); return FunctionType::get(ctx, {value}, {storage}); } static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); } static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { auto value = opaquePointerType(ctx); return FunctionType::get(ctx, {value}, {}); } static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); } static FunctionType awaitValueFunctionType(MLIRContext *ctx) { auto value = opaquePointerType(ctx); return FunctionType::get(ctx, {value}, {}); } static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); } static FunctionType executeFunctionType(MLIRContext *ctx) { auto hdl = opaquePointerType(ctx); auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); return FunctionType::get(ctx, {hdl, resume}, {}); } static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { auto i64 = IntegerType::get(ctx, 64); return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, {i64}); } static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { auto hdl = opaquePointerType(ctx); auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); } static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { auto value = opaquePointerType(ctx); auto hdl = opaquePointerType(ctx); auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); return FunctionType::get(ctx, {value, hdl, resume}, {}); } static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { auto hdl = opaquePointerType(ctx); auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); } // Auxiliary coroutine resume intrinsic wrapper. static Type resumeFunctionType(MLIRContext *ctx) { auto voidTy = LLVM::LLVMVoidType::get(ctx); auto i8Ptr = opaquePointerType(ctx); return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); } }; } // namespace /// Adds Async Runtime C API declarations to the module. static void addAsyncRuntimeApiDeclarations(ModuleOp module) { auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(), module.getBody()); auto addFuncDecl = [&](StringRef name, FunctionType type) { if (module.lookupSymbol(name)) return; builder.create(name, type).setPrivate(); }; MLIRContext *ctx = module.getContext(); addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); addFuncDecl(kAwaitTokenAndExecute, AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); addFuncDecl(kAwaitValueAndExecute, AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); addFuncDecl(kAwaitAllAndExecute, AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); } //===----------------------------------------------------------------------===// // LLVM coroutines intrinsics declarations. //===----------------------------------------------------------------------===// static constexpr const char *kCoroId = "llvm.coro.id"; static constexpr const char *kCoroSizeI64 = "llvm.coro.size.i64"; static constexpr const char *kCoroBegin = "llvm.coro.begin"; static constexpr const char *kCoroSave = "llvm.coro.save"; static constexpr const char *kCoroSuspend = "llvm.coro.suspend"; static constexpr const char *kCoroEnd = "llvm.coro.end"; static constexpr const char *kCoroFree = "llvm.coro.free"; static constexpr const char *kCoroResume = "llvm.coro.resume"; static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder, StringRef name, Type ret, ArrayRef params) { if (module.lookupSymbol(name)) return; Type type = LLVM::LLVMFunctionType::get(ret, params); builder.create(name, type); } /// Adds coroutine intrinsics declarations to the module. static void addCoroutineIntrinsicsDeclarations(ModuleOp module) { using namespace mlir::LLVM; MLIRContext *ctx = module.getContext(); ImplicitLocOpBuilder builder(module.getLoc(), module.getBody()->getTerminator()); auto token = LLVMTokenType::get(ctx); auto voidTy = LLVMVoidType::get(ctx); auto i8 = IntegerType::get(ctx, 8); auto i1 = IntegerType::get(ctx, 1); auto i32 = IntegerType::get(ctx, 32); auto i64 = IntegerType::get(ctx, 64); auto i8Ptr = LLVMPointerType::get(i8); addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr}); addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {}); addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr}); addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr}); addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1}); addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1}); addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr}); addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr}); } //===----------------------------------------------------------------------===// // Add malloc/free declarations to the module. //===----------------------------------------------------------------------===// static constexpr const char *kMalloc = "malloc"; static constexpr const char *kFree = "free"; /// Adds malloc/free declarations to the module. static void addCRuntimeDeclarations(ModuleOp module) { using namespace mlir::LLVM; MLIRContext *ctx = module.getContext(); ImplicitLocOpBuilder builder(module.getLoc(), module.getBody()->getTerminator()); auto voidTy = LLVMVoidType::get(ctx); auto i64 = IntegerType::get(ctx, 64); auto i8Ptr = LLVMPointerType::get(IntegerType::get(ctx, 8)); addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); } //===----------------------------------------------------------------------===// // Coroutine resume function wrapper. //===----------------------------------------------------------------------===// static constexpr const char *kResume = "__resume"; /// A function that takes a coroutine handle and calls a `llvm.coro.resume` /// intrinsics. We need this function to be able to pass it to the async /// runtime execute API. static void addResumeFunction(ModuleOp module) { MLIRContext *ctx = module.getContext(); OpBuilder moduleBuilder(module.getBody()->getTerminator()); Location loc = module.getLoc(); if (module.lookupSymbol(kResume)) return; auto voidTy = LLVM::LLVMVoidType::get(ctx); auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); auto resumeOp = moduleBuilder.create( loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); resumeOp.setPrivate(); auto *block = resumeOp.addEntryBlock(); auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); blockBuilder.create(TypeRange(), blockBuilder.getSymbolRefAttr(kCoroResume), resumeOp.getArgument(0)); blockBuilder.create(ValueRange()); } //===----------------------------------------------------------------------===// // async.execute op outlining to the coroutine functions. //===----------------------------------------------------------------------===// /// Function targeted for coroutine transformation has two additional blocks at /// the end: coroutine cleanup and coroutine suspension. /// /// async.await op lowering additionaly creates a resume block for each /// operation to enable non-blocking waiting via coroutine suspension. namespace { struct CoroMachinery { // Async execute region returns a completion token, and an async value for // each yielded value. // // %token, %result = async.execute -> !async.value { // %0 = constant ... : T // async.yield %0 : T // } Value asyncToken; // token representing completion of the async region llvm::SmallVector returnValues; // returned async values Value coroHandle; // coroutine handle (!async.coro.handle value) Block *cleanup; // coroutine cleanup block Block *suspend; // coroutine suspension block }; } // namespace /// Builds an coroutine template compatible with LLVM coroutines switched-resume /// lowering using `async.runtime.*` and `async.coro.*` operations. /// /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html /// /// - `entry` block sets up the coroutine. /// - `cleanup` block cleans up the coroutine state. /// - `suspend block after the @llvm.coro.end() defines what value will be /// returned to the initial caller of a coroutine. Everything before the /// @llvm.coro.end() will be executed at every suspension point. /// /// Coroutine structure (only the important bits): /// /// func @async_execute_fn() /// -> (!async.token, !async.value) /// { /// ^entry(): /// %token = : !async.token // create async runtime token /// %value = : !async.value // create async value /// %id = async.coro.id // create a coroutine id /// %hdl = async.coro.begin %id // create a coroutine handle /// br ^cleanup /// /// ^cleanup: /// async.coro.free %hdl // delete the coroutine state /// br ^suspend /// /// ^suspend: /// async.coro.end %hdl // marks the end of a coroutine /// return %token, %value : !async.token, !async.value /// } /// /// The actual code for the async.execute operation body region will be inserted /// before the entry block terminator. /// /// static CoroMachinery setupCoroMachinery(FuncOp func) { assert(func.getBody().empty() && "Function must have empty body"); MLIRContext *ctx = func.getContext(); Block *entryBlock = func.addEntryBlock(); auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); // ------------------------------------------------------------------------ // // Allocate async token/values that we will return from a ramp function. // ------------------------------------------------------------------------ // auto retToken = builder.create(TokenType::get(ctx)).result(); llvm::SmallVector retValues; for (auto resType : func.getCallableResults().drop_front()) retValues.emplace_back(builder.create(resType).result()); // ------------------------------------------------------------------------ // // Initialize coroutine: get coroutine id and coroutine handle. // ------------------------------------------------------------------------ // auto coroIdOp = builder.create(CoroIdType::get(ctx)); auto coroHdlOp = builder.create(CoroHandleType::get(ctx), coroIdOp.id()); Block *cleanupBlock = func.addBlock(); Block *suspendBlock = func.addBlock(); // ------------------------------------------------------------------------ // // Coroutine cleanup block: deallocate coroutine frame, free the memory. // ------------------------------------------------------------------------ // builder.setInsertionPointToStart(cleanupBlock); builder.create(coroIdOp.id(), coroHdlOp.handle()); // Branch into the suspend block. builder.create(suspendBlock); // ------------------------------------------------------------------------ // // Coroutine suspend block: mark the end of a coroutine and return allocated // async token. // ------------------------------------------------------------------------ // builder.setInsertionPointToStart(suspendBlock); // Mark the end of a coroutine: async.coro.end builder.create(coroHdlOp.handle()); // Return created `async.token` and `async.values` from the suspend block. // This will be the return value of a coroutine ramp function. SmallVector ret{retToken}; ret.insert(ret.end(), retValues.begin(), retValues.end()); builder.create(ret); // Branch from the entry block to the cleanup block to create a valid CFG. builder.setInsertionPointToEnd(entryBlock); builder.create(cleanupBlock); // `async.await` op lowering will create resume blocks for async // continuations, and will conditionally branch to cleanup or suspend blocks. CoroMachinery machinery; machinery.asyncToken = retToken; machinery.returnValues = retValues; machinery.coroHandle = coroHdlOp.handle(); machinery.cleanup = cleanupBlock; machinery.suspend = suspendBlock; return machinery; } /// Outline the body region attached to the `async.execute` op into a standalone /// function. /// /// Note that this is not reversible transformation. static std::pair outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { ModuleOp module = execute->getParentOfType(); MLIRContext *ctx = module.getContext(); Location loc = execute.getLoc(); // Collect all outlined function inputs. llvm::SetVector functionInputs(execute.dependencies().begin(), execute.dependencies().end()); functionInputs.insert(execute.operands().begin(), execute.operands().end()); getUsedValuesDefinedAbove(execute.body(), functionInputs); // Collect types for the outlined function inputs and outputs. auto typesRange = llvm::map_range( functionInputs, [](Value value) { return value.getType(); }); SmallVector inputTypes(typesRange.begin(), typesRange.end()); auto outputTypes = execute.getResultTypes(); auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); auto funcAttrs = ArrayRef(); // TODO: Derive outlined function name from the parent FuncOp (support // multiple nested async.execute operations). FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator())); SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); // Prepare a function for coroutine lowering by adding entry/cleanup/suspend // blocks, adding async.coro operations and setting up control flow. CoroMachinery coro = setupCoroMachinery(func); // Suspend async function at the end of an entry block, and resume it using // Async resume operation (execution will be resumed in a thread managed by // the async runtime). Block *entryBlock = &func.getBlocks().front(); auto builder = ImplicitLocOpBuilder::atBlockTerminator(loc, entryBlock); // Save the coroutine state: async.coro.save auto coroSaveOp = builder.create(CoroStateType::get(ctx), coro.coroHandle); // Pass coroutine to the runtime to be resumed on a runtime managed thread. builder.create(coro.coroHandle); // Split the entry block before the terminator (branch to suspend block). auto *terminatorOp = entryBlock->getTerminator(); Block *suspended = terminatorOp->getBlock(); Block *resume = suspended->splitBlock(terminatorOp); // Add async.coro.suspend as a suspended block terminator. builder.setInsertionPointToEnd(suspended); builder.create(coroSaveOp.state(), coro.suspend, resume, coro.cleanup); size_t numDependencies = execute.dependencies().size(); size_t numOperands = execute.operands().size(); // Await on all dependencies before starting to execute the body region. builder.setInsertionPointToStart(resume); for (size_t i = 0; i < numDependencies; ++i) builder.create(func.getArgument(i)); // Await on all async value operands and unwrap the payload. SmallVector unwrappedOperands(numOperands); for (size_t i = 0; i < numOperands; ++i) { Value operand = func.getArgument(numDependencies + i); unwrappedOperands[i] = builder.create(loc, operand).result(); } // Map from function inputs defined above the execute op to the function // arguments. BlockAndValueMapping valueMapping; valueMapping.map(functionInputs, func.getArguments()); valueMapping.map(execute.body().getArguments(), unwrappedOperands); // Clone all operations from the execute operation body into the outlined // function body. for (Operation &op : execute.body().getOps()) builder.clone(op, valueMapping); // Replace the original `async.execute` with a call to outlined function. ImplicitLocOpBuilder callBuilder(loc, execute); auto callOutlinedFunc = callBuilder.create( func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); execute.replaceAllUsesWith(callOutlinedFunc.getResults()); execute.erase(); return {func, coro}; } //===----------------------------------------------------------------------===// // Convert Async dialect types to LLVM types. //===----------------------------------------------------------------------===// namespace { /// AsyncRuntimeTypeConverter only converts types from the Async dialect to /// their runtime type (opaque pointers) and does not convert any other types. class AsyncRuntimeTypeConverter : public TypeConverter { public: AsyncRuntimeTypeConverter() { addConversion([](Type type) { return type; }); addConversion(convertAsyncTypes); } static Optional convertAsyncTypes(Type type) { if (type.isa()) return AsyncAPI::opaquePointerType(type.getContext()); if (type.isa()) return AsyncAPI::tokenType(type.getContext()); if (type.isa()) return AsyncAPI::opaquePointerType(type.getContext()); return llvm::None; } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.coro.id to @llvm.coro.id intrinsic. //===----------------------------------------------------------------------===// namespace { class CoroIdOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CoroIdOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto token = AsyncAPI::tokenType(op->getContext()); auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); auto loc = op->getLoc(); // Constants for initializing coroutine frame. auto constZero = rewriter.create( loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); auto nullPtr = rewriter.create(loc, i8Ptr); // Get coroutine id: @llvm.coro.id. rewriter.replaceOpWithNewOp( op, token, rewriter.getSymbolRefAttr(kCoroId), ValueRange({constZero, nullPtr, nullPtr, nullPtr})); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.coro.begin to @llvm.coro.begin intrinsic. //===----------------------------------------------------------------------===// namespace { class CoroBeginOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CoroBeginOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); auto loc = op->getLoc(); // Get coroutine frame size: @llvm.coro.size.i64. auto coroSize = rewriter.create( loc, rewriter.getI64Type(), rewriter.getSymbolRefAttr(kCoroSizeI64), ValueRange()); // Allocate memory for the coroutine frame. auto coroAlloc = rewriter.create( loc, i8Ptr, rewriter.getSymbolRefAttr(kMalloc), ValueRange(coroSize.getResult(0))); // Begin a coroutine: @llvm.coro.begin. auto coroId = CoroBeginOpAdaptor(operands).id(); rewriter.replaceOpWithNewOp( op, i8Ptr, rewriter.getSymbolRefAttr(kCoroBegin), ValueRange({coroId, coroAlloc.getResult(0)})); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.coro.free to @llvm.coro.free intrinsic. //===----------------------------------------------------------------------===// namespace { class CoroFreeOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CoroFreeOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); auto loc = op->getLoc(); // Get a pointer to the coroutine frame memory: @llvm.coro.free. auto coroMem = rewriter.create( loc, i8Ptr, rewriter.getSymbolRefAttr(kCoroFree), operands); // Free the memory. rewriter.replaceOpWithNewOp(op, TypeRange(), rewriter.getSymbolRefAttr(kFree), ValueRange(coroMem.getResult(0))); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.coro.end to @llvm.coro.end intrinsic. //===----------------------------------------------------------------------===// namespace { class CoroEndOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CoroEndOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // We are not in the block that is part of the unwind sequence. auto constFalse = rewriter.create( op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); // Mark the end of a coroutine: @llvm.coro.end. auto coroHdl = CoroEndOpAdaptor(operands).handle(); rewriter.create(op->getLoc(), rewriter.getI1Type(), rewriter.getSymbolRefAttr(kCoroEnd), ValueRange({coroHdl, constFalse})); rewriter.eraseOp(op); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.coro.save to @llvm.coro.save intrinsic. //===----------------------------------------------------------------------===// namespace { class CoroSaveOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CoroSaveOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Save the coroutine state: @llvm.coro.save rewriter.replaceOpWithNewOp( op, AsyncAPI::tokenType(op->getContext()), rewriter.getSymbolRefAttr(kCoroSave), operands); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. //===----------------------------------------------------------------------===// namespace { /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and /// branch to the appropriate block based on the return code. /// /// Before: /// /// ^suspended: /// "opBefore"(...) /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup /// ^resume: /// "op"(...) /// ^cleanup: ... /// ^suspend: ... /// /// After: /// /// ^suspended: /// "opBefore"(...) /// %suspend = llmv.call @llvm.coro.suspend(...) /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] /// ^resume: /// "op"(...) /// ^cleanup: ... /// ^suspend: ... /// class CoroSuspendOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CoroSuspendOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto i8 = rewriter.getIntegerType(8); auto i32 = rewriter.getI32Type(); auto loc = op->getLoc(); // This is not a final suspension point. auto constFalse = rewriter.create( loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); // Suspend a coroutine: @llvm.coro.suspend auto coroState = CoroSuspendOpAdaptor(operands).state(); auto coroSuspend = rewriter.create( loc, i8, rewriter.getSymbolRefAttr(kCoroSuspend), ValueRange({coroState, constFalse})); // Cast return code to i32. // After a suspension point decide if we should branch into resume, cleanup // or suspend block of the coroutine (see @llvm.coro.suspend return code // documentation). llvm::SmallVector caseValues = {0, 1}; llvm::SmallVector caseDest = {op.resumeDest(), op.cleanupDest()}; rewriter.replaceOpWithNewOp( op, rewriter.create(loc, i32, coroSuspend.getResult(0)), /*defaultDestination=*/op.suspendDest(), /*defaultOperands=*/ValueRange(), /*caseValues=*/caseValues, /*caseDestinations=*/caseDest, /*caseOperands=*/ArrayRef(), /*branchWeights=*/ArrayRef()); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.create to the corresponding runtime API call. // // To allocate storage for the async values we use getelementptr trick: // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt //===----------------------------------------------------------------------===// namespace { class RuntimeCreateOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RuntimeCreateOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { TypeConverter *converter = getTypeConverter(); Type resultType = op->getResultTypes()[0]; // Tokens and Groups lowered to function calls without arguments. if (resultType.isa() || resultType.isa()) { rewriter.replaceOpWithNewOp( op, resultType.isa() ? kCreateToken : kCreateGroup, converter->convertType(resultType)); return success(); } // To create a value we need to compute the storage requirement. if (auto value = resultType.dyn_cast()) { // Returns the size requirements for the async value storage. auto sizeOf = [&](ValueType valueType) -> Value { auto loc = op->getLoc(); auto i32 = rewriter.getI32Type(); auto storedType = converter->convertType(valueType.getValueType()); auto storagePtrType = LLVM::LLVMPointerType::get(storedType); // %Size = getelementptr %T* null, int 1 // %SizeI = ptrtoint %T* %Size to i32 auto nullPtr = rewriter.create(loc, storagePtrType); auto one = rewriter.create( loc, i32, rewriter.getI32IntegerAttr(1)); auto gep = rewriter.create(loc, storagePtrType, nullPtr, one.getResult()); return rewriter.create(loc, i32, gep); }; rewriter.replaceOpWithNewOp(op, kCreateValue, resultType, sizeOf(value)); return success(); } return rewriter.notifyMatchFailure(op, "unsupported async type"); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.set_available to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeSetAvailableOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Type operandType = op.operand().getType(); if (operandType.isa() || operandType.isa()) { rewriter.create(op->getLoc(), operandType.isa() ? kEmplaceToken : kEmplaceValue, TypeRange(), operands); rewriter.eraseOp(op); return success(); } return rewriter.notifyMatchFailure(op, "unsupported async type"); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.await to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeAwaitOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RuntimeAwaitOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Type operandType = op.operand().getType(); StringRef apiFuncName; if (operandType.isa()) apiFuncName = kAwaitToken; else if (operandType.isa()) apiFuncName = kAwaitValue; else if (operandType.isa()) apiFuncName = kAwaitGroup; else return rewriter.notifyMatchFailure(op, "unsupported async type"); rewriter.create(op->getLoc(), apiFuncName, TypeRange(), operands); rewriter.eraseOp(op); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.await_and_resume to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeAwaitAndResumeOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Type operandType = op.operand().getType(); StringRef apiFuncName; if (operandType.isa()) apiFuncName = kAwaitTokenAndExecute; else if (operandType.isa()) apiFuncName = kAwaitValueAndExecute; else if (operandType.isa()) apiFuncName = kAwaitAllAndExecute; else return rewriter.notifyMatchFailure(op, "unsupported async type"); Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand(); Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle(); // A pointer to coroutine resume intrinsic wrapper. auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); auto resumePtr = rewriter.create( op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); rewriter.create(op->getLoc(), apiFuncName, TypeRange(), ValueRange({operand, handle, resumePtr.res()})); rewriter.eraseOp(op); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.resume to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeResumeOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RuntimeResumeOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // A pointer to coroutine resume intrinsic wrapper. auto resumeFnTy = AsyncAPI::resumeFunctionType(op->getContext()); auto resumePtr = rewriter.create( op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); // Call async runtime API to execute a coroutine in the managed thread. auto coroHdl = RuntimeResumeOpAdaptor(operands).handle(); rewriter.replaceOpWithNewOp(op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.res()})); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.store to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeStoreOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RuntimeStoreOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); // Get a pointer to the async value storage from the runtime. auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); auto storage = RuntimeStoreOpAdaptor(operands).storage(); auto storagePtr = rewriter.create(loc, kGetValueStorage, TypeRange(i8Ptr), storage); // Cast from i8* to the LLVM pointer type. auto valueType = op.value().getType(); auto llvmValueType = getTypeConverter()->convertType(valueType); auto castedStoragePtr = rewriter.create( loc, LLVM::LLVMPointerType::get(llvmValueType), storagePtr.getResult(0)); // Store the yielded value into the async value storage. auto value = RuntimeStoreOpAdaptor(operands).value(); rewriter.create(loc, value, castedStoragePtr.getResult()); // Erase the original runtime store operation. rewriter.eraseOp(op); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.load to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeLoadOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RuntimeLoadOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); // Get a pointer to the async value storage from the runtime. auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); auto storage = RuntimeLoadOpAdaptor(operands).storage(); auto storagePtr = rewriter.create(loc, kGetValueStorage, TypeRange(i8Ptr), storage); // Cast from i8* to the LLVM pointer type. auto valueType = op.result().getType(); auto llvmValueType = getTypeConverter()->convertType(valueType); auto castedStoragePtr = rewriter.create( loc, LLVM::LLVMPointerType::get(llvmValueType), storagePtr.getResult(0)); // Load from the casted pointer. rewriter.replaceOpWithNewOp(op, castedStoragePtr.getResult()); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.add_to_group to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeAddToGroupOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Currently we can only add tokens to the group. if (!op.operand().getType().isa()) return rewriter.notifyMatchFailure(op, "only token type is supported"); // Replace with a runtime API function call. rewriter.replaceOpWithNewOp(op, kAddTokenToGroup, rewriter.getI64Type(), operands); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Async reference counting ops lowering (`async.runtime.add_ref` and // `async.runtime.drop_ref` to the corresponding API calls). //===----------------------------------------------------------------------===// namespace { template class RefCountingOpLowering : public OpConversionPattern { public: explicit RefCountingOpLowering(TypeConverter &converter, MLIRContext *ctx, StringRef apiFunctionName) : OpConversionPattern(converter, ctx), apiFunctionName(apiFunctionName) {} LogicalResult matchAndRewrite(RefCountingOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto count = rewriter.create(op->getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(op.count())); auto operand = typename RefCountingOp::Adaptor(operands).operand(); rewriter.replaceOpWithNewOp(op, TypeRange(), apiFunctionName, ValueRange({operand, count})); return success(); } private: StringRef apiFunctionName; }; class RuntimeAddRefOpLowering : public RefCountingOpLowering { public: explicit RuntimeAddRefOpLowering(TypeConverter &converter, MLIRContext *ctx) : RefCountingOpLowering(converter, ctx, kAddRef) {} }; class RuntimeDropRefOpLowering : public RefCountingOpLowering { public: explicit RuntimeDropRefOpLowering(TypeConverter &converter, MLIRContext *ctx) : RefCountingOpLowering(converter, ctx, kDropRef) {} }; } // namespace //===----------------------------------------------------------------------===// // Convert return operations that return async values from async regions. //===----------------------------------------------------------------------===// namespace { class ReturnOpOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ReturnOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, operands); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.create_group operation to async.runtime.create //===----------------------------------------------------------------------===// namespace { class CreateGroupOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CreateGroupOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, GroupType::get(op->getContext())); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.add_to_group operation to async.runtime.add_to_group. //===----------------------------------------------------------------------===// namespace { class AddToGroupOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AddToGroupOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, rewriter.getIndexType(), operands); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.await and async.await_all operations to the async.runtime.await // or async.runtime.await_and_resume operations. //===----------------------------------------------------------------------===// namespace { template class AwaitOpLoweringBase : public OpConversionPattern { using AwaitAdaptor = typename AwaitType::Adaptor; public: AwaitOpLoweringBase( MLIRContext *ctx, const llvm::DenseMap &outlinedFunctions) : OpConversionPattern(ctx), outlinedFunctions(outlinedFunctions) {} LogicalResult matchAndRewrite(AwaitType op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // We can only await on one the `AwaitableType` (for `await` it can be // a `token` or a `value`, for `await_all` it must be a `group`). if (!op.operand().getType().template isa()) return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); // Check if await operation is inside the outlined coroutine function. auto func = op->template getParentOfType(); auto outlined = outlinedFunctions.find(func); const bool isInCoroutine = outlined != outlinedFunctions.end(); Location loc = op->getLoc(); Value operand = AwaitAdaptor(operands).operand(); // Inside regular functions we use the blocking wait operation to wait for // the async object (token, value or group) to become available. if (!isInCoroutine) rewriter.create(loc, operand); // Inside the coroutine we convert await operation into coroutine suspension // point, and resume execution asynchronously. if (isInCoroutine) { const CoroMachinery &coro = outlined->getSecond(); Block *suspended = op->getBlock(); ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); MLIRContext *ctx = op->getContext(); // Save the coroutine state and resume on a runtime managed thread when // the operand becomes available. auto coroSaveOp = builder.create(CoroStateType::get(ctx), coro.coroHandle); builder.create(operand, coro.coroHandle); // Split the entry block before the await operation. Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); // Add async.coro.suspend as a suspended block terminator. builder.setInsertionPointToEnd(suspended); builder.create(coroSaveOp.state(), coro.suspend, resume, coro.cleanup); // Make sure that replacement value will be constructed in resume block. rewriter.setInsertionPointToStart(resume); } // Erase or replace the await operation with the new value. if (Value replaceWith = getReplacementValue(op, operand, rewriter)) rewriter.replaceOp(op, replaceWith); else rewriter.eraseOp(op); return success(); } virtual Value getReplacementValue(AwaitType op, Value operand, ConversionPatternRewriter &rewriter) const { return Value(); } private: const llvm::DenseMap &outlinedFunctions; }; /// Lowering for `async.await` with a token operand. class AwaitTokenOpLowering : public AwaitOpLoweringBase { using Base = AwaitOpLoweringBase; public: using Base::Base; }; /// Lowering for `async.await` with a value operand. class AwaitValueOpLowering : public AwaitOpLoweringBase { using Base = AwaitOpLoweringBase; public: using Base::Base; Value getReplacementValue(AwaitOp op, Value operand, ConversionPatternRewriter &rewriter) const override { // Load from the async value storage. auto valueType = operand.getType().cast().getValueType(); return rewriter.create(op->getLoc(), valueType, operand); } }; /// Lowering for `async.await_all` operation. class AwaitAllOpLowering : public AwaitOpLoweringBase { using Base = AwaitOpLoweringBase; public: using Base::Base; }; } // namespace //===----------------------------------------------------------------------===// // Convert async.yield operation to async.runtime operations. //===----------------------------------------------------------------------===// class YieldOpLowering : public OpConversionPattern { public: YieldOpLowering( MLIRContext *ctx, const llvm::DenseMap &outlinedFunctions) : OpConversionPattern(ctx), outlinedFunctions(outlinedFunctions) {} LogicalResult matchAndRewrite(async::YieldOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Check if yield operation is inside the outlined coroutine function. auto func = op->template getParentOfType(); auto outlined = outlinedFunctions.find(func); if (outlined == outlinedFunctions.end()) return rewriter.notifyMatchFailure( op, "operation is not inside the outlined async.execute function"); Location loc = op->getLoc(); const CoroMachinery &coro = outlined->getSecond(); // Store yielded values into the async values storage and switch async // values state to available. for (auto tuple : llvm::zip(operands, coro.returnValues)) { Value yieldValue = std::get<0>(tuple); Value asyncValue = std::get<1>(tuple); rewriter.create(loc, yieldValue, asyncValue); rewriter.create(loc, asyncValue); } // Switch the coroutine completion token to available state. rewriter.replaceOpWithNewOp(op, coro.asyncToken); return success(); } private: const llvm::DenseMap &outlinedFunctions; }; //===----------------------------------------------------------------------===// namespace { struct ConvertAsyncToLLVMPass : public ConvertAsyncToLLVMBase { void runOnOperation() override; }; } // namespace void ConvertAsyncToLLVMPass::runOnOperation() { ModuleOp module = getOperation(); SymbolTable symbolTable(module); MLIRContext *ctx = &getContext(); // Outline all `async.execute` body regions into async functions (coroutines). llvm::DenseMap outlinedFunctions; // We use conversion to LLVM type to ensure that all `async.value` operands // and results can be lowered to LLVM load and store operations. LLVMTypeConverter llvmConverter(ctx); llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); // Returns true if the `async.value` payload is convertible to LLVM. auto isConvertibleToLlvm = [&](Type type) -> bool { auto valueType = type.cast().getValueType(); return static_cast(llvmConverter.convertType(valueType)); }; WalkResult outlineResult = module.walk([&](ExecuteOp execute) { // All operands and results must be convertible to LLVM. if (!llvm::all_of(execute.operands().getTypes(), isConvertibleToLlvm)) { execute.emitOpError("operands payload must be convertible to LLVM type"); return WalkResult::interrupt(); } if (!llvm::all_of(execute.results().getTypes(), isConvertibleToLlvm)) { execute.emitOpError("results payload must be convertible to LLVM type"); return WalkResult::interrupt(); } outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); return WalkResult::advance(); }); // Failed to outline all async execute operations. if (outlineResult.wasInterrupted()) { signalPassFailure(); return; } LLVM_DEBUG({ llvm::dbgs() << "Outlined " << outlinedFunctions.size() << " async functions\n"; }); // Add declarations for all functions required by the coroutines lowering. addResumeFunction(module); addAsyncRuntimeApiDeclarations(module); addCoroutineIntrinsicsDeclarations(module); addCRuntimeDeclarations(module); // ------------------------------------------------------------------------ // // Lower async operations to async.runtime operations. // ------------------------------------------------------------------------ // OwningRewritePatternList asyncPatterns; // Async lowering does not use type converter because it must preserve all // types for async.runtime operations. asyncPatterns.insert(ctx); asyncPatterns.insert(ctx, outlinedFunctions); // All high level async operations must be lowered to the runtime operations. ConversionTarget runtimeTarget(*ctx); runtimeTarget.addLegalDialect(); runtimeTarget.addIllegalOp(); runtimeTarget.addIllegalOp(); if (failed(applyPartialConversion(module, runtimeTarget, std::move(asyncPatterns)))) { signalPassFailure(); return; } // ------------------------------------------------------------------------ // // Lower async.runtime and async.coro operations to Async Runtime API and // LLVM coroutine intrinsics. // ------------------------------------------------------------------------ // // Convert async dialect types and operations to LLVM dialect. AsyncRuntimeTypeConverter converter; OwningRewritePatternList patterns; // Convert async types in function signatures and function calls. populateFuncOpTypeConversionPattern(patterns, ctx, converter); populateCallOpTypeConversionPattern(patterns, ctx, converter); // Convert return operations inside async.execute regions. patterns.insert(converter, ctx); // Lower async.runtime operations to the async runtime API calls. patterns.insert(converter, ctx); // Lower async.runtime operations that rely on LLVM type converter to convert // from async value payload type to the LLVM type. patterns.insert(llvmConverter, ctx); // Lower async coroutine operations to LLVM coroutine intrinsics. patterns.insert(converter, ctx); ConversionTarget target(*ctx); target.addLegalOp(); target.addLegalDialect(); // All operations from Async dialect must be lowered to the runtime API and // LLVM intrinsics calls. target.addIllegalDialect(); // Add dynamic legality constraints to apply conversions defined above. target.addDynamicallyLegalOp( [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); target.addDynamicallyLegalOp( [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); target.addDynamicallyLegalOp([&](CallOp op) { return converter.isSignatureLegal(op.getCalleeType()); }); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } //===----------------------------------------------------------------------===// // Patterns for structural type conversions for the Async dialect operations. //===----------------------------------------------------------------------===// namespace { class ConvertExecuteOpTypes : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ExecuteOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { ExecuteOp newOp = cast(rewriter.cloneWithoutRegions(*op.getOperation())); rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), newOp.getRegion().end()); // Set operands and update block argument and result types. newOp->setOperands(operands); if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) return failure(); for (auto result : newOp.getResults()) result.setType(typeConverter->convertType(result.getType())); rewriter.replaceOp(op, newOp.getResults()); return success(); } }; // Dummy pattern to trigger the appropriate type conversion / materialization. class ConvertAwaitOpTypes : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AwaitOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, operands.front()); return success(); } }; // Dummy pattern to trigger the appropriate type conversion / materialization. class ConvertYieldOpTypes : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(async::YieldOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, operands); return success(); } }; } // namespace std::unique_ptr> mlir::createConvertAsyncToLLVMPass() { return std::make_unique(); } void mlir::populateAsyncStructuralTypeConversionsAndLegality( MLIRContext *context, TypeConverter &typeConverter, OwningRewritePatternList &patterns, ConversionTarget &target) { typeConverter.addConversion([&](TokenType type) { return type; }); typeConverter.addConversion([&](ValueType type) { return ValueType::get(typeConverter.convertType(type.getValueType())); }); patterns .insert( typeConverter, context); target.addDynamicallyLegalOp( [&](Operation *op) { return typeConverter.isLegal(op); }); }