1 //===- AsyncRegionRewriter.cpp - Implementation of GPU async rewriters ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the GPU dialect pattern rewriters that make GPU op 10 // within a region execute asynchronously. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/GPU/GPUDialect.h" 16 #include "mlir/Dialect/GPU/Passes.h" 17 #include "mlir/Dialect/GPU/Utils.h" 18 #include "mlir/Dialect/StandardOps/IR/Ops.h" 19 #include "mlir/IR/BlockAndValueMapping.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/IR/SymbolTable.h" 23 #include "mlir/Support/LLVM.h" 24 #include "mlir/Transforms/RegionUtils.h" 25 26 using namespace mlir; 27 namespace { 28 class GpuAsyncRegionPass : public GpuAsyncRegionPassBase<GpuAsyncRegionPass> { 29 struct Callback; 30 void runOnFunction() override; 31 }; 32 } // namespace 33 34 // Region walk callback which makes GPU ops implementing the AsyncOpInterface 35 // execute asynchronously. 36 struct GpuAsyncRegionPass::Callback { 37 // If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to 38 // create a current token (unless it already exists), and 'thread' that token 39 // through the `op` so that it executes asynchronously. 40 // 41 // If `op` is a terminator or an op with side-effects, insert a `gpu.wait` to 42 // host-synchronize execution. 43 WalkResult operator()(Operation *op) { 44 if (isa<gpu::LaunchOp>(op)) 45 return op->emitOpError("replace with gpu.launch_func first"); 46 if (isa<gpu::WaitOp>(op)) 47 return op->emitOpError("unexpected pre-existing gpu.wait"); 48 builder.setInsertionPoint(op); 49 if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op)) 50 return rewriteAsyncOp(asyncOp); // Replace GPU op with async version. 51 if (!currentToken) 52 return success(); 53 if (!op->hasTrait<OpTrait::IsTerminator>() && 54 MemoryEffectOpInterface::hasNoEffect(op)) 55 return success(); 56 // Insert host synchronization before terminator or op with side effects. 57 currentToken = createWaitOp(op->getLoc(), Type(), {currentToken}); 58 return success(); 59 } 60 61 // Replaces asyncOp with a clone that returns a token. 62 LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) { 63 auto *op = asyncOp.getOperation(); 64 if (asyncOp.getAsyncToken()) 65 // TODO: Support ops that are already async. 66 return op->emitOpError("is already async"); 67 if (op->getNumRegions() > 0) 68 return op->emitOpError("regions are not supported"); 69 70 // If there is no current token, insert a `gpu.wait async` without 71 // dependencies to create one. 72 if (!currentToken) 73 currentToken = createWaitOp(op->getLoc(), tokenType, {}); 74 asyncOp.addAsyncDependency(currentToken); 75 76 // Clone the op to return a token in addition to the other results. 77 SmallVector<Type, 1> resultTypes = {tokenType}; 78 resultTypes.reserve(1 + op->getNumResults()); 79 copy(op->getResultTypes(), std::back_inserter(resultTypes)); 80 auto *newOp = Operation::create(op->getLoc(), op->getName(), resultTypes, 81 op->getOperands(), op->getMutableAttrDict(), 82 op->getSuccessors()); 83 84 // Replace the op with the async clone. 85 auto results = newOp->getResults(); 86 currentToken = results.front(); 87 builder.insert(newOp); 88 op->replaceAllUsesWith(results.drop_front()); 89 op->erase(); 90 91 return success(); 92 } 93 94 Value createWaitOp(Location loc, Type resultType, ValueRange operands) { 95 return builder.create<gpu::WaitOp>(loc, resultType, operands).asyncToken(); 96 } 97 98 OpBuilder builder; 99 const Type tokenType = builder.getType<gpu::AsyncTokenType>(); 100 // The token that represents the current asynchronous dependency. It's valid 101 // range starts with a `gpu.wait async` op, and ends with a `gpu.wait` op. 102 // In between, each gpu::AsyncOpInterface depends on the current token and 103 // produces the new one. 104 Value currentToken = {}; 105 }; 106 107 // Replaces synchronous GPU ops in the op's region with asynchronous ones and 108 // inserts the necessary synchronization (as gpu.wait ops). Assumes sequential 109 // execution semantics and that no GPU ops are asynchronous yet. 110 void GpuAsyncRegionPass::runOnFunction() { 111 Callback callback{OpBuilder(&getContext())}; 112 if (getFunction().getRegion().walk(callback).wasInterrupted()) 113 return signalPassFailure(); 114 } 115 116 std::unique_ptr<OperationPass<FuncOp>> mlir::createGpuAsyncRegionPass() { 117 return std::make_unique<GpuAsyncRegionPass>(); 118 } 119