1 //===- AsyncRegionRewriter.cpp - Implementation of GPU async rewriters ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU dialect pattern rewriters that make GPU op
10 // within a region execute asynchronously.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Async/IR/Async.h"
16 #include "mlir/Dialect/GPU/GPUDialect.h"
17 #include "mlir/Dialect/GPU/Passes.h"
18 #include "mlir/Dialect/GPU/Utils.h"
19 #include "mlir/Dialect/StandardOps/IR/Ops.h"
20 #include "mlir/IR/BlockAndValueMapping.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/SymbolTable.h"
24 #include "mlir/Support/LLVM.h"
25 #include "mlir/Transforms/RegionUtils.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 
28 using namespace mlir;
29 namespace {
30 class GpuAsyncRegionPass : public GpuAsyncRegionPassBase<GpuAsyncRegionPass> {
31   struct ThreadTokenCallback;
32   struct DeferWaitCallback;
33   void runOnFunction() override;
34 };
35 } // namespace
36 
37 static bool isTerminator(Operation *op) { return !op->isKnownNonTerminator(); }
38 static bool hasSideEffects(Operation *op) {
39   return !MemoryEffectOpInterface::hasNoEffect(op);
40 }
41 
42 // Region walk callback which makes GPU ops implementing the AsyncOpInterface
43 // execute asynchronously.
44 struct GpuAsyncRegionPass::ThreadTokenCallback {
45   ThreadTokenCallback(MLIRContext &context) : builder(&context) {}
46 
47   // If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to
48   // create a current token (unless it already exists), and 'thread' that token
49   // through the `op` so that it executes asynchronously.
50   //
51   // If `op` is a terminator or an op with side-effects, insert a `gpu.wait` to
52   // host-synchronize execution. A `!gpu.async.token` will therefore only be
53   // used inside of its block and GPU execution will always synchronize with
54   // the host at block boundaries.
55   WalkResult operator()(Operation *op) {
56     if (isa<gpu::LaunchOp>(op))
57       return op->emitOpError("replace with gpu.launch_func first");
58     if (isa<gpu::WaitOp>(op))
59       return op->emitOpError("unexpected pre-existing gpu.wait");
60     builder.setInsertionPoint(op);
61     if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op))
62       return rewriteAsyncOp(asyncOp); // Replace GPU op with async version.
63     if (!currentToken)
64       return success();
65     // Insert host synchronization before terminator or op with side effects.
66     if (isTerminator(op) || hasSideEffects(op))
67       currentToken = createWaitOp(op->getLoc(), Type(), {currentToken});
68     return success();
69   }
70 
71 private:
72   // Replaces asyncOp with a clone that returns a token.
73   LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) {
74     auto *op = asyncOp.getOperation();
75     if (asyncOp.getAsyncToken())
76       // TODO: Support ops that are already async.
77       return op->emitOpError("is already async");
78     if (op->getNumRegions() > 0)
79       return op->emitOpError("regions are not supported");
80 
81     auto tokenType = builder.getType<gpu::AsyncTokenType>();
82 
83     // If there is no current token, insert a `gpu.wait async` without
84     // dependencies to create one.
85     if (!currentToken)
86       currentToken = createWaitOp(op->getLoc(), tokenType, {});
87     asyncOp.addAsyncDependency(currentToken);
88 
89     // Clone the op to return a token in addition to the other results.
90     SmallVector<Type, 1> resultTypes;
91     resultTypes.reserve(1 + op->getNumResults());
92     copy(op->getResultTypes(), std::back_inserter(resultTypes));
93     resultTypes.push_back(tokenType);
94     auto *newOp = Operation::create(op->getLoc(), op->getName(), resultTypes,
95                                     op->getOperands(), op->getAttrDictionary(),
96                                     op->getSuccessors());
97 
98     // Replace the op with the async clone.
99     auto results = newOp->getResults();
100     currentToken = results.back();
101     builder.insert(newOp);
102     op->replaceAllUsesWith(results.drop_back());
103     op->erase();
104 
105     return success();
106   }
107 
108   Value createWaitOp(Location loc, Type resultType, ValueRange operands) {
109     return builder.create<gpu::WaitOp>(loc, resultType, operands).asyncToken();
110   }
111 
112   OpBuilder builder;
113 
114   // The token that represents the current asynchronous dependency. It's valid
115   // range starts with a `gpu.wait async` op, and ends with a `gpu.wait` op.
116   // In between, each gpu::AsyncOpInterface depends on the current token and
117   // produces the new one.
118   Value currentToken = {};
119 };
120 
121 // Callback for `async.execute` ops which tries to push the contained
122 // synchronous `gpu.wait` op to the dependencies of the `async.execute`.
123 struct GpuAsyncRegionPass::DeferWaitCallback {
124   // If the `executeOp`s token is used only in `async.execute` or `async.await`
125   // ops, add the region's last `gpu.wait` op to the worklist if it is
126   // synchronous and is the last op with side effects.
127   void operator()(async::ExecuteOp executeOp) {
128     if (!areAllUsersExecuteOrAwait(executeOp.token()))
129       return;
130     // async.execute's region is currently restricted to one block.
131     for (auto &op : llvm::reverse(executeOp.getBody()->without_terminator())) {
132       if (auto waitOp = dyn_cast<gpu::WaitOp>(op)) {
133         if (!waitOp.asyncToken())
134           worklist.push_back(waitOp);
135         return;
136       }
137       if (hasSideEffects(&op))
138         return;
139     }
140   }
141 
142   // The destructor performs the actual rewrite work.
143   ~DeferWaitCallback() {
144     for (size_t i = 0; i < worklist.size(); ++i) {
145       auto waitOp = worklist[i];
146       auto executeOp = waitOp->getParentOfType<async::ExecuteOp>();
147       auto numDependencies = waitOp.asyncDependencies().size();
148 
149       // Erase `gpu.wait` and return async dependencies from region instead.
150       auto &yieldOp = executeOp.getBody()->getOperations().back();
151       yieldOp.insertOperands(yieldOp.getNumOperands(),
152                              waitOp.asyncDependencies());
153       waitOp.erase();
154       auto asyncTokens = addAsyncTokenResults(executeOp, numDependencies);
155 
156       // Add the async dependency to each user of the `async.execute` token.
157       for (Operation *user : executeOp.token().getUsers())
158         addAsyncDependencyAfter(asyncTokens, user);
159     }
160   }
161 
162 private:
163   // Append `count` `!async.value<!gpu.async.token>` results to `executeOp`.
164   static ValueRange addAsyncTokenResults(async::ExecuteOp &executeOp,
165                                          unsigned count) {
166     auto numResults = executeOp.getNumResults() + count;
167 
168     // Construct new result type list with `count` additional types.
169     SmallVector<Type, 2> resultTypes;
170     resultTypes.reserve(numResults);
171     transform(executeOp.getResultTypes(), std::back_inserter(resultTypes),
172               [](Type type) {
173                 // Extract value type from !async.value.
174                 if (auto valueType = type.dyn_cast<async::ValueType>())
175                   return valueType.getValueType();
176                 assert(type.isa<async::TokenType>() && "expected token type");
177                 return type;
178               });
179     OpBuilder builder(executeOp);
180     auto tokenType = builder.getType<gpu::AsyncTokenType>();
181     resultTypes.resize(numResults, tokenType);
182 
183     // Clone executeOp with the extra `!gpu.async.token` results.
184     auto newOp = builder.create<async::ExecuteOp>(
185         executeOp.getLoc(), TypeRange{resultTypes}.drop_front() /*drop token*/,
186         executeOp.dependencies(), executeOp.operands());
187     BlockAndValueMapping mapper;
188     newOp.getRegion().getBlocks().clear();
189     executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
190 
191     // Replace executeOp with cloned one.
192     executeOp.getOperation()->replaceAllUsesWith(
193         newOp.getResults().drop_back(count));
194     executeOp.erase();
195     executeOp = newOp;
196 
197     // Return the new result values.
198     return executeOp.getResults().take_back(count);
199   }
200 
201   // Returns whether all token users are either 'async.execute' or 'async.await'
202   // ops. This is used as a requirement for pushing 'gpu.wait' ops from a
203   // 'async.execute' body to it's users. Specifically, we do not allow
204   // terminator users, because it could mean that the `async.execute` is inside
205   // control flow code.
206   static bool areAllUsersExecuteOrAwait(Value token) {
207     return llvm::all_of(token.getUsers(), [](Operation *user) {
208       return isa<async::ExecuteOp, async::AwaitOp>(user);
209     });
210   }
211 
212   // Add the `asyncToken` as dependency as needed after `op`.
213   void addAsyncDependencyAfter(ValueRange asyncTokens, Operation *op) {
214     OpBuilder builder(op->getContext());
215     auto loc = op->getLoc();
216 
217     Block::iterator it;
218     SmallVector<Value, 1> tokens;
219     tokens.reserve(asyncTokens.size());
220     TypeSwitch<Operation *>(op)
221         .Case<async::AwaitOp>([&](auto awaitOp) {
222           // Add async.await ops to wait for the !gpu.async.tokens.
223           builder.setInsertionPointAfter(op);
224           for (auto asyncToken : asyncTokens)
225             tokens.push_back(
226                 builder.create<async::AwaitOp>(loc, asyncToken).result());
227           // Set `it` after the inserted async.await ops.
228           it = builder.getInsertionPoint();
229         })
230         .Case<async::ExecuteOp>([&](auto executeOp) {
231           // Set `it` to the beginning of the region and add asyncTokens to the
232           // async.execute operands.
233           it = executeOp.getBody()->begin();
234           executeOp.operandsMutable().append(asyncTokens);
235           SmallVector<Type, 1> tokenTypes(
236               asyncTokens.size(), builder.getType<gpu::AsyncTokenType>());
237           copy(executeOp.getBody()->addArguments(tokenTypes),
238                std::back_inserter(tokens));
239         });
240 
241     // Advance `it` to terminator or op with side-effects.
242     it = std::find_if(it, Block::iterator(), [](Operation &op) {
243       return isTerminator(&op) || hasSideEffects(&op);
244     });
245 
246     // If `op` implements the AsyncOpInterface, add `token` to the list of async
247     // dependencies.
248     if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(*it)) {
249       for (auto token : tokens)
250         asyncOp.addAsyncDependency(token);
251       return;
252     }
253 
254     // Otherwise, insert a gpu.wait before 'it'.
255     builder.setInsertionPoint(it->getBlock(), it);
256     auto waitOp = builder.create<gpu::WaitOp>(loc, Type{}, tokens);
257 
258     // If the new waitOp is at the end of an async.execute region, add it to the
259     // worklist. 'operator()(executeOp)' would do the same, but this is faster.
260     auto executeOp = dyn_cast<async::ExecuteOp>(it->getParentOp());
261     if (executeOp && areAllUsersExecuteOrAwait(executeOp.token()) &&
262         !it->getNextNode())
263       worklist.push_back(waitOp);
264   }
265 
266   SmallVector<gpu::WaitOp, 8> worklist;
267 };
268 
269 // Replaces synchronous GPU ops in the op's region with asynchronous ones and
270 // inserts the necessary synchronization (as gpu.wait ops). Assumes sequential
271 // execution semantics and that no GPU ops are asynchronous yet.
272 void GpuAsyncRegionPass::runOnFunction() {
273   if (getFunction()
274           .getRegion()
275           .walk(ThreadTokenCallback(getContext()))
276           .wasInterrupted())
277     return signalPassFailure();
278 
279   // Collect gpu.wait ops that we can move out of async.execute regions.
280   getFunction().getRegion().walk(DeferWaitCallback());
281 }
282 
283 std::unique_ptr<OperationPass<FuncOp>> mlir::createGpuAsyncRegionPass() {
284   return std::make_unique<GpuAsyncRegionPass>();
285 }
286