1db7129a0SChristian Sigg //===- AsyncRegionRewriter.cpp - Implementation of GPU async rewriters ----===//
2db7129a0SChristian Sigg //
3db7129a0SChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4db7129a0SChristian Sigg // See https://llvm.org/LICENSE.txt for license information.
5db7129a0SChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6db7129a0SChristian Sigg //
7db7129a0SChristian Sigg //===----------------------------------------------------------------------===//
8db7129a0SChristian Sigg //
9db7129a0SChristian Sigg // This file implements the GPU dialect pattern rewriters that make GPU op
10db7129a0SChristian Sigg // within a region execute asynchronously.
11db7129a0SChristian Sigg //
12db7129a0SChristian Sigg //===----------------------------------------------------------------------===//
13db7129a0SChristian Sigg
14db7129a0SChristian Sigg #include "PassDetail.h"
15d9adde5aSChristian Sigg #include "mlir/Dialect/Async/IR/Async.h"
16*d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17*d7ef488bSMogball #include "mlir/Dialect/GPU/Transforms/Passes.h"
18*d7ef488bSMogball #include "mlir/Dialect/GPU/Transforms/Utils.h"
19db7129a0SChristian Sigg #include "mlir/IR/BlockAndValueMapping.h"
20db7129a0SChristian Sigg #include "mlir/IR/Builders.h"
21db7129a0SChristian Sigg #include "mlir/IR/PatternMatch.h"
22db7129a0SChristian Sigg #include "mlir/IR/SymbolTable.h"
23db7129a0SChristian Sigg #include "mlir/Support/LLVM.h"
24db7129a0SChristian Sigg #include "mlir/Transforms/RegionUtils.h"
25d9adde5aSChristian Sigg #include "llvm/ADT/TypeSwitch.h"
26db7129a0SChristian Sigg
27db7129a0SChristian Sigg using namespace mlir;
28db7129a0SChristian Sigg namespace {
29db7129a0SChristian Sigg class GpuAsyncRegionPass : public GpuAsyncRegionPassBase<GpuAsyncRegionPass> {
30d9adde5aSChristian Sigg struct ThreadTokenCallback;
31d9adde5aSChristian Sigg struct DeferWaitCallback;
32f03826f8SChristian Sigg struct SingleTokenUseCallback;
3341574554SRiver Riddle void runOnOperation() override;
34db7129a0SChristian Sigg };
35db7129a0SChristian Sigg } // namespace
36db7129a0SChristian Sigg
isTerminator(Operation * op)37fe7c0d90SRiver Riddle static bool isTerminator(Operation *op) {
38fe7c0d90SRiver Riddle return op->mightHaveTrait<OpTrait::IsTerminator>();
39fe7c0d90SRiver Riddle }
hasSideEffects(Operation * op)40d9adde5aSChristian Sigg static bool hasSideEffects(Operation *op) {
41d9adde5aSChristian Sigg return !MemoryEffectOpInterface::hasNoEffect(op);
42d9adde5aSChristian Sigg }
43d9adde5aSChristian Sigg
44db7129a0SChristian Sigg // Region walk callback which makes GPU ops implementing the AsyncOpInterface
45db7129a0SChristian Sigg // execute asynchronously.
46d9adde5aSChristian Sigg struct GpuAsyncRegionPass::ThreadTokenCallback {
ThreadTokenCallbackGpuAsyncRegionPass::ThreadTokenCallback47d9adde5aSChristian Sigg ThreadTokenCallback(MLIRContext &context) : builder(&context) {}
48d9adde5aSChristian Sigg
operator ()GpuAsyncRegionPass::ThreadTokenCallback490b21371eSChristian Sigg WalkResult operator()(Block *block) {
500b21371eSChristian Sigg for (Operation &op : make_early_inc_range(*block)) {
510b21371eSChristian Sigg if (failed(visit(&op)))
520b21371eSChristian Sigg return WalkResult::interrupt();
530b21371eSChristian Sigg }
540b21371eSChristian Sigg return WalkResult::advance();
550b21371eSChristian Sigg }
560b21371eSChristian Sigg
570b21371eSChristian Sigg private:
58db7129a0SChristian Sigg // If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to
59db7129a0SChristian Sigg // create a current token (unless it already exists), and 'thread' that token
60db7129a0SChristian Sigg // through the `op` so that it executes asynchronously.
61db7129a0SChristian Sigg //
62db7129a0SChristian Sigg // If `op` is a terminator or an op with side-effects, insert a `gpu.wait` to
63d9adde5aSChristian Sigg // host-synchronize execution. A `!gpu.async.token` will therefore only be
64d9adde5aSChristian Sigg // used inside of its block and GPU execution will always synchronize with
65d9adde5aSChristian Sigg // the host at block boundaries.
visitGpuAsyncRegionPass::ThreadTokenCallback660b21371eSChristian Sigg LogicalResult visit(Operation *op) {
67db7129a0SChristian Sigg if (isa<gpu::LaunchOp>(op))
68db7129a0SChristian Sigg return op->emitOpError("replace with gpu.launch_func first");
690b21371eSChristian Sigg if (auto waitOp = llvm::dyn_cast<gpu::WaitOp>(op)) {
700b21371eSChristian Sigg if (currentToken)
710b21371eSChristian Sigg waitOp.addAsyncDependency(currentToken);
720b21371eSChristian Sigg currentToken = waitOp.asyncToken();
730b21371eSChristian Sigg return success();
740b21371eSChristian Sigg }
75db7129a0SChristian Sigg builder.setInsertionPoint(op);
76db7129a0SChristian Sigg if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op))
77db7129a0SChristian Sigg return rewriteAsyncOp(asyncOp); // Replace GPU op with async version.
78db7129a0SChristian Sigg if (!currentToken)
79db7129a0SChristian Sigg return success();
80db7129a0SChristian Sigg // Insert host synchronization before terminator or op with side effects.
81d9adde5aSChristian Sigg if (isTerminator(op) || hasSideEffects(op))
82db7129a0SChristian Sigg currentToken = createWaitOp(op->getLoc(), Type(), {currentToken});
83db7129a0SChristian Sigg return success();
84db7129a0SChristian Sigg }
85db7129a0SChristian Sigg
86db7129a0SChristian Sigg // Replaces asyncOp with a clone that returns a token.
rewriteAsyncOpGpuAsyncRegionPass::ThreadTokenCallback87db7129a0SChristian Sigg LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) {
88db7129a0SChristian Sigg auto *op = asyncOp.getOperation();
894c372a35SChristian Sigg auto tokenType = builder.getType<gpu::AsyncTokenType>();
904c372a35SChristian Sigg
91db7129a0SChristian Sigg // If there is no current token, insert a `gpu.wait async` without
92db7129a0SChristian Sigg // dependencies to create one.
93db7129a0SChristian Sigg if (!currentToken)
94db7129a0SChristian Sigg currentToken = createWaitOp(op->getLoc(), tokenType, {});
95db7129a0SChristian Sigg asyncOp.addAsyncDependency(currentToken);
96db7129a0SChristian Sigg
970b21371eSChristian Sigg // Return early if op returns a token already.
980b21371eSChristian Sigg currentToken = asyncOp.getAsyncToken();
990b21371eSChristian Sigg if (currentToken)
1000b21371eSChristian Sigg return success();
1010b21371eSChristian Sigg
102db7129a0SChristian Sigg // Clone the op to return a token in addition to the other results.
103a79b26dbSChristian Sigg SmallVector<Type, 1> resultTypes;
104db7129a0SChristian Sigg resultTypes.reserve(1 + op->getNumResults());
105db7129a0SChristian Sigg copy(op->getResultTypes(), std::back_inserter(resultTypes));
106a79b26dbSChristian Sigg resultTypes.push_back(tokenType);
107db7129a0SChristian Sigg auto *newOp = Operation::create(op->getLoc(), op->getName(), resultTypes,
108fc5cf50eSRiver Riddle op->getOperands(), op->getAttrDictionary(),
109a0d019fcSChristian Sigg op->getSuccessors(), op->getNumRegions());
110a0d019fcSChristian Sigg
111a0d019fcSChristian Sigg // Clone regions into new op.
112a0d019fcSChristian Sigg BlockAndValueMapping mapping;
113a0d019fcSChristian Sigg for (auto pair : llvm::zip_first(op->getRegions(), newOp->getRegions()))
114a0d019fcSChristian Sigg std::get<0>(pair).cloneInto(&std::get<1>(pair), mapping);
115db7129a0SChristian Sigg
116db7129a0SChristian Sigg // Replace the op with the async clone.
117db7129a0SChristian Sigg auto results = newOp->getResults();
118a79b26dbSChristian Sigg currentToken = results.back();
119db7129a0SChristian Sigg builder.insert(newOp);
120a79b26dbSChristian Sigg op->replaceAllUsesWith(results.drop_back());
121db7129a0SChristian Sigg op->erase();
122db7129a0SChristian Sigg
123db7129a0SChristian Sigg return success();
124db7129a0SChristian Sigg }
125db7129a0SChristian Sigg
createWaitOpGpuAsyncRegionPass::ThreadTokenCallback126db7129a0SChristian Sigg Value createWaitOp(Location loc, Type resultType, ValueRange operands) {
127db7129a0SChristian Sigg return builder.create<gpu::WaitOp>(loc, resultType, operands).asyncToken();
128db7129a0SChristian Sigg }
129db7129a0SChristian Sigg
130db7129a0SChristian Sigg OpBuilder builder;
1314c372a35SChristian Sigg
132db7129a0SChristian Sigg // The token that represents the current asynchronous dependency. It's valid
133db7129a0SChristian Sigg // range starts with a `gpu.wait async` op, and ends with a `gpu.wait` op.
134db7129a0SChristian Sigg // In between, each gpu::AsyncOpInterface depends on the current token and
135db7129a0SChristian Sigg // produces the new one.
136db7129a0SChristian Sigg Value currentToken = {};
137db7129a0SChristian Sigg };
138db7129a0SChristian Sigg
139f03826f8SChristian Sigg /// Erases `executeOp` and returns a clone with additional `results`.
addExecuteResults(async::ExecuteOp executeOp,ValueRange results)140f03826f8SChristian Sigg async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp,
141f03826f8SChristian Sigg ValueRange results) {
142f03826f8SChristian Sigg // Add values to async.yield op.
143f03826f8SChristian Sigg Operation *yieldOp = executeOp.getBody()->getTerminator();
144f03826f8SChristian Sigg yieldOp->insertOperands(yieldOp->getNumOperands(), results);
145f03826f8SChristian Sigg
146f03826f8SChristian Sigg // Construct new result type list with additional types.
147f03826f8SChristian Sigg SmallVector<Type, 2> resultTypes;
148f03826f8SChristian Sigg resultTypes.reserve(executeOp.getNumResults() + results.size());
149f03826f8SChristian Sigg transform(executeOp.getResultTypes(), std::back_inserter(resultTypes),
150f03826f8SChristian Sigg [](Type type) {
151f03826f8SChristian Sigg // Extract value type from !async.value.
152f03826f8SChristian Sigg if (auto valueType = type.dyn_cast<async::ValueType>())
153f03826f8SChristian Sigg return valueType.getValueType();
154f03826f8SChristian Sigg assert(type.isa<async::TokenType>() && "expected token type");
155f03826f8SChristian Sigg return type;
156f03826f8SChristian Sigg });
157f03826f8SChristian Sigg transform(results, std::back_inserter(resultTypes),
158f03826f8SChristian Sigg [](Value value) { return value.getType(); });
159f03826f8SChristian Sigg
160f03826f8SChristian Sigg // Clone executeOp with the extra results.
161f03826f8SChristian Sigg OpBuilder builder(executeOp);
162f03826f8SChristian Sigg auto newOp = builder.create<async::ExecuteOp>(
163f03826f8SChristian Sigg executeOp.getLoc(), TypeRange{resultTypes}.drop_front() /*drop token*/,
164f03826f8SChristian Sigg executeOp.dependencies(), executeOp.operands());
165f03826f8SChristian Sigg BlockAndValueMapping mapper;
166f03826f8SChristian Sigg newOp.getRegion().getBlocks().clear();
167f03826f8SChristian Sigg executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
168f03826f8SChristian Sigg
169f03826f8SChristian Sigg // Replace executeOp with cloned one.
170f03826f8SChristian Sigg executeOp.getOperation()->replaceAllUsesWith(
171f03826f8SChristian Sigg newOp.getResults().drop_back(results.size()));
172f03826f8SChristian Sigg executeOp.erase();
173f03826f8SChristian Sigg
174f03826f8SChristian Sigg return newOp;
175f03826f8SChristian Sigg }
176f03826f8SChristian Sigg
177d9adde5aSChristian Sigg // Callback for `async.execute` ops which tries to push the contained
178d9adde5aSChristian Sigg // synchronous `gpu.wait` op to the dependencies of the `async.execute`.
179d9adde5aSChristian Sigg struct GpuAsyncRegionPass::DeferWaitCallback {
180d9adde5aSChristian Sigg // If the `executeOp`s token is used only in `async.execute` or `async.await`
181d9adde5aSChristian Sigg // ops, add the region's last `gpu.wait` op to the worklist if it is
182d9adde5aSChristian Sigg // synchronous and is the last op with side effects.
operator ()GpuAsyncRegionPass::DeferWaitCallback183d9adde5aSChristian Sigg void operator()(async::ExecuteOp executeOp) {
184d9adde5aSChristian Sigg if (!areAllUsersExecuteOrAwait(executeOp.token()))
185d9adde5aSChristian Sigg return;
186d9adde5aSChristian Sigg // async.execute's region is currently restricted to one block.
187d9adde5aSChristian Sigg for (auto &op : llvm::reverse(executeOp.getBody()->without_terminator())) {
188d9adde5aSChristian Sigg if (auto waitOp = dyn_cast<gpu::WaitOp>(op)) {
189d9adde5aSChristian Sigg if (!waitOp.asyncToken())
190d9adde5aSChristian Sigg worklist.push_back(waitOp);
191d9adde5aSChristian Sigg return;
192d9adde5aSChristian Sigg }
193d9adde5aSChristian Sigg if (hasSideEffects(&op))
194d9adde5aSChristian Sigg return;
195d9adde5aSChristian Sigg }
196d9adde5aSChristian Sigg }
197d9adde5aSChristian Sigg
198d9adde5aSChristian Sigg // The destructor performs the actual rewrite work.
~DeferWaitCallbackGpuAsyncRegionPass::DeferWaitCallback199d9adde5aSChristian Sigg ~DeferWaitCallback() {
200d9adde5aSChristian Sigg for (size_t i = 0; i < worklist.size(); ++i) {
201d9adde5aSChristian Sigg auto waitOp = worklist[i];
2020bf4a82aSChristian Sigg auto executeOp = waitOp->getParentOfType<async::ExecuteOp>();
203d9adde5aSChristian Sigg
204f03826f8SChristian Sigg // Erase `gpu.wait` and return async dependencies from execute op instead.
205f03826f8SChristian Sigg SmallVector<Value, 4> dependencies = waitOp.asyncDependencies();
206d9adde5aSChristian Sigg waitOp.erase();
207f03826f8SChristian Sigg executeOp = addExecuteResults(executeOp, dependencies);
208d9adde5aSChristian Sigg
209d9adde5aSChristian Sigg // Add the async dependency to each user of the `async.execute` token.
210f03826f8SChristian Sigg auto asyncTokens = executeOp.getResults().take_back(dependencies.size());
2116e1ac68aSVitaly Buka SmallVector<Operation *, 4> users(executeOp.token().user_begin(),
2126e1ac68aSVitaly Buka executeOp.token().user_end());
2136e1ac68aSVitaly Buka for (Operation *user : users)
214d9adde5aSChristian Sigg addAsyncDependencyAfter(asyncTokens, user);
215d9adde5aSChristian Sigg }
216d9adde5aSChristian Sigg }
217d9adde5aSChristian Sigg
218d9adde5aSChristian Sigg private:
219d9adde5aSChristian Sigg // Returns whether all token users are either 'async.execute' or 'async.await'
220d9adde5aSChristian Sigg // ops. This is used as a requirement for pushing 'gpu.wait' ops from a
221d9adde5aSChristian Sigg // 'async.execute' body to it's users. Specifically, we do not allow
222d9adde5aSChristian Sigg // terminator users, because it could mean that the `async.execute` is inside
223d9adde5aSChristian Sigg // control flow code.
areAllUsersExecuteOrAwaitGpuAsyncRegionPass::DeferWaitCallback224d9adde5aSChristian Sigg static bool areAllUsersExecuteOrAwait(Value token) {
225f03826f8SChristian Sigg return !token.use_empty() &&
226f03826f8SChristian Sigg llvm::all_of(token.getUsers(), [](Operation *user) {
227d9adde5aSChristian Sigg return isa<async::ExecuteOp, async::AwaitOp>(user);
228d9adde5aSChristian Sigg });
229d9adde5aSChristian Sigg }
230d9adde5aSChristian Sigg
231d9adde5aSChristian Sigg // Add the `asyncToken` as dependency as needed after `op`.
addAsyncDependencyAfterGpuAsyncRegionPass::DeferWaitCallback232d9adde5aSChristian Sigg void addAsyncDependencyAfter(ValueRange asyncTokens, Operation *op) {
233d9adde5aSChristian Sigg OpBuilder builder(op->getContext());
234d9adde5aSChristian Sigg auto loc = op->getLoc();
235d9adde5aSChristian Sigg
236d9adde5aSChristian Sigg Block::iterator it;
237d9adde5aSChristian Sigg SmallVector<Value, 1> tokens;
238d9adde5aSChristian Sigg tokens.reserve(asyncTokens.size());
239d9adde5aSChristian Sigg TypeSwitch<Operation *>(op)
240d9adde5aSChristian Sigg .Case<async::AwaitOp>([&](auto awaitOp) {
241d9adde5aSChristian Sigg // Add async.await ops to wait for the !gpu.async.tokens.
242d9adde5aSChristian Sigg builder.setInsertionPointAfter(op);
243d9adde5aSChristian Sigg for (auto asyncToken : asyncTokens)
244d9adde5aSChristian Sigg tokens.push_back(
245d9adde5aSChristian Sigg builder.create<async::AwaitOp>(loc, asyncToken).result());
246d9adde5aSChristian Sigg // Set `it` after the inserted async.await ops.
247d9adde5aSChristian Sigg it = builder.getInsertionPoint();
248d9adde5aSChristian Sigg })
249d9adde5aSChristian Sigg .Case<async::ExecuteOp>([&](auto executeOp) {
250d9adde5aSChristian Sigg // Set `it` to the beginning of the region and add asyncTokens to the
251d9adde5aSChristian Sigg // async.execute operands.
252d9adde5aSChristian Sigg it = executeOp.getBody()->begin();
253d9adde5aSChristian Sigg executeOp.operandsMutable().append(asyncTokens);
254d9adde5aSChristian Sigg SmallVector<Type, 1> tokenTypes(
255d9adde5aSChristian Sigg asyncTokens.size(), builder.getType<gpu::AsyncTokenType>());
256e084679fSRiver Riddle SmallVector<Location, 1> tokenLocs(asyncTokens.size(),
257e084679fSRiver Riddle executeOp.getLoc());
258e084679fSRiver Riddle copy(executeOp.getBody()->addArguments(tokenTypes, tokenLocs),
259d9adde5aSChristian Sigg std::back_inserter(tokens));
260d9adde5aSChristian Sigg });
261d9adde5aSChristian Sigg
262d9adde5aSChristian Sigg // Advance `it` to terminator or op with side-effects.
263d9adde5aSChristian Sigg it = std::find_if(it, Block::iterator(), [](Operation &op) {
264d9adde5aSChristian Sigg return isTerminator(&op) || hasSideEffects(&op);
265d9adde5aSChristian Sigg });
266d9adde5aSChristian Sigg
267d9adde5aSChristian Sigg // If `op` implements the AsyncOpInterface, add `token` to the list of async
268d9adde5aSChristian Sigg // dependencies.
269d9adde5aSChristian Sigg if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(*it)) {
270d9adde5aSChristian Sigg for (auto token : tokens)
271d9adde5aSChristian Sigg asyncOp.addAsyncDependency(token);
272d9adde5aSChristian Sigg return;
273d9adde5aSChristian Sigg }
274d9adde5aSChristian Sigg
275d9adde5aSChristian Sigg // Otherwise, insert a gpu.wait before 'it'.
276d9adde5aSChristian Sigg builder.setInsertionPoint(it->getBlock(), it);
277d9adde5aSChristian Sigg auto waitOp = builder.create<gpu::WaitOp>(loc, Type{}, tokens);
278d9adde5aSChristian Sigg
279d9adde5aSChristian Sigg // If the new waitOp is at the end of an async.execute region, add it to the
280d9adde5aSChristian Sigg // worklist. 'operator()(executeOp)' would do the same, but this is faster.
281d9adde5aSChristian Sigg auto executeOp = dyn_cast<async::ExecuteOp>(it->getParentOp());
282d9adde5aSChristian Sigg if (executeOp && areAllUsersExecuteOrAwait(executeOp.token()) &&
283d9adde5aSChristian Sigg !it->getNextNode())
284d9adde5aSChristian Sigg worklist.push_back(waitOp);
285d9adde5aSChristian Sigg }
286d9adde5aSChristian Sigg
287d9adde5aSChristian Sigg SmallVector<gpu::WaitOp, 8> worklist;
288d9adde5aSChristian Sigg };
289d9adde5aSChristian Sigg
290f03826f8SChristian Sigg // Callback for `async.execute` ops which repeats !gpu.async.token results
291f03826f8SChristian Sigg // so that each of them is only used once.
292f03826f8SChristian Sigg struct GpuAsyncRegionPass::SingleTokenUseCallback {
operator ()GpuAsyncRegionPass::SingleTokenUseCallback293f03826f8SChristian Sigg void operator()(async::ExecuteOp executeOp) {
294f03826f8SChristian Sigg // Extract !gpu.async.token results which have multiple uses.
295f03826f8SChristian Sigg auto multiUseResults =
296f03826f8SChristian Sigg llvm::make_filter_range(executeOp.results(), [](OpResult result) {
297f03826f8SChristian Sigg if (result.use_empty() || result.hasOneUse())
298f03826f8SChristian Sigg return false;
299f03826f8SChristian Sigg auto valueType = result.getType().dyn_cast<async::ValueType>();
300f03826f8SChristian Sigg return valueType &&
301f03826f8SChristian Sigg valueType.getValueType().isa<gpu::AsyncTokenType>();
302f03826f8SChristian Sigg });
303f03826f8SChristian Sigg if (multiUseResults.empty())
304f03826f8SChristian Sigg return;
305f03826f8SChristian Sigg
306f03826f8SChristian Sigg // Indices within !async.execute results (i.e. without the async.token).
307f03826f8SChristian Sigg SmallVector<int, 4> indices;
308f03826f8SChristian Sigg transform(multiUseResults, std::back_inserter(indices),
309f03826f8SChristian Sigg [](OpResult result) {
310f03826f8SChristian Sigg return result.getResultNumber() - 1; // Index without token.
311f03826f8SChristian Sigg });
312f03826f8SChristian Sigg
313f03826f8SChristian Sigg for (auto index : indices) {
314f03826f8SChristian Sigg assert(!executeOp.results()[index].getUses().empty());
315f03826f8SChristian Sigg // Repeat async.yield token result, one for each use after the first one.
316f03826f8SChristian Sigg auto uses = llvm::drop_begin(executeOp.results()[index].getUses());
317f03826f8SChristian Sigg auto count = std::distance(uses.begin(), uses.end());
318f03826f8SChristian Sigg auto yieldOp = cast<async::YieldOp>(executeOp.getBody()->getTerminator());
319f03826f8SChristian Sigg SmallVector<Value, 4> operands(count, yieldOp.getOperand(index));
320f03826f8SChristian Sigg executeOp = addExecuteResults(executeOp, operands);
321f03826f8SChristian Sigg // Update 'uses' to refer to the new executeOp.
322f03826f8SChristian Sigg uses = llvm::drop_begin(executeOp.results()[index].getUses());
323f03826f8SChristian Sigg auto results = executeOp.results().take_back(count);
324f03826f8SChristian Sigg for (auto pair : llvm::zip(uses, results))
325f03826f8SChristian Sigg std::get<0>(pair).set(std::get<1>(pair));
326f03826f8SChristian Sigg }
327f03826f8SChristian Sigg }
328f03826f8SChristian Sigg };
329f03826f8SChristian Sigg
330db7129a0SChristian Sigg // Replaces synchronous GPU ops in the op's region with asynchronous ones and
331db7129a0SChristian Sigg // inserts the necessary synchronization (as gpu.wait ops). Assumes sequential
332db7129a0SChristian Sigg // execution semantics and that no GPU ops are asynchronous yet.
runOnOperation()33341574554SRiver Riddle void GpuAsyncRegionPass::runOnOperation() {
33441574554SRiver Riddle if (getOperation()->walk(ThreadTokenCallback(getContext())).wasInterrupted())
335db7129a0SChristian Sigg return signalPassFailure();
336d9adde5aSChristian Sigg
337a79b26dbSChristian Sigg // Collect gpu.wait ops that we can move out of async.execute regions.
33841574554SRiver Riddle getOperation().getRegion().walk(DeferWaitCallback());
339f03826f8SChristian Sigg // Makes each !gpu.async.token returned from async.execute op have single use.
34041574554SRiver Riddle getOperation().getRegion().walk(SingleTokenUseCallback());
341db7129a0SChristian Sigg }
342db7129a0SChristian Sigg
createGpuAsyncRegionPass()34358ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>> mlir::createGpuAsyncRegionPass() {
344db7129a0SChristian Sigg return std::make_unique<GpuAsyncRegionPass>();
345db7129a0SChristian Sigg }
346