1 //===- AsyncRegionRewriter.cpp - Implementation of GPU async rewriters ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU dialect pattern rewriters that make GPU op
10 // within a region execute asynchronously.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/GPU/GPUDialect.h"
16 #include "mlir/Dialect/GPU/Passes.h"
17 #include "mlir/Dialect/GPU/Utils.h"
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/SymbolTable.h"
23 #include "mlir/Support/LLVM.h"
24 #include "mlir/Transforms/RegionUtils.h"
25 
26 using namespace mlir;
27 namespace {
28 class GpuAsyncRegionPass : public GpuAsyncRegionPassBase<GpuAsyncRegionPass> {
29   struct Callback;
30   void runOnFunction() override;
31 };
32 } // namespace
33 
34 // Region walk callback which makes GPU ops implementing the AsyncOpInterface
35 // execute asynchronously.
36 struct GpuAsyncRegionPass::Callback {
37   // If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to
38   // create a current token (unless it already exists), and 'thread' that token
39   // through the `op` so that it executes asynchronously.
40   //
41   // If `op` is a terminator or an op with side-effects, insert a `gpu.wait` to
42   // host-synchronize execution.
43   WalkResult operator()(Operation *op) {
44     if (isa<gpu::LaunchOp>(op))
45       return op->emitOpError("replace with gpu.launch_func first");
46     if (isa<gpu::WaitOp>(op))
47       return op->emitOpError("unexpected pre-existing gpu.wait");
48     builder.setInsertionPoint(op);
49     if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op))
50       return rewriteAsyncOp(asyncOp); // Replace GPU op with async version.
51     if (!currentToken)
52       return success();
53     if (!op->hasTrait<OpTrait::IsTerminator>() &&
54         MemoryEffectOpInterface::hasNoEffect(op))
55       return success();
56     // Insert host synchronization before terminator or op with side effects.
57     currentToken = createWaitOp(op->getLoc(), Type(), {currentToken});
58     return success();
59   }
60 
61   // Replaces asyncOp with a clone that returns a token.
62   LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) {
63     auto *op = asyncOp.getOperation();
64     if (asyncOp.getAsyncToken())
65       // TODO: Support ops that are already async.
66       return op->emitOpError("is already async");
67     if (op->getNumRegions() > 0)
68       return op->emitOpError("regions are not supported");
69 
70     // If there is no current token, insert a `gpu.wait async` without
71     // dependencies to create one.
72     if (!currentToken)
73       currentToken = createWaitOp(op->getLoc(), tokenType, {});
74     asyncOp.addAsyncDependency(currentToken);
75 
76     // Clone the op to return a token in addition to the other results.
77     SmallVector<Type, 1> resultTypes = {tokenType};
78     resultTypes.reserve(1 + op->getNumResults());
79     copy(op->getResultTypes(), std::back_inserter(resultTypes));
80     auto *newOp = Operation::create(op->getLoc(), op->getName(), resultTypes,
81                                     op->getOperands(), op->getMutableAttrDict(),
82                                     op->getSuccessors());
83 
84     // Replace the op with the async clone.
85     auto results = newOp->getResults();
86     currentToken = results.front();
87     builder.insert(newOp);
88     op->replaceAllUsesWith(results.drop_front());
89     op->erase();
90 
91     return success();
92   }
93 
94   Value createWaitOp(Location loc, Type resultType, ValueRange operands) {
95     return builder.create<gpu::WaitOp>(loc, resultType, operands).asyncToken();
96   }
97 
98   OpBuilder builder;
99   const Type tokenType = builder.getType<gpu::AsyncTokenType>();
100   // The token that represents the current asynchronous dependency. It's valid
101   // range starts with a `gpu.wait async` op, and ends with a `gpu.wait` op.
102   // In between, each gpu::AsyncOpInterface depends on the current token and
103   // produces the new one.
104   Value currentToken = {};
105 };
106 
107 // Replaces synchronous GPU ops in the op's region with asynchronous ones and
108 // inserts the necessary synchronization (as gpu.wait ops). Assumes sequential
109 // execution semantics and that no GPU ops are asynchronous yet.
110 void GpuAsyncRegionPass::runOnFunction() {
111   Callback callback{OpBuilder(&getContext())};
112   if (getFunction().getRegion().walk(callback).wasInterrupted())
113     return signalPassFailure();
114 }
115 
116 std::unique_ptr<OperationPass<FuncOp>> mlir::createGpuAsyncRegionPass() {
117   return std::make_unique<GpuAsyncRegionPass>();
118 }
119