1d02f10d9SThomas Raoux //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2d02f10d9SThomas Raoux //
3d02f10d9SThomas Raoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d02f10d9SThomas Raoux // See https://llvm.org/LICENSE.txt for license information.
5d02f10d9SThomas Raoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d02f10d9SThomas Raoux //
7d02f10d9SThomas Raoux //===----------------------------------------------------------------------===//
8d02f10d9SThomas Raoux
9ed0288f7SThomas Raoux #include "mlir/Dialect/Affine/IR/AffineOps.h"
10d02f10d9SThomas Raoux #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11d02f10d9SThomas Raoux #include "mlir/Dialect/MemRef/IR/MemRef.h"
128b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
13d02f10d9SThomas Raoux #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
14087aba4fSThomas Raoux #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
1576cf33daSThomas Raoux #include "mlir/IR/BlockAndValueMapping.h"
16ed0288f7SThomas Raoux #include "mlir/Transforms/SideEffectUtils.h"
17d7d6443dSThomas Raoux #include "llvm/ADT/SetVector.h"
1808d651d7SMehdi Amini #include <utility>
1908d651d7SMehdi Amini
20d02f10d9SThomas Raoux using namespace mlir;
21d02f10d9SThomas Raoux using namespace mlir::vector;
22d02f10d9SThomas Raoux
23d02f10d9SThomas Raoux static LogicalResult
rewriteWarpOpToScfFor(RewriterBase & rewriter,WarpExecuteOnLane0Op warpOp,const WarpExecuteOnLane0LoweringOptions & options)24d02f10d9SThomas Raoux rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
25d02f10d9SThomas Raoux const WarpExecuteOnLane0LoweringOptions &options) {
26d02f10d9SThomas Raoux assert(warpOp.getBodyRegion().hasOneBlock() &&
27d02f10d9SThomas Raoux "expected WarpOp with single block");
28d02f10d9SThomas Raoux Block *warpOpBody = &warpOp.getBodyRegion().front();
29d02f10d9SThomas Raoux Location loc = warpOp.getLoc();
30d02f10d9SThomas Raoux
31d02f10d9SThomas Raoux // Passed all checks. Start rewriting.
32d02f10d9SThomas Raoux OpBuilder::InsertionGuard g(rewriter);
33d02f10d9SThomas Raoux rewriter.setInsertionPoint(warpOp);
34d02f10d9SThomas Raoux
35d02f10d9SThomas Raoux // Create scf.if op.
36d02f10d9SThomas Raoux Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
37d02f10d9SThomas Raoux Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
38d02f10d9SThomas Raoux warpOp.getLaneid(), c0);
39d02f10d9SThomas Raoux auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
40d02f10d9SThomas Raoux /*withElseRegion=*/false);
41d02f10d9SThomas Raoux rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
42d02f10d9SThomas Raoux
43d02f10d9SThomas Raoux // Store vectors that are defined outside of warpOp into the scratch pad
44d02f10d9SThomas Raoux // buffer.
45d02f10d9SThomas Raoux SmallVector<Value> bbArgReplacements;
46d02f10d9SThomas Raoux for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
47d02f10d9SThomas Raoux Value val = it.value();
48d02f10d9SThomas Raoux Value bbArg = warpOpBody->getArgument(it.index());
49d02f10d9SThomas Raoux
50d02f10d9SThomas Raoux rewriter.setInsertionPoint(ifOp);
51f6c79c6aSNicolas Vasilache Value buffer =
52f6c79c6aSNicolas Vasilache options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType());
53d02f10d9SThomas Raoux
54d02f10d9SThomas Raoux // Store arg vector into buffer.
55d02f10d9SThomas Raoux rewriter.setInsertionPoint(ifOp);
56d02f10d9SThomas Raoux auto vectorType = val.getType().cast<VectorType>();
57d02f10d9SThomas Raoux int64_t storeSize = vectorType.getShape()[0];
58d02f10d9SThomas Raoux Value storeOffset = rewriter.create<arith::MulIOp>(
59d02f10d9SThomas Raoux loc, warpOp.getLaneid(),
60d02f10d9SThomas Raoux rewriter.create<arith::ConstantIndexOp>(loc, storeSize));
61d02f10d9SThomas Raoux rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset);
62d02f10d9SThomas Raoux
63d02f10d9SThomas Raoux // Load bbArg vector from buffer.
64d02f10d9SThomas Raoux rewriter.setInsertionPointToStart(ifOp.thenBlock());
65d02f10d9SThomas Raoux auto bbArgType = bbArg.getType().cast<VectorType>();
66d02f10d9SThomas Raoux Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0);
67d02f10d9SThomas Raoux bbArgReplacements.push_back(loadOp);
68d02f10d9SThomas Raoux }
69d02f10d9SThomas Raoux
70d02f10d9SThomas Raoux // Insert sync after all the stores and before all the loads.
71d02f10d9SThomas Raoux if (!warpOp.getArgs().empty()) {
72d02f10d9SThomas Raoux rewriter.setInsertionPoint(ifOp);
73f6c79c6aSNicolas Vasilache options.warpSyncronizationFn(loc, rewriter, warpOp);
74d02f10d9SThomas Raoux }
75d02f10d9SThomas Raoux
76d02f10d9SThomas Raoux // Move body of warpOp to ifOp.
77d02f10d9SThomas Raoux rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
78d02f10d9SThomas Raoux
79d02f10d9SThomas Raoux // Rewrite terminator and compute replacements of WarpOp results.
80d02f10d9SThomas Raoux SmallVector<Value> replacements;
81d02f10d9SThomas Raoux auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
82d02f10d9SThomas Raoux Location yieldLoc = yieldOp.getLoc();
83d02f10d9SThomas Raoux for (const auto &it : llvm::enumerate(yieldOp.operands())) {
84d02f10d9SThomas Raoux Value val = it.value();
85d02f10d9SThomas Raoux Type resultType = warpOp->getResultTypes()[it.index()];
86d02f10d9SThomas Raoux rewriter.setInsertionPoint(ifOp);
87f6c79c6aSNicolas Vasilache Value buffer =
88f6c79c6aSNicolas Vasilache options.warpAllocationFn(loc, rewriter, warpOp, val.getType());
89d02f10d9SThomas Raoux
90d02f10d9SThomas Raoux // Store yielded value into buffer.
91d02f10d9SThomas Raoux rewriter.setInsertionPoint(yieldOp);
92d02f10d9SThomas Raoux if (val.getType().isa<VectorType>())
93d02f10d9SThomas Raoux rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0);
94d02f10d9SThomas Raoux else
95d02f10d9SThomas Raoux rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0);
96d02f10d9SThomas Raoux
97d02f10d9SThomas Raoux // Load value from buffer (after warpOp).
98d02f10d9SThomas Raoux rewriter.setInsertionPointAfter(ifOp);
99d02f10d9SThomas Raoux if (resultType == val.getType()) {
100d02f10d9SThomas Raoux // Result type and yielded value type are the same. This is a broadcast.
101d02f10d9SThomas Raoux // E.g.:
102ed0288f7SThomas Raoux // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
103ed0288f7SThomas Raoux // vector.yield %cst : f32
104d02f10d9SThomas Raoux // }
105d02f10d9SThomas Raoux // Both types are f32. The constant %cst is broadcasted to all lanes.
106d02f10d9SThomas Raoux // This is described in more detail in the documentation of the op.
107d02f10d9SThomas Raoux Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0);
108d02f10d9SThomas Raoux replacements.push_back(loadOp);
109d02f10d9SThomas Raoux } else {
110d02f10d9SThomas Raoux auto loadedVectorType = resultType.cast<VectorType>();
111d02f10d9SThomas Raoux int64_t loadSize = loadedVectorType.getShape()[0];
112d02f10d9SThomas Raoux
113d02f10d9SThomas Raoux // loadOffset = laneid * loadSize
114d02f10d9SThomas Raoux Value loadOffset = rewriter.create<arith::MulIOp>(
115d02f10d9SThomas Raoux loc, warpOp.getLaneid(),
116d02f10d9SThomas Raoux rewriter.create<arith::ConstantIndexOp>(loc, loadSize));
117d02f10d9SThomas Raoux Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType,
118d02f10d9SThomas Raoux buffer, loadOffset);
119d02f10d9SThomas Raoux replacements.push_back(loadOp);
120d02f10d9SThomas Raoux }
121d02f10d9SThomas Raoux }
122d02f10d9SThomas Raoux
123d02f10d9SThomas Raoux // Insert sync after all the stores and before all the loads.
124d02f10d9SThomas Raoux if (!yieldOp.operands().empty()) {
125d02f10d9SThomas Raoux rewriter.setInsertionPointAfter(ifOp);
126f6c79c6aSNicolas Vasilache options.warpSyncronizationFn(loc, rewriter, warpOp);
127d02f10d9SThomas Raoux }
128d02f10d9SThomas Raoux
129d02f10d9SThomas Raoux // Delete terminator and add empty scf.yield.
130d02f10d9SThomas Raoux rewriter.eraseOp(yieldOp);
131d02f10d9SThomas Raoux rewriter.setInsertionPointToEnd(ifOp.thenBlock());
132d02f10d9SThomas Raoux rewriter.create<scf::YieldOp>(yieldLoc);
133d02f10d9SThomas Raoux
134d02f10d9SThomas Raoux // Compute replacements for WarpOp results.
135d02f10d9SThomas Raoux rewriter.replaceOp(warpOp, replacements);
136d02f10d9SThomas Raoux
137d02f10d9SThomas Raoux return success();
138d02f10d9SThomas Raoux }
139d02f10d9SThomas Raoux
140ed0288f7SThomas Raoux /// Helper to create a new WarpExecuteOnLane0Op with different signature.
moveRegionToNewWarpOpAndReplaceReturns(RewriterBase & rewriter,WarpExecuteOnLane0Op warpOp,ValueRange newYieldedValues,TypeRange newReturnTypes)141ed0288f7SThomas Raoux static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
142ed0288f7SThomas Raoux RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
143ed0288f7SThomas Raoux ValueRange newYieldedValues, TypeRange newReturnTypes) {
144ed0288f7SThomas Raoux // Create a new op before the existing one, with the extra operands.
145ed0288f7SThomas Raoux OpBuilder::InsertionGuard g(rewriter);
146ed0288f7SThomas Raoux rewriter.setInsertionPoint(warpOp);
147ed0288f7SThomas Raoux auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
148ed0288f7SThomas Raoux warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
149ed0288f7SThomas Raoux warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
150ed0288f7SThomas Raoux
151ed0288f7SThomas Raoux Region &opBody = warpOp.getBodyRegion();
152ed0288f7SThomas Raoux Region &newOpBody = newWarpOp.getBodyRegion();
153f6c79c6aSNicolas Vasilache Block &newOpFirstBlock = newOpBody.front();
154ed0288f7SThomas Raoux rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
155f6c79c6aSNicolas Vasilache rewriter.eraseBlock(&newOpFirstBlock);
156f6c79c6aSNicolas Vasilache assert(newWarpOp.getWarpRegion().hasOneBlock() &&
157f6c79c6aSNicolas Vasilache "expected WarpOp with single block");
158f6c79c6aSNicolas Vasilache
159ed0288f7SThomas Raoux auto yield =
160ed0288f7SThomas Raoux cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
161ed0288f7SThomas Raoux
162ed0288f7SThomas Raoux rewriter.updateRootInPlace(
163ed0288f7SThomas Raoux yield, [&]() { yield.operandsMutable().assign(newYieldedValues); });
164ed0288f7SThomas Raoux return newWarpOp;
165ed0288f7SThomas Raoux }
166ed0288f7SThomas Raoux
167ed0288f7SThomas Raoux /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
168d7d6443dSThomas Raoux /// `indices` return the index of each new output.
moveRegionToNewWarpOpAndAppendReturns(RewriterBase & rewriter,WarpExecuteOnLane0Op warpOp,ValueRange newYieldedValues,TypeRange newReturnTypes,llvm::SmallVector<size_t> & indices)169ed0288f7SThomas Raoux static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
170ed0288f7SThomas Raoux RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
171d7d6443dSThomas Raoux ValueRange newYieldedValues, TypeRange newReturnTypes,
172d7d6443dSThomas Raoux llvm::SmallVector<size_t> &indices) {
173ed0288f7SThomas Raoux SmallVector<Type> types(warpOp.getResultTypes().begin(),
174ed0288f7SThomas Raoux warpOp.getResultTypes().end());
175ed0288f7SThomas Raoux auto yield = cast<vector::YieldOp>(
176ed0288f7SThomas Raoux warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
177d7d6443dSThomas Raoux llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
178ed0288f7SThomas Raoux yield.getOperands().end());
179d7d6443dSThomas Raoux for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
180d7d6443dSThomas Raoux if (yieldValues.insert(std::get<0>(newRet))) {
181d7d6443dSThomas Raoux types.push_back(std::get<1>(newRet));
182d7d6443dSThomas Raoux indices.push_back(yieldValues.size() - 1);
183d7d6443dSThomas Raoux } else {
184d7d6443dSThomas Raoux // If the value already exit the region don't create a new output.
185d7d6443dSThomas Raoux for (auto &yieldOperand : llvm::enumerate(yieldValues.getArrayRef())) {
186d7d6443dSThomas Raoux if (yieldOperand.value() == std::get<0>(newRet)) {
187d7d6443dSThomas Raoux indices.push_back(yieldOperand.index());
188d7d6443dSThomas Raoux break;
189d7d6443dSThomas Raoux }
190d7d6443dSThomas Raoux }
191d7d6443dSThomas Raoux }
192d7d6443dSThomas Raoux }
193d7d6443dSThomas Raoux yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
194ed0288f7SThomas Raoux WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
195d7d6443dSThomas Raoux rewriter, warpOp, yieldValues.getArrayRef(), types);
196ed0288f7SThomas Raoux rewriter.replaceOp(warpOp,
197ed0288f7SThomas Raoux newWarpOp.getResults().take_front(warpOp.getNumResults()));
198ed0288f7SThomas Raoux return newWarpOp;
199ed0288f7SThomas Raoux }
200ed0288f7SThomas Raoux
201ed0288f7SThomas Raoux /// Helper to know if an op can be hoisted out of the region.
canBeHoisted(Operation * op,function_ref<bool (Value)> definedOutside)202ed0288f7SThomas Raoux static bool canBeHoisted(Operation *op,
203ed0288f7SThomas Raoux function_ref<bool(Value)> definedOutside) {
204ed0288f7SThomas Raoux return llvm::all_of(op->getOperands(), definedOutside) &&
205ed0288f7SThomas Raoux isSideEffectFree(op) && op->getNumRegions() == 0;
206ed0288f7SThomas Raoux }
207ed0288f7SThomas Raoux
20876cf33daSThomas Raoux /// Return a value yielded by `warpOp` which statifies the filter lamdba
20976cf33daSThomas Raoux /// condition and is not dead.
getWarpResult(WarpExecuteOnLane0Op warpOp,std::function<bool (Operation *)> fn)21076cf33daSThomas Raoux static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
21176cf33daSThomas Raoux std::function<bool(Operation *)> fn) {
21276cf33daSThomas Raoux auto yield = cast<vector::YieldOp>(
21376cf33daSThomas Raoux warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
21476cf33daSThomas Raoux for (OpOperand &yieldOperand : yield->getOpOperands()) {
21576cf33daSThomas Raoux Value yieldValues = yieldOperand.get();
21676cf33daSThomas Raoux Operation *definedOp = yieldValues.getDefiningOp();
21776cf33daSThomas Raoux if (definedOp && fn(definedOp)) {
21876cf33daSThomas Raoux if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
21976cf33daSThomas Raoux return &yieldOperand;
22076cf33daSThomas Raoux }
22176cf33daSThomas Raoux }
22276cf33daSThomas Raoux return {};
22376cf33daSThomas Raoux }
22476cf33daSThomas Raoux
22576cf33daSThomas Raoux // Clones `op` into a new operation that takes `operands` and returns
22676cf33daSThomas Raoux // `resultTypes`.
cloneOpWithOperandsAndTypes(RewriterBase & rewriter,Location loc,Operation * op,ArrayRef<Value> operands,ArrayRef<Type> resultTypes)22776cf33daSThomas Raoux static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
22876cf33daSThomas Raoux Location loc, Operation *op,
22976cf33daSThomas Raoux ArrayRef<Value> operands,
23076cf33daSThomas Raoux ArrayRef<Type> resultTypes) {
23176cf33daSThomas Raoux OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
23276cf33daSThomas Raoux op->getAttrs());
23376cf33daSThomas Raoux return rewriter.create(res);
23476cf33daSThomas Raoux }
23576cf33daSThomas Raoux
23676cf33daSThomas Raoux /// Currently the distribution map is implicit based on the vector shape. In the
23776cf33daSThomas Raoux /// future it will be part of the op.
23876cf33daSThomas Raoux /// Example:
23976cf33daSThomas Raoux /// ```
24076cf33daSThomas Raoux /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
24176cf33daSThomas Raoux /// ...
24276cf33daSThomas Raoux /// vector.yield %3 : vector<32x16x64xf32>
24376cf33daSThomas Raoux /// }
24476cf33daSThomas Raoux /// ```
24576cf33daSThomas Raoux /// Would have an implicit map of:
24676cf33daSThomas Raoux /// `(d0, d1, d2) -> (d0, d2)`
calculateImplicitMap(Value yield,Value ret)24776cf33daSThomas Raoux static AffineMap calculateImplicitMap(Value yield, Value ret) {
24876cf33daSThomas Raoux auto srcType = yield.getType().cast<VectorType>();
24976cf33daSThomas Raoux auto dstType = ret.getType().cast<VectorType>();
25076cf33daSThomas Raoux SmallVector<AffineExpr> perm;
25176cf33daSThomas Raoux // Check which dimensions of the yield value are different than the dimensions
25276cf33daSThomas Raoux // of the result to know the distributed dimensions. Then associate each
25376cf33daSThomas Raoux // distributed dimension to an ID in order.
25476cf33daSThomas Raoux for (unsigned i = 0, e = srcType.getRank(); i < e; i++) {
25576cf33daSThomas Raoux if (srcType.getDimSize(i) != dstType.getDimSize(i))
25676cf33daSThomas Raoux perm.push_back(getAffineDimExpr(i, yield.getContext()));
25776cf33daSThomas Raoux }
25876cf33daSThomas Raoux auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext());
25976cf33daSThomas Raoux return map;
26076cf33daSThomas Raoux }
26176cf33daSThomas Raoux
262d02f10d9SThomas Raoux namespace {
263d02f10d9SThomas Raoux
264d02f10d9SThomas Raoux struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpOpToScfForPattern__anonbfa0aa500211::WarpOpToScfForPattern265d02f10d9SThomas Raoux WarpOpToScfForPattern(MLIRContext *context,
266d02f10d9SThomas Raoux const WarpExecuteOnLane0LoweringOptions &options,
267d02f10d9SThomas Raoux PatternBenefit benefit = 1)
268d02f10d9SThomas Raoux : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
269d02f10d9SThomas Raoux options(options) {}
270d02f10d9SThomas Raoux
matchAndRewrite__anonbfa0aa500211::WarpOpToScfForPattern271d02f10d9SThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
272d02f10d9SThomas Raoux PatternRewriter &rewriter) const override {
273d02f10d9SThomas Raoux return rewriteWarpOpToScfFor(rewriter, warpOp, options);
274d02f10d9SThomas Raoux }
275d02f10d9SThomas Raoux
276d02f10d9SThomas Raoux private:
277d02f10d9SThomas Raoux const WarpExecuteOnLane0LoweringOptions &options;
278d02f10d9SThomas Raoux };
279d02f10d9SThomas Raoux
2806a57d8fbSNicolas Vasilache /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute
2816a57d8fbSNicolas Vasilache /// op with the proper return type.
2826a57d8fbSNicolas Vasilache /// The new write op is updated to write the result of the new warp execute op.
2836a57d8fbSNicolas Vasilache /// The old `writeOp` is deleted.
cloneWriteOp(RewriterBase & rewriter,WarpExecuteOnLane0Op warpOp,vector::TransferWriteOp writeOp,VectorType targetType)2846a57d8fbSNicolas Vasilache static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
2856a57d8fbSNicolas Vasilache WarpExecuteOnLane0Op warpOp,
2866a57d8fbSNicolas Vasilache vector::TransferWriteOp writeOp,
2876a57d8fbSNicolas Vasilache VectorType targetType) {
2886a57d8fbSNicolas Vasilache assert(writeOp->getParentOp() == warpOp &&
2896a57d8fbSNicolas Vasilache "write must be nested immediately under warp");
2906a57d8fbSNicolas Vasilache OpBuilder::InsertionGuard g(rewriter);
291d7d6443dSThomas Raoux SmallVector<size_t> newRetIndices;
2926a57d8fbSNicolas Vasilache WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2936a57d8fbSNicolas Vasilache rewriter, warpOp, ValueRange{{writeOp.getVector()}},
294d7d6443dSThomas Raoux TypeRange{targetType}, newRetIndices);
2956a57d8fbSNicolas Vasilache rewriter.setInsertionPointAfter(newWarpOp);
2966a57d8fbSNicolas Vasilache auto newWriteOp =
2976a57d8fbSNicolas Vasilache cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
2986a57d8fbSNicolas Vasilache rewriter.eraseOp(writeOp);
299d7d6443dSThomas Raoux newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
3006a57d8fbSNicolas Vasilache return newWriteOp;
3016a57d8fbSNicolas Vasilache }
3026a57d8fbSNicolas Vasilache
303ed0288f7SThomas Raoux /// Distribute transfer_write ops based on the affine map returned by
304ed0288f7SThomas Raoux /// `distributionMapFn`.
305ed0288f7SThomas Raoux /// Example:
306ed0288f7SThomas Raoux /// ```
307ed0288f7SThomas Raoux /// %0 = vector.warp_execute_on_lane_0(%id){
308ed0288f7SThomas Raoux /// ...
309ed0288f7SThomas Raoux /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
310ed0288f7SThomas Raoux /// vector.yield
311ed0288f7SThomas Raoux /// }
312ed0288f7SThomas Raoux /// ```
313ed0288f7SThomas Raoux /// To
314ed0288f7SThomas Raoux /// ```
315ed0288f7SThomas Raoux /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
316ed0288f7SThomas Raoux /// ...
317ed0288f7SThomas Raoux /// vector.yield %v : vector<32xf32>
318ed0288f7SThomas Raoux /// }
319ed0288f7SThomas Raoux /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
320ed0288f7SThomas Raoux struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
WarpOpTransferWrite__anonbfa0aa500211::WarpOpTransferWrite321ed0288f7SThomas Raoux WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
322ed0288f7SThomas Raoux PatternBenefit b = 1)
323ed0288f7SThomas Raoux : OpRewritePattern<vector::TransferWriteOp>(ctx, b),
32408d651d7SMehdi Amini distributionMapFn(std::move(fn)) {}
325ed0288f7SThomas Raoux
326ed0288f7SThomas Raoux /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
327ed0288f7SThomas Raoux /// are multiples of the distribution ratio are supported at the moment.
tryDistributeOp__anonbfa0aa500211::WarpOpTransferWrite328ed0288f7SThomas Raoux LogicalResult tryDistributeOp(RewriterBase &rewriter,
329ed0288f7SThomas Raoux vector::TransferWriteOp writeOp,
330ed0288f7SThomas Raoux WarpExecuteOnLane0Op warpOp) const {
3316a57d8fbSNicolas Vasilache VectorType writtenVectorType = writeOp.getVectorType();
3326a57d8fbSNicolas Vasilache
3336a57d8fbSNicolas Vasilache // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
3346a57d8fbSNicolas Vasilache // to separate it from the rest.
3356a57d8fbSNicolas Vasilache if (writtenVectorType.getRank() == 0)
3366a57d8fbSNicolas Vasilache return failure();
3376a57d8fbSNicolas Vasilache
3386a57d8fbSNicolas Vasilache // 2. Compute the distribution map.
339ed0288f7SThomas Raoux AffineMap map = distributionMapFn(writeOp);
3406a57d8fbSNicolas Vasilache if (map.getNumResults() != 1)
3416a57d8fbSNicolas Vasilache return writeOp->emitError("multi-dim distribution not implemented yet");
3426a57d8fbSNicolas Vasilache
3436a57d8fbSNicolas Vasilache // 3. Compute the targetType using the distribution map.
3446a57d8fbSNicolas Vasilache SmallVector<int64_t> targetShape(writtenVectorType.getShape().begin(),
3456a57d8fbSNicolas Vasilache writtenVectorType.getShape().end());
346ed0288f7SThomas Raoux for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
347ed0288f7SThomas Raoux unsigned position = map.getDimPosition(i);
348ed0288f7SThomas Raoux if (targetShape[position] % warpOp.getWarpSize() != 0)
349ed0288f7SThomas Raoux return failure();
350ed0288f7SThomas Raoux targetShape[position] = targetShape[position] / warpOp.getWarpSize();
351ed0288f7SThomas Raoux }
352ed0288f7SThomas Raoux VectorType targetType =
3536a57d8fbSNicolas Vasilache VectorType::get(targetShape, writtenVectorType.getElementType());
354ed0288f7SThomas Raoux
3556a57d8fbSNicolas Vasilache // 4. clone the write into a new WarpExecuteOnLane0Op to separate it from
3566a57d8fbSNicolas Vasilache // the rest.
3576a57d8fbSNicolas Vasilache vector::TransferWriteOp newWriteOp =
3586a57d8fbSNicolas Vasilache cloneWriteOp(rewriter, warpOp, writeOp, targetType);
359ed0288f7SThomas Raoux
3606a57d8fbSNicolas Vasilache // 5. Reindex the write using the distribution map.
3616a57d8fbSNicolas Vasilache auto newWarpOp =
3626a57d8fbSNicolas Vasilache newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
363ed0288f7SThomas Raoux rewriter.setInsertionPoint(newWriteOp);
364ed0288f7SThomas Raoux AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
365ed0288f7SThomas Raoux Location loc = newWriteOp.getLoc();
366ed0288f7SThomas Raoux SmallVector<Value> indices(newWriteOp.getIndices().begin(),
367ed0288f7SThomas Raoux newWriteOp.getIndices().end());
368ed0288f7SThomas Raoux for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
369ed0288f7SThomas Raoux AffineExpr d0, d1;
370ed0288f7SThomas Raoux bindDims(newWarpOp.getContext(), d0, d1);
371ed0288f7SThomas Raoux auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
372ed0288f7SThomas Raoux if (!indexExpr)
373ed0288f7SThomas Raoux continue;
374ed0288f7SThomas Raoux unsigned indexPos = indexExpr.getPosition();
375ed0288f7SThomas Raoux unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
3766a57d8fbSNicolas Vasilache auto scale = rewriter.getAffineConstantExpr(targetShape[vectorPos]);
377ed0288f7SThomas Raoux indices[indexPos] =
378ed0288f7SThomas Raoux makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
379ed0288f7SThomas Raoux {indices[indexPos], newWarpOp.getLaneid()});
380ed0288f7SThomas Raoux }
381ed0288f7SThomas Raoux newWriteOp.getIndicesMutable().assign(indices);
382ed0288f7SThomas Raoux
383ed0288f7SThomas Raoux return success();
384ed0288f7SThomas Raoux }
385ed0288f7SThomas Raoux
386ed0288f7SThomas Raoux /// Extract TransferWriteOps of vector<1x> into a separate warp op.
tryExtractOp__anonbfa0aa500211::WarpOpTransferWrite387ed0288f7SThomas Raoux LogicalResult tryExtractOp(RewriterBase &rewriter,
388ed0288f7SThomas Raoux vector::TransferWriteOp writeOp,
389ed0288f7SThomas Raoux WarpExecuteOnLane0Op warpOp) const {
390ed0288f7SThomas Raoux Location loc = writeOp.getLoc();
391ed0288f7SThomas Raoux VectorType vecType = writeOp.getVectorType();
392ed0288f7SThomas Raoux
3937eba5cdfSThomas Raoux // Only sink out vector of 1 element for now to not serialize large vector
3947eba5cdfSThomas Raoux // store. This can later be controlled by user.
3957eba5cdfSThomas Raoux if (vecType.getNumElements() != 1)
396ed0288f7SThomas Raoux return failure();
397ed0288f7SThomas Raoux
398ed0288f7SThomas Raoux // Do not process warp ops that contain only TransferWriteOps.
399ed0288f7SThomas Raoux if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
400ed0288f7SThomas Raoux return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
401ed0288f7SThomas Raoux }))
402ed0288f7SThomas Raoux return failure();
403ed0288f7SThomas Raoux
404ed0288f7SThomas Raoux SmallVector<Value> yieldValues = {writeOp.getVector()};
405ed0288f7SThomas Raoux SmallVector<Type> retTypes = {vecType};
406d7d6443dSThomas Raoux SmallVector<size_t> newRetIndices;
407ed0288f7SThomas Raoux WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
408d7d6443dSThomas Raoux rewriter, warpOp, yieldValues, retTypes, newRetIndices);
409ed0288f7SThomas Raoux rewriter.setInsertionPointAfter(newWarpOp);
410ed0288f7SThomas Raoux
411ed0288f7SThomas Raoux // Create a second warp op that contains only writeOp.
412ed0288f7SThomas Raoux auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
413ed0288f7SThomas Raoux loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
414ed0288f7SThomas Raoux Block &body = secondWarpOp.getBodyRegion().front();
415ed0288f7SThomas Raoux rewriter.setInsertionPointToStart(&body);
416ed0288f7SThomas Raoux auto newWriteOp =
417ed0288f7SThomas Raoux cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
418d7d6443dSThomas Raoux newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
419ed0288f7SThomas Raoux rewriter.eraseOp(writeOp);
420ed0288f7SThomas Raoux rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
421ed0288f7SThomas Raoux return success();
422ed0288f7SThomas Raoux }
423ed0288f7SThomas Raoux
matchAndRewrite__anonbfa0aa500211::WarpOpTransferWrite424ed0288f7SThomas Raoux LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
425ed0288f7SThomas Raoux PatternRewriter &rewriter) const override {
426ed0288f7SThomas Raoux // Ops with mask not supported yet.
427ed0288f7SThomas Raoux if (writeOp.getMask())
428ed0288f7SThomas Raoux return failure();
429ed0288f7SThomas Raoux
430ed0288f7SThomas Raoux auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
431ed0288f7SThomas Raoux if (!warpOp)
432ed0288f7SThomas Raoux return failure();
433ed0288f7SThomas Raoux
434ed0288f7SThomas Raoux // There must be no op with a side effect after writeOp.
435ed0288f7SThomas Raoux Operation *nextOp = writeOp.getOperation();
436ed0288f7SThomas Raoux while ((nextOp = nextOp->getNextNode()))
437ed0288f7SThomas Raoux if (!isSideEffectFree(nextOp))
438ed0288f7SThomas Raoux return failure();
439ed0288f7SThomas Raoux
440ed0288f7SThomas Raoux if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
441ed0288f7SThomas Raoux return writeOp.getVector() == value ||
442ed0288f7SThomas Raoux warpOp.isDefinedOutsideOfRegion(value);
443ed0288f7SThomas Raoux }))
444ed0288f7SThomas Raoux return failure();
445ed0288f7SThomas Raoux
446ed0288f7SThomas Raoux if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
447ed0288f7SThomas Raoux return success();
448ed0288f7SThomas Raoux
449ed0288f7SThomas Raoux if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
450ed0288f7SThomas Raoux return success();
451ed0288f7SThomas Raoux
452ed0288f7SThomas Raoux return failure();
453ed0288f7SThomas Raoux }
454ed0288f7SThomas Raoux
455ed0288f7SThomas Raoux private:
456ed0288f7SThomas Raoux DistributionMapFn distributionMapFn;
457ed0288f7SThomas Raoux };
458ed0288f7SThomas Raoux
45976cf33daSThomas Raoux /// Sink out elementwise op feeding into a warp op yield.
46076cf33daSThomas Raoux /// ```
46176cf33daSThomas Raoux /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
46276cf33daSThomas Raoux /// ...
46376cf33daSThomas Raoux /// %3 = arith.addf %1, %2 : vector<32xf32>
46476cf33daSThomas Raoux /// vector.yield %3 : vector<32xf32>
46576cf33daSThomas Raoux /// }
46676cf33daSThomas Raoux /// ```
46776cf33daSThomas Raoux /// To
46876cf33daSThomas Raoux /// ```
46976cf33daSThomas Raoux /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
47076cf33daSThomas Raoux /// vector<1xf32>, vector<1xf32>) {
47176cf33daSThomas Raoux /// ...
47276cf33daSThomas Raoux /// %4 = arith.addf %2, %3 : vector<32xf32>
47376cf33daSThomas Raoux /// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
47476cf33daSThomas Raoux /// vector<32xf32>
47576cf33daSThomas Raoux /// }
47676cf33daSThomas Raoux /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
47776cf33daSThomas Raoux struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
47876cf33daSThomas Raoux using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpElementwise47976cf33daSThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
48076cf33daSThomas Raoux PatternRewriter &rewriter) const override {
48176cf33daSThomas Raoux OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
48276cf33daSThomas Raoux return OpTrait::hasElementwiseMappableTraits(op);
48376cf33daSThomas Raoux });
48476cf33daSThomas Raoux if (!yieldOperand)
48576cf33daSThomas Raoux return failure();
48676cf33daSThomas Raoux Operation *elementWise = yieldOperand->get().getDefiningOp();
48776cf33daSThomas Raoux unsigned operandIndex = yieldOperand->getOperandNumber();
48876cf33daSThomas Raoux Value distributedVal = warpOp.getResult(operandIndex);
48976cf33daSThomas Raoux SmallVector<Value> yieldValues;
49076cf33daSThomas Raoux SmallVector<Type> retTypes;
49176cf33daSThomas Raoux Location loc = warpOp.getLoc();
49276cf33daSThomas Raoux for (OpOperand &operand : elementWise->getOpOperands()) {
49376cf33daSThomas Raoux Type targetType;
49476cf33daSThomas Raoux if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) {
49576cf33daSThomas Raoux // If the result type is a vector, the operands must also be vectors.
49676cf33daSThomas Raoux auto operandType = operand.get().getType().cast<VectorType>();
49776cf33daSThomas Raoux targetType =
49876cf33daSThomas Raoux VectorType::get(vecType.getShape(), operandType.getElementType());
49976cf33daSThomas Raoux } else {
50076cf33daSThomas Raoux auto operandType = operand.get().getType();
50176cf33daSThomas Raoux assert(!operandType.isa<VectorType>() &&
50276cf33daSThomas Raoux "unexpected yield of vector from op with scalar result type");
50376cf33daSThomas Raoux targetType = operandType;
50476cf33daSThomas Raoux }
50576cf33daSThomas Raoux retTypes.push_back(targetType);
50676cf33daSThomas Raoux yieldValues.push_back(operand.get());
50776cf33daSThomas Raoux }
508d7d6443dSThomas Raoux SmallVector<size_t> newRetIndices;
50976cf33daSThomas Raoux WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
510d7d6443dSThomas Raoux rewriter, warpOp, yieldValues, retTypes, newRetIndices);
51176cf33daSThomas Raoux rewriter.setInsertionPointAfter(newWarpOp);
51276cf33daSThomas Raoux SmallVector<Value> newOperands(elementWise->getOperands().begin(),
51376cf33daSThomas Raoux elementWise->getOperands().end());
51476cf33daSThomas Raoux for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
515d7d6443dSThomas Raoux newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
51676cf33daSThomas Raoux }
51776cf33daSThomas Raoux OpBuilder::InsertionGuard g(rewriter);
51876cf33daSThomas Raoux rewriter.setInsertionPointAfter(newWarpOp);
51976cf33daSThomas Raoux Operation *newOp = cloneOpWithOperandsAndTypes(
52076cf33daSThomas Raoux rewriter, loc, elementWise, newOperands,
52176cf33daSThomas Raoux {newWarpOp.getResult(operandIndex).getType()});
52276cf33daSThomas Raoux newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0));
52376cf33daSThomas Raoux return success();
52476cf33daSThomas Raoux }
52576cf33daSThomas Raoux };
52676cf33daSThomas Raoux
5270af26805SThomas Raoux /// Sink out splat constant op feeding into a warp op yield.
5280af26805SThomas Raoux /// ```
5290af26805SThomas Raoux /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
5300af26805SThomas Raoux /// ...
5310af26805SThomas Raoux /// %cst = arith.constant dense<2.0> : vector<32xf32>
5320af26805SThomas Raoux /// vector.yield %cst : vector<32xf32>
5330af26805SThomas Raoux /// }
5340af26805SThomas Raoux /// ```
5350af26805SThomas Raoux /// To
5360af26805SThomas Raoux /// ```
5370af26805SThomas Raoux /// vector.warp_execute_on_lane_0(%arg0 {
5380af26805SThomas Raoux /// ...
5390af26805SThomas Raoux /// }
5400af26805SThomas Raoux /// %0 = arith.constant dense<2.0> : vector<1xf32>
5410af26805SThomas Raoux struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
5420af26805SThomas Raoux using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpConstant5430af26805SThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
5440af26805SThomas Raoux PatternRewriter &rewriter) const override {
5450af26805SThomas Raoux OpOperand *yieldOperand = getWarpResult(
5460af26805SThomas Raoux warpOp, [](Operation *op) { return isa<arith::ConstantOp>(op); });
5470af26805SThomas Raoux if (!yieldOperand)
5480af26805SThomas Raoux return failure();
5490af26805SThomas Raoux auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
5500af26805SThomas Raoux auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
5510af26805SThomas Raoux if (!dense)
5520af26805SThomas Raoux return failure();
5530af26805SThomas Raoux unsigned operandIndex = yieldOperand->getOperandNumber();
5540af26805SThomas Raoux Attribute scalarAttr = dense.getSplatValue<Attribute>();
5550af26805SThomas Raoux Attribute newAttr = DenseElementsAttr::get(
5560af26805SThomas Raoux warpOp.getResult(operandIndex).getType(), scalarAttr);
5570af26805SThomas Raoux Location loc = warpOp.getLoc();
5580af26805SThomas Raoux rewriter.setInsertionPointAfter(warpOp);
5590af26805SThomas Raoux Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
5600af26805SThomas Raoux warpOp.getResult(operandIndex).replaceAllUsesWith(distConstant);
5610af26805SThomas Raoux return success();
5620af26805SThomas Raoux }
5630af26805SThomas Raoux };
5640af26805SThomas Raoux
56576cf33daSThomas Raoux /// Sink out transfer_read op feeding into a warp op yield.
56676cf33daSThomas Raoux /// ```
56776cf33daSThomas Raoux /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
56876cf33daSThomas Raoux /// ...
56976cf33daSThomas Raoux // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
57076cf33daSThomas Raoux // vector<32xf32>
57176cf33daSThomas Raoux /// vector.yield %2 : vector<32xf32>
57276cf33daSThomas Raoux /// }
57376cf33daSThomas Raoux /// ```
57476cf33daSThomas Raoux /// To
57576cf33daSThomas Raoux /// ```
57676cf33daSThomas Raoux /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
57776cf33daSThomas Raoux /// vector<1xf32>, vector<1xf32>) {
57876cf33daSThomas Raoux /// ...
57976cf33daSThomas Raoux /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
58076cf33daSThomas Raoux /// vector<32xf32> vector.yield %2 : vector<32xf32>
58176cf33daSThomas Raoux /// }
58276cf33daSThomas Raoux /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
58376cf33daSThomas Raoux struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
58476cf33daSThomas Raoux using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpTransferRead58576cf33daSThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
58676cf33daSThomas Raoux PatternRewriter &rewriter) const override {
58776cf33daSThomas Raoux OpOperand *operand = getWarpResult(
58876cf33daSThomas Raoux warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
58976cf33daSThomas Raoux if (!operand)
59076cf33daSThomas Raoux return failure();
59176cf33daSThomas Raoux auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
59276cf33daSThomas Raoux unsigned operandIndex = operand->getOperandNumber();
59376cf33daSThomas Raoux Value distributedVal = warpOp.getResult(operandIndex);
59476cf33daSThomas Raoux
59576cf33daSThomas Raoux SmallVector<Value, 4> indices(read.getIndices().begin(),
59676cf33daSThomas Raoux read.getIndices().end());
59776cf33daSThomas Raoux AffineMap map = calculateImplicitMap(read.getResult(), distributedVal);
59876cf33daSThomas Raoux AffineMap indexMap = map.compose(read.getPermutationMap());
59976cf33daSThomas Raoux OpBuilder::InsertionGuard g(rewriter);
60076cf33daSThomas Raoux rewriter.setInsertionPointAfter(warpOp);
60176cf33daSThomas Raoux for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
60276cf33daSThomas Raoux AffineExpr d0, d1;
60376cf33daSThomas Raoux bindDims(read.getContext(), d0, d1);
60476cf33daSThomas Raoux auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
60576cf33daSThomas Raoux if (!indexExpr)
60676cf33daSThomas Raoux continue;
60776cf33daSThomas Raoux unsigned indexPos = indexExpr.getPosition();
60876cf33daSThomas Raoux unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
60976cf33daSThomas Raoux int64_t scale =
61076cf33daSThomas Raoux distributedVal.getType().cast<VectorType>().getDimSize(vectorPos);
61176cf33daSThomas Raoux indices[indexPos] =
61276cf33daSThomas Raoux makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1,
61376cf33daSThomas Raoux {indices[indexPos], warpOp.getLaneid()});
61476cf33daSThomas Raoux }
61576cf33daSThomas Raoux Value newRead = rewriter.create<vector::TransferReadOp>(
61676cf33daSThomas Raoux read.getLoc(), distributedVal.getType(), read.getSource(), indices,
61776cf33daSThomas Raoux read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
61876cf33daSThomas Raoux read.getInBoundsAttr());
61976cf33daSThomas Raoux distributedVal.replaceAllUsesWith(newRead);
62076cf33daSThomas Raoux return success();
62176cf33daSThomas Raoux }
62276cf33daSThomas Raoux };
62376cf33daSThomas Raoux
62476cf33daSThomas Raoux /// Remove any result that has no use along with the matching yieldOp operand.
62576cf33daSThomas Raoux // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
62676cf33daSThomas Raoux struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
62776cf33daSThomas Raoux using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpDeadResult62876cf33daSThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
62976cf33daSThomas Raoux PatternRewriter &rewriter) const override {
63076cf33daSThomas Raoux SmallVector<Type> resultTypes;
63176cf33daSThomas Raoux SmallVector<Value> yieldValues;
63276cf33daSThomas Raoux auto yield = cast<vector::YieldOp>(
63376cf33daSThomas Raoux warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
63476cf33daSThomas Raoux for (OpResult result : warpOp.getResults()) {
63576cf33daSThomas Raoux if (result.use_empty())
63676cf33daSThomas Raoux continue;
63776cf33daSThomas Raoux resultTypes.push_back(result.getType());
63876cf33daSThomas Raoux yieldValues.push_back(yield.getOperand(result.getResultNumber()));
63976cf33daSThomas Raoux }
64076cf33daSThomas Raoux if (yield.getNumOperands() == yieldValues.size())
64176cf33daSThomas Raoux return failure();
64276cf33daSThomas Raoux WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
64376cf33daSThomas Raoux rewriter, warpOp, yieldValues, resultTypes);
64476cf33daSThomas Raoux unsigned resultIndex = 0;
64576cf33daSThomas Raoux for (OpResult result : warpOp.getResults()) {
64676cf33daSThomas Raoux if (result.use_empty())
64776cf33daSThomas Raoux continue;
64876cf33daSThomas Raoux result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++));
64976cf33daSThomas Raoux }
65076cf33daSThomas Raoux rewriter.eraseOp(warpOp);
65176cf33daSThomas Raoux return success();
65276cf33daSThomas Raoux }
65376cf33daSThomas Raoux };
65476cf33daSThomas Raoux
65576cf33daSThomas Raoux // If an operand is directly yielded out of the region we can forward it
65676cf33daSThomas Raoux // directly and it doesn't need to go through the region.
65776cf33daSThomas Raoux struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
65876cf33daSThomas Raoux using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpForwardOperand65976cf33daSThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
66076cf33daSThomas Raoux PatternRewriter &rewriter) const override {
66176cf33daSThomas Raoux SmallVector<Type> resultTypes;
66276cf33daSThomas Raoux SmallVector<Value> yieldValues;
66376cf33daSThomas Raoux auto yield = cast<vector::YieldOp>(
66476cf33daSThomas Raoux warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
66576cf33daSThomas Raoux Value valForwarded;
66676cf33daSThomas Raoux unsigned resultIndex;
66776cf33daSThomas Raoux for (OpOperand &operand : yield->getOpOperands()) {
66876cf33daSThomas Raoux Value result = warpOp.getResult(operand.getOperandNumber());
66976cf33daSThomas Raoux if (result.use_empty())
67076cf33daSThomas Raoux continue;
67176cf33daSThomas Raoux
67276cf33daSThomas Raoux // Assume all the values coming from above are uniform.
67376cf33daSThomas Raoux if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
67476cf33daSThomas Raoux if (result.getType() != operand.get().getType())
67576cf33daSThomas Raoux continue;
67676cf33daSThomas Raoux valForwarded = operand.get();
67776cf33daSThomas Raoux resultIndex = operand.getOperandNumber();
67876cf33daSThomas Raoux break;
67976cf33daSThomas Raoux }
68076cf33daSThomas Raoux auto arg = operand.get().dyn_cast<BlockArgument>();
68176cf33daSThomas Raoux if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
68276cf33daSThomas Raoux continue;
68376cf33daSThomas Raoux Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
68476cf33daSThomas Raoux if (result.getType() != warpOperand.getType())
68576cf33daSThomas Raoux continue;
68676cf33daSThomas Raoux valForwarded = warpOperand;
68776cf33daSThomas Raoux resultIndex = operand.getOperandNumber();
68876cf33daSThomas Raoux break;
68976cf33daSThomas Raoux }
69076cf33daSThomas Raoux if (!valForwarded)
69176cf33daSThomas Raoux return failure();
69276cf33daSThomas Raoux warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded);
69376cf33daSThomas Raoux return success();
69476cf33daSThomas Raoux }
69576cf33daSThomas Raoux };
69676cf33daSThomas Raoux
69776cf33daSThomas Raoux struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
69876cf33daSThomas Raoux using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpBroadcast69976cf33daSThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
70076cf33daSThomas Raoux PatternRewriter &rewriter) const override {
70176cf33daSThomas Raoux OpOperand *operand = getWarpResult(
70276cf33daSThomas Raoux warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); });
70376cf33daSThomas Raoux if (!operand)
70476cf33daSThomas Raoux return failure();
70576cf33daSThomas Raoux unsigned int operandNumber = operand->getOperandNumber();
70676cf33daSThomas Raoux auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
70776cf33daSThomas Raoux Location loc = broadcastOp.getLoc();
70876cf33daSThomas Raoux auto destVecType =
70976cf33daSThomas Raoux warpOp->getResultTypes()[operandNumber].cast<VectorType>();
710d7d6443dSThomas Raoux SmallVector<size_t> newRetIndices;
71176cf33daSThomas Raoux WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
71276cf33daSThomas Raoux rewriter, warpOp, {broadcastOp.getSource()},
713d7d6443dSThomas Raoux {broadcastOp.getSource().getType()}, newRetIndices);
71476cf33daSThomas Raoux rewriter.setInsertionPointAfter(newWarpOp);
71576cf33daSThomas Raoux Value broadcasted = rewriter.create<vector::BroadcastOp>(
716d7d6443dSThomas Raoux loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
71776cf33daSThomas Raoux newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
71876cf33daSThomas Raoux return success();
71976cf33daSThomas Raoux }
72076cf33daSThomas Raoux };
72176cf33daSThomas Raoux
722*f48ce52cSThomas Raoux /// Pattern to move out vector.extract of single element vector. Those don't
723*f48ce52cSThomas Raoux /// need to be distributed and can just be propagated outside of the region.
724*f48ce52cSThomas Raoux struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
725*f48ce52cSThomas Raoux using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpExtract726*f48ce52cSThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
727*f48ce52cSThomas Raoux PatternRewriter &rewriter) const override {
728*f48ce52cSThomas Raoux OpOperand *operand = getWarpResult(
729*f48ce52cSThomas Raoux warpOp, [](Operation *op) { return isa<vector::ExtractOp>(op); });
730*f48ce52cSThomas Raoux if (!operand)
731*f48ce52cSThomas Raoux return failure();
732*f48ce52cSThomas Raoux unsigned int operandNumber = operand->getOperandNumber();
733*f48ce52cSThomas Raoux auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
734*f48ce52cSThomas Raoux if (extractOp.getVectorType().getNumElements() != 1)
735*f48ce52cSThomas Raoux return failure();
736*f48ce52cSThomas Raoux Location loc = extractOp.getLoc();
737*f48ce52cSThomas Raoux SmallVector<size_t> newRetIndices;
738*f48ce52cSThomas Raoux WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
739*f48ce52cSThomas Raoux rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()},
740*f48ce52cSThomas Raoux newRetIndices);
741*f48ce52cSThomas Raoux rewriter.setInsertionPointAfter(newWarpOp);
742*f48ce52cSThomas Raoux Value newExtract = rewriter.create<vector::ExtractOp>(
743*f48ce52cSThomas Raoux loc, newWarpOp->getResult(newRetIndices[0]), extractOp.getPosition());
744*f48ce52cSThomas Raoux newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
745*f48ce52cSThomas Raoux return success();
746*f48ce52cSThomas Raoux }
747*f48ce52cSThomas Raoux };
748*f48ce52cSThomas Raoux
74976cf33daSThomas Raoux /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
75076cf33daSThomas Raoux /// the scf.ForOp is the last operation in the region so that it doesn't change
75176cf33daSThomas Raoux /// the order of execution. This creates a new scf.for region after the
75276cf33daSThomas Raoux /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
75376cf33daSThomas Raoux /// WarpExecuteOnLane0Op region. Example:
75476cf33daSThomas Raoux /// ```
75576cf33daSThomas Raoux /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
75676cf33daSThomas Raoux /// ...
75776cf33daSThomas Raoux /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
75876cf33daSThomas Raoux /// -> (vector<128xf32>) {
75976cf33daSThomas Raoux /// ...
76076cf33daSThomas Raoux /// scf.yield %r : vector<128xf32>
76176cf33daSThomas Raoux /// }
76276cf33daSThomas Raoux /// vector.yield %v1 : vector<128xf32>
76376cf33daSThomas Raoux /// }
76476cf33daSThomas Raoux /// ```
76576cf33daSThomas Raoux /// To:
76676cf33daSThomas Raoux /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
76776cf33daSThomas Raoux /// ...
76876cf33daSThomas Raoux /// vector.yield %v : vector<128xf32>
76976cf33daSThomas Raoux /// }
77076cf33daSThomas Raoux /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
77176cf33daSThomas Raoux /// -> (vector<4xf32>) {
77276cf33daSThomas Raoux /// %iw = vector.warp_execute_on_lane_0(%laneid)
77376cf33daSThomas Raoux /// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
77476cf33daSThomas Raoux /// ^bb0(%arg: vector<128xf32>):
77576cf33daSThomas Raoux /// ...
77676cf33daSThomas Raoux /// vector.yield %ir : vector<128xf32>
77776cf33daSThomas Raoux /// }
77876cf33daSThomas Raoux /// scf.yield %iw : vector<4xf32>
77976cf33daSThomas Raoux /// }
78076cf33daSThomas Raoux /// ```
78176cf33daSThomas Raoux struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
78276cf33daSThomas Raoux using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpScfForOp78376cf33daSThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
78476cf33daSThomas Raoux PatternRewriter &rewriter) const override {
78576cf33daSThomas Raoux auto yield = cast<vector::YieldOp>(
78676cf33daSThomas Raoux warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
78776cf33daSThomas Raoux // Only pick up forOp if it is the last op in the region.
78876cf33daSThomas Raoux Operation *lastNode = yield->getPrevNode();
78976cf33daSThomas Raoux auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
79076cf33daSThomas Raoux if (!forOp)
79176cf33daSThomas Raoux return failure();
79276cf33daSThomas Raoux SmallVector<Value> newOperands;
79376cf33daSThomas Raoux SmallVector<unsigned> resultIdx;
79476cf33daSThomas Raoux // Collect all the outputs coming from the forOp.
79576cf33daSThomas Raoux for (OpOperand &yieldOperand : yield->getOpOperands()) {
79676cf33daSThomas Raoux if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
79776cf33daSThomas Raoux continue;
79876cf33daSThomas Raoux auto forResult = yieldOperand.get().cast<OpResult>();
79976cf33daSThomas Raoux newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber()));
80076cf33daSThomas Raoux yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
80176cf33daSThomas Raoux resultIdx.push_back(yieldOperand.getOperandNumber());
80276cf33daSThomas Raoux }
80376cf33daSThomas Raoux OpBuilder::InsertionGuard g(rewriter);
80476cf33daSThomas Raoux rewriter.setInsertionPointAfter(warpOp);
80576cf33daSThomas Raoux // Create a new for op outside the region with a WarpExecuteOnLane0Op region
80676cf33daSThomas Raoux // inside.
80776cf33daSThomas Raoux auto newForOp = rewriter.create<scf::ForOp>(
80876cf33daSThomas Raoux forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
80976cf33daSThomas Raoux forOp.getStep(), newOperands);
81076cf33daSThomas Raoux rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
81176cf33daSThomas Raoux auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
81276cf33daSThomas Raoux warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(),
81376cf33daSThomas Raoux warpOp.getWarpSize(), newForOp.getRegionIterArgs(),
81476cf33daSThomas Raoux forOp.getResultTypes());
81576cf33daSThomas Raoux
81676cf33daSThomas Raoux SmallVector<Value> argMapping;
81776cf33daSThomas Raoux argMapping.push_back(newForOp.getInductionVar());
81876cf33daSThomas Raoux for (Value args : innerWarp.getBody()->getArguments()) {
81976cf33daSThomas Raoux argMapping.push_back(args);
82076cf33daSThomas Raoux }
82176cf33daSThomas Raoux SmallVector<Value> yieldOperands;
82276cf33daSThomas Raoux for (Value operand : forOp.getBody()->getTerminator()->getOperands())
82376cf33daSThomas Raoux yieldOperands.push_back(operand);
82476cf33daSThomas Raoux rewriter.eraseOp(forOp.getBody()->getTerminator());
82576cf33daSThomas Raoux rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
82676cf33daSThomas Raoux rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
82776cf33daSThomas Raoux rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
82876cf33daSThomas Raoux rewriter.setInsertionPointAfter(innerWarp);
829d343cdd5SThomas Raoux if (!innerWarp.getResults().empty())
83076cf33daSThomas Raoux rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
83176cf33daSThomas Raoux rewriter.eraseOp(forOp);
83276cf33daSThomas Raoux // Replace the warpOp result coming from the original ForOp.
83376cf33daSThomas Raoux for (const auto &res : llvm::enumerate(resultIdx)) {
83476cf33daSThomas Raoux warpOp.getResult(res.value())
83576cf33daSThomas Raoux .replaceAllUsesWith(newForOp.getResult(res.index()));
83676cf33daSThomas Raoux newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value()));
83776cf33daSThomas Raoux }
83876cf33daSThomas Raoux return success();
83976cf33daSThomas Raoux }
84076cf33daSThomas Raoux };
84176cf33daSThomas Raoux
842087aba4fSThomas Raoux /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
8436834803cSThomas Raoux /// The vector is reduced in parallel. Currently limited to vector size matching
8446834803cSThomas Raoux /// the warpOp size. E.g.:
845087aba4fSThomas Raoux /// ```
8466834803cSThomas Raoux /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
847087aba4fSThomas Raoux /// %0 = "some_def"() : () -> (vector<32xf32>)
848087aba4fSThomas Raoux /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
849087aba4fSThomas Raoux /// vector_ext.yield %1 : f32
850087aba4fSThomas Raoux /// }
851087aba4fSThomas Raoux /// ```
852087aba4fSThomas Raoux /// is lowered to:
853087aba4fSThomas Raoux /// ```
8546834803cSThomas Raoux /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
855087aba4fSThomas Raoux /// %1 = "some_def"() : () -> (vector<32xf32>)
856087aba4fSThomas Raoux /// vector_ext.yield %1 : vector<32xf32>
857087aba4fSThomas Raoux /// }
858087aba4fSThomas Raoux /// %a = vector.extract %0[0] : vector<1xf32>
8596834803cSThomas Raoux /// %r = ("warp.reduction %a")
860087aba4fSThomas Raoux /// ```
8616834803cSThomas Raoux struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpOpReduction__anonbfa0aa500211::WarpOpReduction8626834803cSThomas Raoux WarpOpReduction(MLIRContext *context,
8636834803cSThomas Raoux DistributedReductionFn distributedReductionFn,
8646834803cSThomas Raoux PatternBenefit benefit = 1)
8656834803cSThomas Raoux : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
8666834803cSThomas Raoux distributedReductionFn(distributedReductionFn) {}
867087aba4fSThomas Raoux
matchAndRewrite__anonbfa0aa500211::WarpOpReduction868087aba4fSThomas Raoux LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
869087aba4fSThomas Raoux PatternRewriter &rewriter) const override {
870087aba4fSThomas Raoux OpOperand *yieldOperand = getWarpResult(
871087aba4fSThomas Raoux warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); });
872087aba4fSThomas Raoux if (!yieldOperand)
873087aba4fSThomas Raoux return failure();
874087aba4fSThomas Raoux
875087aba4fSThomas Raoux auto reductionOp =
876087aba4fSThomas Raoux cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
877087aba4fSThomas Raoux auto vectorType = reductionOp.getVector().getType().cast<VectorType>();
878087aba4fSThomas Raoux // Only rank 1 vectors supported.
879087aba4fSThomas Raoux if (vectorType.getRank() != 1)
880087aba4fSThomas Raoux return rewriter.notifyMatchFailure(
881087aba4fSThomas Raoux warpOp, "Only rank 1 reductions can be distributed.");
882087aba4fSThomas Raoux // Only warp_size-sized vectors supported.
8830660f3c5SThomas Raoux if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
884087aba4fSThomas Raoux return rewriter.notifyMatchFailure(
885087aba4fSThomas Raoux warpOp, "Reduction vector dimension must match was size.");
886087aba4fSThomas Raoux // Only f32 and i32 element types are supported.
887087aba4fSThomas Raoux if (!reductionOp.getType().isF32() &&
888087aba4fSThomas Raoux !reductionOp.getType().isSignlessInteger(32))
889087aba4fSThomas Raoux return rewriter.notifyMatchFailure(
890087aba4fSThomas Raoux warpOp,
891087aba4fSThomas Raoux "Reduction distribution currently only supports 32bits types.");
892087aba4fSThomas Raoux
8930660f3c5SThomas Raoux int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
894087aba4fSThomas Raoux // Return vector that will be reduced from the WarpExecuteOnLane0Op.
895087aba4fSThomas Raoux unsigned operandIndex = yieldOperand->getOperandNumber();
896087aba4fSThomas Raoux SmallVector<Value> yieldValues = {reductionOp.getVector()};
8970660f3c5SThomas Raoux SmallVector<Type> retTypes = {
8980660f3c5SThomas Raoux VectorType::get({numElements}, reductionOp.getType())};
899ffa7384fSThomas Raoux if (reductionOp.getAcc()) {
900ffa7384fSThomas Raoux yieldValues.push_back(reductionOp.getAcc());
901ffa7384fSThomas Raoux retTypes.push_back(reductionOp.getAcc().getType());
902ffa7384fSThomas Raoux }
903d7d6443dSThomas Raoux SmallVector<size_t> newRetIndices;
904087aba4fSThomas Raoux WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
905d7d6443dSThomas Raoux rewriter, warpOp, yieldValues, retTypes, newRetIndices);
906087aba4fSThomas Raoux rewriter.setInsertionPointAfter(newWarpOp);
907087aba4fSThomas Raoux
908d7d6443dSThomas Raoux Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
9090660f3c5SThomas Raoux // First reduce on a single thread.
9100660f3c5SThomas Raoux Value perLaneReduction = rewriter.create<vector::ReductionOp>(
9110660f3c5SThomas Raoux reductionOp.getLoc(), reductionOp.getKind(), laneValVec);
9120660f3c5SThomas Raoux // Then distribute across threads.
9130660f3c5SThomas Raoux Value fullReduce =
9140660f3c5SThomas Raoux distributedReductionFn(reductionOp.getLoc(), rewriter, perLaneReduction,
9156834803cSThomas Raoux reductionOp.getKind(), newWarpOp.getWarpSize());
916ffa7384fSThomas Raoux if (reductionOp.getAcc()) {
917ffa7384fSThomas Raoux fullReduce = vector::makeArithReduction(
918ffa7384fSThomas Raoux rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
919ffa7384fSThomas Raoux newWarpOp.getResult(newRetIndices[1]));
920ffa7384fSThomas Raoux }
9210660f3c5SThomas Raoux newWarpOp.getResult(operandIndex).replaceAllUsesWith(fullReduce);
922087aba4fSThomas Raoux return success();
923087aba4fSThomas Raoux }
9246834803cSThomas Raoux
9256834803cSThomas Raoux private:
9266834803cSThomas Raoux DistributedReductionFn distributedReductionFn;
927087aba4fSThomas Raoux };
928087aba4fSThomas Raoux
929d02f10d9SThomas Raoux } // namespace
930d02f10d9SThomas Raoux
populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet & patterns,const WarpExecuteOnLane0LoweringOptions & options)931d02f10d9SThomas Raoux void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
932d02f10d9SThomas Raoux RewritePatternSet &patterns,
933d02f10d9SThomas Raoux const WarpExecuteOnLane0LoweringOptions &options) {
934d02f10d9SThomas Raoux patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options);
935d02f10d9SThomas Raoux }
936ed0288f7SThomas Raoux
populateDistributeTransferWriteOpPatterns(RewritePatternSet & patterns,const DistributionMapFn & distributionMapFn)937ed0288f7SThomas Raoux void mlir::vector::populateDistributeTransferWriteOpPatterns(
93808d651d7SMehdi Amini RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn) {
939ed0288f7SThomas Raoux patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn);
940ed0288f7SThomas Raoux }
941ed0288f7SThomas Raoux
populatePropagateWarpVectorDistributionPatterns(RewritePatternSet & patterns)94276cf33daSThomas Raoux void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
94376cf33daSThomas Raoux RewritePatternSet &patterns) {
94476cf33daSThomas Raoux patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
945*f48ce52cSThomas Raoux WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
946*f48ce52cSThomas Raoux WarpOpScfForOp, WarpOpConstant>(patterns.getContext());
94776cf33daSThomas Raoux }
94876cf33daSThomas Raoux
populateDistributeReduction(RewritePatternSet & patterns,DistributedReductionFn distributedReductionFn)9496834803cSThomas Raoux void mlir::vector::populateDistributeReduction(
9506834803cSThomas Raoux RewritePatternSet &patterns,
9516834803cSThomas Raoux DistributedReductionFn distributedReductionFn) {
9526834803cSThomas Raoux patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn);
953087aba4fSThomas Raoux }
954087aba4fSThomas Raoux
moveScalarUniformCode(WarpExecuteOnLane0Op warpOp)955ed0288f7SThomas Raoux void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
956ed0288f7SThomas Raoux Block *body = warpOp.getBody();
957ed0288f7SThomas Raoux
958ed0288f7SThomas Raoux // Keep track of the ops we want to hoist.
959ed0288f7SThomas Raoux llvm::SmallSetVector<Operation *, 8> opsToMove;
960ed0288f7SThomas Raoux
961ed0288f7SThomas Raoux // Helper to check if a value is or will be defined outside of the region.
962ed0288f7SThomas Raoux auto isDefinedOutsideOfBody = [&](Value value) {
963ed0288f7SThomas Raoux auto *definingOp = value.getDefiningOp();
964ed0288f7SThomas Raoux return (definingOp && opsToMove.count(definingOp)) ||
965ed0288f7SThomas Raoux warpOp.isDefinedOutsideOfRegion(value);
966ed0288f7SThomas Raoux };
967ed0288f7SThomas Raoux
968ed0288f7SThomas Raoux // Do not use walk here, as we do not want to go into nested regions and hoist
969ed0288f7SThomas Raoux // operations from there.
970ed0288f7SThomas Raoux for (auto &op : body->without_terminator()) {
971ed0288f7SThomas Raoux bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
972ed0288f7SThomas Raoux return result.getType().isa<VectorType>();
973ed0288f7SThomas Raoux });
974ed0288f7SThomas Raoux if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
975ed0288f7SThomas Raoux opsToMove.insert(&op);
976ed0288f7SThomas Raoux }
977ed0288f7SThomas Raoux
978ed0288f7SThomas Raoux // Move all the ops marked as uniform outside of the region.
979ed0288f7SThomas Raoux for (Operation *op : opsToMove)
980ed0288f7SThomas Raoux op->moveBefore(warpOp);
981ed0288f7SThomas Raoux }
982