1d02f10d9SThomas Raoux //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2d02f10d9SThomas Raoux //
3d02f10d9SThomas Raoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d02f10d9SThomas Raoux // See https://llvm.org/LICENSE.txt for license information.
5d02f10d9SThomas Raoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d02f10d9SThomas Raoux //
7d02f10d9SThomas Raoux //===----------------------------------------------------------------------===//
8d02f10d9SThomas Raoux 
9ed0288f7SThomas Raoux #include "mlir/Dialect/Affine/IR/AffineOps.h"
10d02f10d9SThomas Raoux #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11d02f10d9SThomas Raoux #include "mlir/Dialect/MemRef/IR/MemRef.h"
128b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
13d02f10d9SThomas Raoux #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
14087aba4fSThomas Raoux #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
1576cf33daSThomas Raoux #include "mlir/IR/BlockAndValueMapping.h"
16ed0288f7SThomas Raoux #include "mlir/Transforms/SideEffectUtils.h"
17d7d6443dSThomas Raoux #include "llvm/ADT/SetVector.h"
1808d651d7SMehdi Amini #include <utility>
1908d651d7SMehdi Amini 
20d02f10d9SThomas Raoux using namespace mlir;
21d02f10d9SThomas Raoux using namespace mlir::vector;
22d02f10d9SThomas Raoux 
23d02f10d9SThomas Raoux static LogicalResult
rewriteWarpOpToScfFor(RewriterBase & rewriter,WarpExecuteOnLane0Op warpOp,const WarpExecuteOnLane0LoweringOptions & options)24d02f10d9SThomas Raoux rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
25d02f10d9SThomas Raoux                       const WarpExecuteOnLane0LoweringOptions &options) {
26d02f10d9SThomas Raoux   assert(warpOp.getBodyRegion().hasOneBlock() &&
27d02f10d9SThomas Raoux          "expected WarpOp with single block");
28d02f10d9SThomas Raoux   Block *warpOpBody = &warpOp.getBodyRegion().front();
29d02f10d9SThomas Raoux   Location loc = warpOp.getLoc();
30d02f10d9SThomas Raoux 
31d02f10d9SThomas Raoux   // Passed all checks. Start rewriting.
32d02f10d9SThomas Raoux   OpBuilder::InsertionGuard g(rewriter);
33d02f10d9SThomas Raoux   rewriter.setInsertionPoint(warpOp);
34d02f10d9SThomas Raoux 
35d02f10d9SThomas Raoux   // Create scf.if op.
36d02f10d9SThomas Raoux   Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
37d02f10d9SThomas Raoux   Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
38d02f10d9SThomas Raoux                                                  warpOp.getLaneid(), c0);
39d02f10d9SThomas Raoux   auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
40d02f10d9SThomas Raoux                                          /*withElseRegion=*/false);
41d02f10d9SThomas Raoux   rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
42d02f10d9SThomas Raoux 
43d02f10d9SThomas Raoux   // Store vectors that are defined outside of warpOp into the scratch pad
44d02f10d9SThomas Raoux   // buffer.
45d02f10d9SThomas Raoux   SmallVector<Value> bbArgReplacements;
46d02f10d9SThomas Raoux   for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
47d02f10d9SThomas Raoux     Value val = it.value();
48d02f10d9SThomas Raoux     Value bbArg = warpOpBody->getArgument(it.index());
49d02f10d9SThomas Raoux 
50d02f10d9SThomas Raoux     rewriter.setInsertionPoint(ifOp);
51f6c79c6aSNicolas Vasilache     Value buffer =
52f6c79c6aSNicolas Vasilache         options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType());
53d02f10d9SThomas Raoux 
54d02f10d9SThomas Raoux     // Store arg vector into buffer.
55d02f10d9SThomas Raoux     rewriter.setInsertionPoint(ifOp);
56d02f10d9SThomas Raoux     auto vectorType = val.getType().cast<VectorType>();
57d02f10d9SThomas Raoux     int64_t storeSize = vectorType.getShape()[0];
58d02f10d9SThomas Raoux     Value storeOffset = rewriter.create<arith::MulIOp>(
59d02f10d9SThomas Raoux         loc, warpOp.getLaneid(),
60d02f10d9SThomas Raoux         rewriter.create<arith::ConstantIndexOp>(loc, storeSize));
61d02f10d9SThomas Raoux     rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset);
62d02f10d9SThomas Raoux 
63d02f10d9SThomas Raoux     // Load bbArg vector from buffer.
64d02f10d9SThomas Raoux     rewriter.setInsertionPointToStart(ifOp.thenBlock());
65d02f10d9SThomas Raoux     auto bbArgType = bbArg.getType().cast<VectorType>();
66d02f10d9SThomas Raoux     Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0);
67d02f10d9SThomas Raoux     bbArgReplacements.push_back(loadOp);
68d02f10d9SThomas Raoux   }
69d02f10d9SThomas Raoux 
70d02f10d9SThomas Raoux   // Insert sync after all the stores and before all the loads.
71d02f10d9SThomas Raoux   if (!warpOp.getArgs().empty()) {
72d02f10d9SThomas Raoux     rewriter.setInsertionPoint(ifOp);
73f6c79c6aSNicolas Vasilache     options.warpSyncronizationFn(loc, rewriter, warpOp);
74d02f10d9SThomas Raoux   }
75d02f10d9SThomas Raoux 
76d02f10d9SThomas Raoux   // Move body of warpOp to ifOp.
77d02f10d9SThomas Raoux   rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
78d02f10d9SThomas Raoux 
79d02f10d9SThomas Raoux   // Rewrite terminator and compute replacements of WarpOp results.
80d02f10d9SThomas Raoux   SmallVector<Value> replacements;
81d02f10d9SThomas Raoux   auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
82d02f10d9SThomas Raoux   Location yieldLoc = yieldOp.getLoc();
83d02f10d9SThomas Raoux   for (const auto &it : llvm::enumerate(yieldOp.operands())) {
84d02f10d9SThomas Raoux     Value val = it.value();
85d02f10d9SThomas Raoux     Type resultType = warpOp->getResultTypes()[it.index()];
86d02f10d9SThomas Raoux     rewriter.setInsertionPoint(ifOp);
87f6c79c6aSNicolas Vasilache     Value buffer =
88f6c79c6aSNicolas Vasilache         options.warpAllocationFn(loc, rewriter, warpOp, val.getType());
89d02f10d9SThomas Raoux 
90d02f10d9SThomas Raoux     // Store yielded value into buffer.
91d02f10d9SThomas Raoux     rewriter.setInsertionPoint(yieldOp);
92d02f10d9SThomas Raoux     if (val.getType().isa<VectorType>())
93d02f10d9SThomas Raoux       rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0);
94d02f10d9SThomas Raoux     else
95d02f10d9SThomas Raoux       rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0);
96d02f10d9SThomas Raoux 
97d02f10d9SThomas Raoux     // Load value from buffer (after warpOp).
98d02f10d9SThomas Raoux     rewriter.setInsertionPointAfter(ifOp);
99d02f10d9SThomas Raoux     if (resultType == val.getType()) {
100d02f10d9SThomas Raoux       // Result type and yielded value type are the same. This is a broadcast.
101d02f10d9SThomas Raoux       // E.g.:
102ed0288f7SThomas Raoux       // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
103ed0288f7SThomas Raoux       //   vector.yield %cst : f32
104d02f10d9SThomas Raoux       // }
105d02f10d9SThomas Raoux       // Both types are f32. The constant %cst is broadcasted to all lanes.
106d02f10d9SThomas Raoux       // This is described in more detail in the documentation of the op.
107d02f10d9SThomas Raoux       Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0);
108d02f10d9SThomas Raoux       replacements.push_back(loadOp);
109d02f10d9SThomas Raoux     } else {
110d02f10d9SThomas Raoux       auto loadedVectorType = resultType.cast<VectorType>();
111d02f10d9SThomas Raoux       int64_t loadSize = loadedVectorType.getShape()[0];
112d02f10d9SThomas Raoux 
113d02f10d9SThomas Raoux       // loadOffset = laneid * loadSize
114d02f10d9SThomas Raoux       Value loadOffset = rewriter.create<arith::MulIOp>(
115d02f10d9SThomas Raoux           loc, warpOp.getLaneid(),
116d02f10d9SThomas Raoux           rewriter.create<arith::ConstantIndexOp>(loc, loadSize));
117d02f10d9SThomas Raoux       Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType,
118d02f10d9SThomas Raoux                                                      buffer, loadOffset);
119d02f10d9SThomas Raoux       replacements.push_back(loadOp);
120d02f10d9SThomas Raoux     }
121d02f10d9SThomas Raoux   }
122d02f10d9SThomas Raoux 
123d02f10d9SThomas Raoux   // Insert sync after all the stores and before all the loads.
124d02f10d9SThomas Raoux   if (!yieldOp.operands().empty()) {
125d02f10d9SThomas Raoux     rewriter.setInsertionPointAfter(ifOp);
126f6c79c6aSNicolas Vasilache     options.warpSyncronizationFn(loc, rewriter, warpOp);
127d02f10d9SThomas Raoux   }
128d02f10d9SThomas Raoux 
129d02f10d9SThomas Raoux   // Delete terminator and add empty scf.yield.
130d02f10d9SThomas Raoux   rewriter.eraseOp(yieldOp);
131d02f10d9SThomas Raoux   rewriter.setInsertionPointToEnd(ifOp.thenBlock());
132d02f10d9SThomas Raoux   rewriter.create<scf::YieldOp>(yieldLoc);
133d02f10d9SThomas Raoux 
134d02f10d9SThomas Raoux   // Compute replacements for WarpOp results.
135d02f10d9SThomas Raoux   rewriter.replaceOp(warpOp, replacements);
136d02f10d9SThomas Raoux 
137d02f10d9SThomas Raoux   return success();
138d02f10d9SThomas Raoux }
139d02f10d9SThomas Raoux 
140ed0288f7SThomas Raoux /// Helper to create a new WarpExecuteOnLane0Op with different signature.
moveRegionToNewWarpOpAndReplaceReturns(RewriterBase & rewriter,WarpExecuteOnLane0Op warpOp,ValueRange newYieldedValues,TypeRange newReturnTypes)141ed0288f7SThomas Raoux static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
142ed0288f7SThomas Raoux     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
143ed0288f7SThomas Raoux     ValueRange newYieldedValues, TypeRange newReturnTypes) {
144ed0288f7SThomas Raoux   // Create a new op before the existing one, with the extra operands.
145ed0288f7SThomas Raoux   OpBuilder::InsertionGuard g(rewriter);
146ed0288f7SThomas Raoux   rewriter.setInsertionPoint(warpOp);
147ed0288f7SThomas Raoux   auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
148ed0288f7SThomas Raoux       warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
149ed0288f7SThomas Raoux       warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
150ed0288f7SThomas Raoux 
151ed0288f7SThomas Raoux   Region &opBody = warpOp.getBodyRegion();
152ed0288f7SThomas Raoux   Region &newOpBody = newWarpOp.getBodyRegion();
153f6c79c6aSNicolas Vasilache   Block &newOpFirstBlock = newOpBody.front();
154ed0288f7SThomas Raoux   rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
155f6c79c6aSNicolas Vasilache   rewriter.eraseBlock(&newOpFirstBlock);
156f6c79c6aSNicolas Vasilache   assert(newWarpOp.getWarpRegion().hasOneBlock() &&
157f6c79c6aSNicolas Vasilache          "expected WarpOp with single block");
158f6c79c6aSNicolas Vasilache 
159ed0288f7SThomas Raoux   auto yield =
160ed0288f7SThomas Raoux       cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
161ed0288f7SThomas Raoux 
162ed0288f7SThomas Raoux   rewriter.updateRootInPlace(
163ed0288f7SThomas Raoux       yield, [&]() { yield.operandsMutable().assign(newYieldedValues); });
164ed0288f7SThomas Raoux   return newWarpOp;
165ed0288f7SThomas Raoux }
166ed0288f7SThomas Raoux 
167ed0288f7SThomas Raoux /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
168d7d6443dSThomas Raoux /// `indices` return the index of each new output.
moveRegionToNewWarpOpAndAppendReturns(RewriterBase & rewriter,WarpExecuteOnLane0Op warpOp,ValueRange newYieldedValues,TypeRange newReturnTypes,llvm::SmallVector<size_t> & indices)169ed0288f7SThomas Raoux static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
170ed0288f7SThomas Raoux     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
171d7d6443dSThomas Raoux     ValueRange newYieldedValues, TypeRange newReturnTypes,
172d7d6443dSThomas Raoux     llvm::SmallVector<size_t> &indices) {
173ed0288f7SThomas Raoux   SmallVector<Type> types(warpOp.getResultTypes().begin(),
174ed0288f7SThomas Raoux                           warpOp.getResultTypes().end());
175ed0288f7SThomas Raoux   auto yield = cast<vector::YieldOp>(
176ed0288f7SThomas Raoux       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
177d7d6443dSThomas Raoux   llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
178ed0288f7SThomas Raoux                                               yield.getOperands().end());
179d7d6443dSThomas Raoux   for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
180d7d6443dSThomas Raoux     if (yieldValues.insert(std::get<0>(newRet))) {
181d7d6443dSThomas Raoux       types.push_back(std::get<1>(newRet));
182d7d6443dSThomas Raoux       indices.push_back(yieldValues.size() - 1);
183d7d6443dSThomas Raoux     } else {
184d7d6443dSThomas Raoux       // If the value already exit the region don't create a new output.
185d7d6443dSThomas Raoux       for (auto &yieldOperand : llvm::enumerate(yieldValues.getArrayRef())) {
186d7d6443dSThomas Raoux         if (yieldOperand.value() == std::get<0>(newRet)) {
187d7d6443dSThomas Raoux           indices.push_back(yieldOperand.index());
188d7d6443dSThomas Raoux           break;
189d7d6443dSThomas Raoux         }
190d7d6443dSThomas Raoux       }
191d7d6443dSThomas Raoux     }
192d7d6443dSThomas Raoux   }
193d7d6443dSThomas Raoux   yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
194ed0288f7SThomas Raoux   WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
195d7d6443dSThomas Raoux       rewriter, warpOp, yieldValues.getArrayRef(), types);
196ed0288f7SThomas Raoux   rewriter.replaceOp(warpOp,
197ed0288f7SThomas Raoux                      newWarpOp.getResults().take_front(warpOp.getNumResults()));
198ed0288f7SThomas Raoux   return newWarpOp;
199ed0288f7SThomas Raoux }
200ed0288f7SThomas Raoux 
201ed0288f7SThomas Raoux /// Helper to know if an op can be hoisted out of the region.
canBeHoisted(Operation * op,function_ref<bool (Value)> definedOutside)202ed0288f7SThomas Raoux static bool canBeHoisted(Operation *op,
203ed0288f7SThomas Raoux                          function_ref<bool(Value)> definedOutside) {
204ed0288f7SThomas Raoux   return llvm::all_of(op->getOperands(), definedOutside) &&
205ed0288f7SThomas Raoux          isSideEffectFree(op) && op->getNumRegions() == 0;
206ed0288f7SThomas Raoux }
207ed0288f7SThomas Raoux 
20876cf33daSThomas Raoux /// Return a value yielded by `warpOp` which statifies the filter lamdba
20976cf33daSThomas Raoux /// condition and is not dead.
getWarpResult(WarpExecuteOnLane0Op warpOp,std::function<bool (Operation *)> fn)21076cf33daSThomas Raoux static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
21176cf33daSThomas Raoux                                 std::function<bool(Operation *)> fn) {
21276cf33daSThomas Raoux   auto yield = cast<vector::YieldOp>(
21376cf33daSThomas Raoux       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
21476cf33daSThomas Raoux   for (OpOperand &yieldOperand : yield->getOpOperands()) {
21576cf33daSThomas Raoux     Value yieldValues = yieldOperand.get();
21676cf33daSThomas Raoux     Operation *definedOp = yieldValues.getDefiningOp();
21776cf33daSThomas Raoux     if (definedOp && fn(definedOp)) {
21876cf33daSThomas Raoux       if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
21976cf33daSThomas Raoux         return &yieldOperand;
22076cf33daSThomas Raoux     }
22176cf33daSThomas Raoux   }
22276cf33daSThomas Raoux   return {};
22376cf33daSThomas Raoux }
22476cf33daSThomas Raoux 
22576cf33daSThomas Raoux // Clones `op` into a new operation that takes `operands` and returns
22676cf33daSThomas Raoux // `resultTypes`.
cloneOpWithOperandsAndTypes(RewriterBase & rewriter,Location loc,Operation * op,ArrayRef<Value> operands,ArrayRef<Type> resultTypes)22776cf33daSThomas Raoux static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
22876cf33daSThomas Raoux                                               Location loc, Operation *op,
22976cf33daSThomas Raoux                                               ArrayRef<Value> operands,
23076cf33daSThomas Raoux                                               ArrayRef<Type> resultTypes) {
23176cf33daSThomas Raoux   OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
23276cf33daSThomas Raoux                      op->getAttrs());
23376cf33daSThomas Raoux   return rewriter.create(res);
23476cf33daSThomas Raoux }
23576cf33daSThomas Raoux 
23676cf33daSThomas Raoux /// Currently the distribution map is implicit based on the vector shape. In the
23776cf33daSThomas Raoux /// future it will be part of the op.
23876cf33daSThomas Raoux /// Example:
23976cf33daSThomas Raoux /// ```
24076cf33daSThomas Raoux /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
24176cf33daSThomas Raoux ///   ...
24276cf33daSThomas Raoux ///   vector.yield %3 : vector<32x16x64xf32>
24376cf33daSThomas Raoux /// }
24476cf33daSThomas Raoux /// ```
24576cf33daSThomas Raoux /// Would have an implicit map of:
24676cf33daSThomas Raoux /// `(d0, d1, d2) -> (d0, d2)`
calculateImplicitMap(Value yield,Value ret)24776cf33daSThomas Raoux static AffineMap calculateImplicitMap(Value yield, Value ret) {
24876cf33daSThomas Raoux   auto srcType = yield.getType().cast<VectorType>();
24976cf33daSThomas Raoux   auto dstType = ret.getType().cast<VectorType>();
25076cf33daSThomas Raoux   SmallVector<AffineExpr> perm;
25176cf33daSThomas Raoux   // Check which dimensions of the yield value are different than the dimensions
25276cf33daSThomas Raoux   // of the result to know the distributed dimensions. Then associate each
25376cf33daSThomas Raoux   // distributed dimension to an ID in order.
25476cf33daSThomas Raoux   for (unsigned i = 0, e = srcType.getRank(); i < e; i++) {
25576cf33daSThomas Raoux     if (srcType.getDimSize(i) != dstType.getDimSize(i))
25676cf33daSThomas Raoux       perm.push_back(getAffineDimExpr(i, yield.getContext()));
25776cf33daSThomas Raoux   }
25876cf33daSThomas Raoux   auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext());
25976cf33daSThomas Raoux   return map;
26076cf33daSThomas Raoux }
26176cf33daSThomas Raoux 
262d02f10d9SThomas Raoux namespace {
263d02f10d9SThomas Raoux 
264d02f10d9SThomas Raoux struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpOpToScfForPattern__anonbfa0aa500211::WarpOpToScfForPattern265d02f10d9SThomas Raoux   WarpOpToScfForPattern(MLIRContext *context,
266d02f10d9SThomas Raoux                         const WarpExecuteOnLane0LoweringOptions &options,
267d02f10d9SThomas Raoux                         PatternBenefit benefit = 1)
268d02f10d9SThomas Raoux       : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
269d02f10d9SThomas Raoux         options(options) {}
270d02f10d9SThomas Raoux 
matchAndRewrite__anonbfa0aa500211::WarpOpToScfForPattern271d02f10d9SThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
272d02f10d9SThomas Raoux                                 PatternRewriter &rewriter) const override {
273d02f10d9SThomas Raoux     return rewriteWarpOpToScfFor(rewriter, warpOp, options);
274d02f10d9SThomas Raoux   }
275d02f10d9SThomas Raoux 
276d02f10d9SThomas Raoux private:
277d02f10d9SThomas Raoux   const WarpExecuteOnLane0LoweringOptions &options;
278d02f10d9SThomas Raoux };
279d02f10d9SThomas Raoux 
2806a57d8fbSNicolas Vasilache /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute
2816a57d8fbSNicolas Vasilache /// op with the proper return type.
2826a57d8fbSNicolas Vasilache /// The new write op is updated to write the result of the new warp execute op.
2836a57d8fbSNicolas Vasilache /// The old `writeOp` is deleted.
cloneWriteOp(RewriterBase & rewriter,WarpExecuteOnLane0Op warpOp,vector::TransferWriteOp writeOp,VectorType targetType)2846a57d8fbSNicolas Vasilache static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
2856a57d8fbSNicolas Vasilache                                             WarpExecuteOnLane0Op warpOp,
2866a57d8fbSNicolas Vasilache                                             vector::TransferWriteOp writeOp,
2876a57d8fbSNicolas Vasilache                                             VectorType targetType) {
2886a57d8fbSNicolas Vasilache   assert(writeOp->getParentOp() == warpOp &&
2896a57d8fbSNicolas Vasilache          "write must be nested immediately under warp");
2906a57d8fbSNicolas Vasilache   OpBuilder::InsertionGuard g(rewriter);
291d7d6443dSThomas Raoux   SmallVector<size_t> newRetIndices;
2926a57d8fbSNicolas Vasilache   WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2936a57d8fbSNicolas Vasilache       rewriter, warpOp, ValueRange{{writeOp.getVector()}},
294d7d6443dSThomas Raoux       TypeRange{targetType}, newRetIndices);
2956a57d8fbSNicolas Vasilache   rewriter.setInsertionPointAfter(newWarpOp);
2966a57d8fbSNicolas Vasilache   auto newWriteOp =
2976a57d8fbSNicolas Vasilache       cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
2986a57d8fbSNicolas Vasilache   rewriter.eraseOp(writeOp);
299d7d6443dSThomas Raoux   newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
3006a57d8fbSNicolas Vasilache   return newWriteOp;
3016a57d8fbSNicolas Vasilache }
3026a57d8fbSNicolas Vasilache 
303ed0288f7SThomas Raoux /// Distribute transfer_write ops based on the affine map returned by
304ed0288f7SThomas Raoux /// `distributionMapFn`.
305ed0288f7SThomas Raoux /// Example:
306ed0288f7SThomas Raoux /// ```
307ed0288f7SThomas Raoux /// %0 = vector.warp_execute_on_lane_0(%id){
308ed0288f7SThomas Raoux ///   ...
309ed0288f7SThomas Raoux ///   vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
310ed0288f7SThomas Raoux ///   vector.yield
311ed0288f7SThomas Raoux /// }
312ed0288f7SThomas Raoux /// ```
313ed0288f7SThomas Raoux /// To
314ed0288f7SThomas Raoux /// ```
315ed0288f7SThomas Raoux /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
316ed0288f7SThomas Raoux ///   ...
317ed0288f7SThomas Raoux ///   vector.yield %v : vector<32xf32>
318ed0288f7SThomas Raoux /// }
319ed0288f7SThomas Raoux /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
320ed0288f7SThomas Raoux struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
WarpOpTransferWrite__anonbfa0aa500211::WarpOpTransferWrite321ed0288f7SThomas Raoux   WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
322ed0288f7SThomas Raoux                       PatternBenefit b = 1)
323ed0288f7SThomas Raoux       : OpRewritePattern<vector::TransferWriteOp>(ctx, b),
32408d651d7SMehdi Amini         distributionMapFn(std::move(fn)) {}
325ed0288f7SThomas Raoux 
326ed0288f7SThomas Raoux   /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
327ed0288f7SThomas Raoux   /// are multiples of the distribution ratio are supported at the moment.
tryDistributeOp__anonbfa0aa500211::WarpOpTransferWrite328ed0288f7SThomas Raoux   LogicalResult tryDistributeOp(RewriterBase &rewriter,
329ed0288f7SThomas Raoux                                 vector::TransferWriteOp writeOp,
330ed0288f7SThomas Raoux                                 WarpExecuteOnLane0Op warpOp) const {
3316a57d8fbSNicolas Vasilache     VectorType writtenVectorType = writeOp.getVectorType();
3326a57d8fbSNicolas Vasilache 
3336a57d8fbSNicolas Vasilache     // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
3346a57d8fbSNicolas Vasilache     // to separate it from the rest.
3356a57d8fbSNicolas Vasilache     if (writtenVectorType.getRank() == 0)
3366a57d8fbSNicolas Vasilache       return failure();
3376a57d8fbSNicolas Vasilache 
3386a57d8fbSNicolas Vasilache     // 2. Compute the distribution map.
339ed0288f7SThomas Raoux     AffineMap map = distributionMapFn(writeOp);
3406a57d8fbSNicolas Vasilache     if (map.getNumResults() != 1)
3416a57d8fbSNicolas Vasilache       return writeOp->emitError("multi-dim distribution not implemented yet");
3426a57d8fbSNicolas Vasilache 
3436a57d8fbSNicolas Vasilache     // 3. Compute the targetType using the distribution map.
3446a57d8fbSNicolas Vasilache     SmallVector<int64_t> targetShape(writtenVectorType.getShape().begin(),
3456a57d8fbSNicolas Vasilache                                      writtenVectorType.getShape().end());
346ed0288f7SThomas Raoux     for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
347ed0288f7SThomas Raoux       unsigned position = map.getDimPosition(i);
348ed0288f7SThomas Raoux       if (targetShape[position] % warpOp.getWarpSize() != 0)
349ed0288f7SThomas Raoux         return failure();
350ed0288f7SThomas Raoux       targetShape[position] = targetShape[position] / warpOp.getWarpSize();
351ed0288f7SThomas Raoux     }
352ed0288f7SThomas Raoux     VectorType targetType =
3536a57d8fbSNicolas Vasilache         VectorType::get(targetShape, writtenVectorType.getElementType());
354ed0288f7SThomas Raoux 
3556a57d8fbSNicolas Vasilache     // 4. clone the write into a new WarpExecuteOnLane0Op to separate it from
3566a57d8fbSNicolas Vasilache     // the rest.
3576a57d8fbSNicolas Vasilache     vector::TransferWriteOp newWriteOp =
3586a57d8fbSNicolas Vasilache         cloneWriteOp(rewriter, warpOp, writeOp, targetType);
359ed0288f7SThomas Raoux 
3606a57d8fbSNicolas Vasilache     // 5. Reindex the write using the distribution map.
3616a57d8fbSNicolas Vasilache     auto newWarpOp =
3626a57d8fbSNicolas Vasilache         newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
363ed0288f7SThomas Raoux     rewriter.setInsertionPoint(newWriteOp);
364ed0288f7SThomas Raoux     AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
365ed0288f7SThomas Raoux     Location loc = newWriteOp.getLoc();
366ed0288f7SThomas Raoux     SmallVector<Value> indices(newWriteOp.getIndices().begin(),
367ed0288f7SThomas Raoux                                newWriteOp.getIndices().end());
368ed0288f7SThomas Raoux     for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
369ed0288f7SThomas Raoux       AffineExpr d0, d1;
370ed0288f7SThomas Raoux       bindDims(newWarpOp.getContext(), d0, d1);
371ed0288f7SThomas Raoux       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
372ed0288f7SThomas Raoux       if (!indexExpr)
373ed0288f7SThomas Raoux         continue;
374ed0288f7SThomas Raoux       unsigned indexPos = indexExpr.getPosition();
375ed0288f7SThomas Raoux       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
3766a57d8fbSNicolas Vasilache       auto scale = rewriter.getAffineConstantExpr(targetShape[vectorPos]);
377ed0288f7SThomas Raoux       indices[indexPos] =
378ed0288f7SThomas Raoux           makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
379ed0288f7SThomas Raoux                                   {indices[indexPos], newWarpOp.getLaneid()});
380ed0288f7SThomas Raoux     }
381ed0288f7SThomas Raoux     newWriteOp.getIndicesMutable().assign(indices);
382ed0288f7SThomas Raoux 
383ed0288f7SThomas Raoux     return success();
384ed0288f7SThomas Raoux   }
385ed0288f7SThomas Raoux 
386ed0288f7SThomas Raoux   /// Extract TransferWriteOps of vector<1x> into a separate warp op.
tryExtractOp__anonbfa0aa500211::WarpOpTransferWrite387ed0288f7SThomas Raoux   LogicalResult tryExtractOp(RewriterBase &rewriter,
388ed0288f7SThomas Raoux                              vector::TransferWriteOp writeOp,
389ed0288f7SThomas Raoux                              WarpExecuteOnLane0Op warpOp) const {
390ed0288f7SThomas Raoux     Location loc = writeOp.getLoc();
391ed0288f7SThomas Raoux     VectorType vecType = writeOp.getVectorType();
392ed0288f7SThomas Raoux 
3937eba5cdfSThomas Raoux     // Only sink out vector of 1 element for now to not serialize large vector
3947eba5cdfSThomas Raoux     // store. This can later be controlled by user.
3957eba5cdfSThomas Raoux     if (vecType.getNumElements() != 1)
396ed0288f7SThomas Raoux       return failure();
397ed0288f7SThomas Raoux 
398ed0288f7SThomas Raoux     // Do not process warp ops that contain only TransferWriteOps.
399ed0288f7SThomas Raoux     if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
400ed0288f7SThomas Raoux           return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
401ed0288f7SThomas Raoux         }))
402ed0288f7SThomas Raoux       return failure();
403ed0288f7SThomas Raoux 
404ed0288f7SThomas Raoux     SmallVector<Value> yieldValues = {writeOp.getVector()};
405ed0288f7SThomas Raoux     SmallVector<Type> retTypes = {vecType};
406d7d6443dSThomas Raoux     SmallVector<size_t> newRetIndices;
407ed0288f7SThomas Raoux     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
408d7d6443dSThomas Raoux         rewriter, warpOp, yieldValues, retTypes, newRetIndices);
409ed0288f7SThomas Raoux     rewriter.setInsertionPointAfter(newWarpOp);
410ed0288f7SThomas Raoux 
411ed0288f7SThomas Raoux     // Create a second warp op that contains only writeOp.
412ed0288f7SThomas Raoux     auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
413ed0288f7SThomas Raoux         loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
414ed0288f7SThomas Raoux     Block &body = secondWarpOp.getBodyRegion().front();
415ed0288f7SThomas Raoux     rewriter.setInsertionPointToStart(&body);
416ed0288f7SThomas Raoux     auto newWriteOp =
417ed0288f7SThomas Raoux         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
418d7d6443dSThomas Raoux     newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
419ed0288f7SThomas Raoux     rewriter.eraseOp(writeOp);
420ed0288f7SThomas Raoux     rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
421ed0288f7SThomas Raoux     return success();
422ed0288f7SThomas Raoux   }
423ed0288f7SThomas Raoux 
matchAndRewrite__anonbfa0aa500211::WarpOpTransferWrite424ed0288f7SThomas Raoux   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
425ed0288f7SThomas Raoux                                 PatternRewriter &rewriter) const override {
426ed0288f7SThomas Raoux     // Ops with mask not supported yet.
427ed0288f7SThomas Raoux     if (writeOp.getMask())
428ed0288f7SThomas Raoux       return failure();
429ed0288f7SThomas Raoux 
430ed0288f7SThomas Raoux     auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
431ed0288f7SThomas Raoux     if (!warpOp)
432ed0288f7SThomas Raoux       return failure();
433ed0288f7SThomas Raoux 
434ed0288f7SThomas Raoux     // There must be no op with a side effect after writeOp.
435ed0288f7SThomas Raoux     Operation *nextOp = writeOp.getOperation();
436ed0288f7SThomas Raoux     while ((nextOp = nextOp->getNextNode()))
437ed0288f7SThomas Raoux       if (!isSideEffectFree(nextOp))
438ed0288f7SThomas Raoux         return failure();
439ed0288f7SThomas Raoux 
440ed0288f7SThomas Raoux     if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
441ed0288f7SThomas Raoux           return writeOp.getVector() == value ||
442ed0288f7SThomas Raoux                  warpOp.isDefinedOutsideOfRegion(value);
443ed0288f7SThomas Raoux         }))
444ed0288f7SThomas Raoux       return failure();
445ed0288f7SThomas Raoux 
446ed0288f7SThomas Raoux     if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
447ed0288f7SThomas Raoux       return success();
448ed0288f7SThomas Raoux 
449ed0288f7SThomas Raoux     if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
450ed0288f7SThomas Raoux       return success();
451ed0288f7SThomas Raoux 
452ed0288f7SThomas Raoux     return failure();
453ed0288f7SThomas Raoux   }
454ed0288f7SThomas Raoux 
455ed0288f7SThomas Raoux private:
456ed0288f7SThomas Raoux   DistributionMapFn distributionMapFn;
457ed0288f7SThomas Raoux };
458ed0288f7SThomas Raoux 
45976cf33daSThomas Raoux /// Sink out elementwise op feeding into a warp op yield.
46076cf33daSThomas Raoux /// ```
46176cf33daSThomas Raoux /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
46276cf33daSThomas Raoux ///   ...
46376cf33daSThomas Raoux ///   %3 = arith.addf %1, %2 : vector<32xf32>
46476cf33daSThomas Raoux ///   vector.yield %3 : vector<32xf32>
46576cf33daSThomas Raoux /// }
46676cf33daSThomas Raoux /// ```
46776cf33daSThomas Raoux /// To
46876cf33daSThomas Raoux /// ```
46976cf33daSThomas Raoux /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
47076cf33daSThomas Raoux /// vector<1xf32>, vector<1xf32>) {
47176cf33daSThomas Raoux ///   ...
47276cf33daSThomas Raoux ///   %4 = arith.addf %2, %3 : vector<32xf32>
47376cf33daSThomas Raoux ///   vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
47476cf33daSThomas Raoux ///   vector<32xf32>
47576cf33daSThomas Raoux /// }
47676cf33daSThomas Raoux /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
47776cf33daSThomas Raoux struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
47876cf33daSThomas Raoux   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpElementwise47976cf33daSThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
48076cf33daSThomas Raoux                                 PatternRewriter &rewriter) const override {
48176cf33daSThomas Raoux     OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
48276cf33daSThomas Raoux       return OpTrait::hasElementwiseMappableTraits(op);
48376cf33daSThomas Raoux     });
48476cf33daSThomas Raoux     if (!yieldOperand)
48576cf33daSThomas Raoux       return failure();
48676cf33daSThomas Raoux     Operation *elementWise = yieldOperand->get().getDefiningOp();
48776cf33daSThomas Raoux     unsigned operandIndex = yieldOperand->getOperandNumber();
48876cf33daSThomas Raoux     Value distributedVal = warpOp.getResult(operandIndex);
48976cf33daSThomas Raoux     SmallVector<Value> yieldValues;
49076cf33daSThomas Raoux     SmallVector<Type> retTypes;
49176cf33daSThomas Raoux     Location loc = warpOp.getLoc();
49276cf33daSThomas Raoux     for (OpOperand &operand : elementWise->getOpOperands()) {
49376cf33daSThomas Raoux       Type targetType;
49476cf33daSThomas Raoux       if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) {
49576cf33daSThomas Raoux         // If the result type is a vector, the operands must also be vectors.
49676cf33daSThomas Raoux         auto operandType = operand.get().getType().cast<VectorType>();
49776cf33daSThomas Raoux         targetType =
49876cf33daSThomas Raoux             VectorType::get(vecType.getShape(), operandType.getElementType());
49976cf33daSThomas Raoux       } else {
50076cf33daSThomas Raoux         auto operandType = operand.get().getType();
50176cf33daSThomas Raoux         assert(!operandType.isa<VectorType>() &&
50276cf33daSThomas Raoux                "unexpected yield of vector from op with scalar result type");
50376cf33daSThomas Raoux         targetType = operandType;
50476cf33daSThomas Raoux       }
50576cf33daSThomas Raoux       retTypes.push_back(targetType);
50676cf33daSThomas Raoux       yieldValues.push_back(operand.get());
50776cf33daSThomas Raoux     }
508d7d6443dSThomas Raoux     SmallVector<size_t> newRetIndices;
50976cf33daSThomas Raoux     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
510d7d6443dSThomas Raoux         rewriter, warpOp, yieldValues, retTypes, newRetIndices);
51176cf33daSThomas Raoux     rewriter.setInsertionPointAfter(newWarpOp);
51276cf33daSThomas Raoux     SmallVector<Value> newOperands(elementWise->getOperands().begin(),
51376cf33daSThomas Raoux                                    elementWise->getOperands().end());
51476cf33daSThomas Raoux     for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
515d7d6443dSThomas Raoux       newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
51676cf33daSThomas Raoux     }
51776cf33daSThomas Raoux     OpBuilder::InsertionGuard g(rewriter);
51876cf33daSThomas Raoux     rewriter.setInsertionPointAfter(newWarpOp);
51976cf33daSThomas Raoux     Operation *newOp = cloneOpWithOperandsAndTypes(
52076cf33daSThomas Raoux         rewriter, loc, elementWise, newOperands,
52176cf33daSThomas Raoux         {newWarpOp.getResult(operandIndex).getType()});
52276cf33daSThomas Raoux     newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0));
52376cf33daSThomas Raoux     return success();
52476cf33daSThomas Raoux   }
52576cf33daSThomas Raoux };
52676cf33daSThomas Raoux 
5270af26805SThomas Raoux /// Sink out splat constant op feeding into a warp op yield.
5280af26805SThomas Raoux /// ```
5290af26805SThomas Raoux /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
5300af26805SThomas Raoux ///   ...
5310af26805SThomas Raoux ///   %cst = arith.constant dense<2.0> : vector<32xf32>
5320af26805SThomas Raoux ///   vector.yield %cst : vector<32xf32>
5330af26805SThomas Raoux /// }
5340af26805SThomas Raoux /// ```
5350af26805SThomas Raoux /// To
5360af26805SThomas Raoux /// ```
5370af26805SThomas Raoux /// vector.warp_execute_on_lane_0(%arg0 {
5380af26805SThomas Raoux ///   ...
5390af26805SThomas Raoux /// }
5400af26805SThomas Raoux /// %0 = arith.constant dense<2.0> : vector<1xf32>
5410af26805SThomas Raoux struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
5420af26805SThomas Raoux   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpConstant5430af26805SThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
5440af26805SThomas Raoux                                 PatternRewriter &rewriter) const override {
5450af26805SThomas Raoux     OpOperand *yieldOperand = getWarpResult(
5460af26805SThomas Raoux         warpOp, [](Operation *op) { return isa<arith::ConstantOp>(op); });
5470af26805SThomas Raoux     if (!yieldOperand)
5480af26805SThomas Raoux       return failure();
5490af26805SThomas Raoux     auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
5500af26805SThomas Raoux     auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
5510af26805SThomas Raoux     if (!dense)
5520af26805SThomas Raoux       return failure();
5530af26805SThomas Raoux     unsigned operandIndex = yieldOperand->getOperandNumber();
5540af26805SThomas Raoux     Attribute scalarAttr = dense.getSplatValue<Attribute>();
5550af26805SThomas Raoux     Attribute newAttr = DenseElementsAttr::get(
5560af26805SThomas Raoux         warpOp.getResult(operandIndex).getType(), scalarAttr);
5570af26805SThomas Raoux     Location loc = warpOp.getLoc();
5580af26805SThomas Raoux     rewriter.setInsertionPointAfter(warpOp);
5590af26805SThomas Raoux     Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
5600af26805SThomas Raoux     warpOp.getResult(operandIndex).replaceAllUsesWith(distConstant);
5610af26805SThomas Raoux     return success();
5620af26805SThomas Raoux   }
5630af26805SThomas Raoux };
5640af26805SThomas Raoux 
56576cf33daSThomas Raoux /// Sink out transfer_read op feeding into a warp op yield.
56676cf33daSThomas Raoux /// ```
56776cf33daSThomas Raoux /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
56876cf33daSThomas Raoux ///   ...
56976cf33daSThomas Raoux //    %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
57076cf33daSThomas Raoux //    vector<32xf32>
57176cf33daSThomas Raoux ///   vector.yield %2 : vector<32xf32>
57276cf33daSThomas Raoux /// }
57376cf33daSThomas Raoux /// ```
57476cf33daSThomas Raoux /// To
57576cf33daSThomas Raoux /// ```
57676cf33daSThomas Raoux /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
57776cf33daSThomas Raoux /// vector<1xf32>, vector<1xf32>) {
57876cf33daSThomas Raoux ///   ...
57976cf33daSThomas Raoux ///   %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
58076cf33daSThomas Raoux ///   vector<32xf32> vector.yield %2 : vector<32xf32>
58176cf33daSThomas Raoux /// }
58276cf33daSThomas Raoux /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
58376cf33daSThomas Raoux struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
58476cf33daSThomas Raoux   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpTransferRead58576cf33daSThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
58676cf33daSThomas Raoux                                 PatternRewriter &rewriter) const override {
58776cf33daSThomas Raoux     OpOperand *operand = getWarpResult(
58876cf33daSThomas Raoux         warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
58976cf33daSThomas Raoux     if (!operand)
59076cf33daSThomas Raoux       return failure();
59176cf33daSThomas Raoux     auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
59276cf33daSThomas Raoux     unsigned operandIndex = operand->getOperandNumber();
59376cf33daSThomas Raoux     Value distributedVal = warpOp.getResult(operandIndex);
59476cf33daSThomas Raoux 
59576cf33daSThomas Raoux     SmallVector<Value, 4> indices(read.getIndices().begin(),
59676cf33daSThomas Raoux                                   read.getIndices().end());
59776cf33daSThomas Raoux     AffineMap map = calculateImplicitMap(read.getResult(), distributedVal);
59876cf33daSThomas Raoux     AffineMap indexMap = map.compose(read.getPermutationMap());
59976cf33daSThomas Raoux     OpBuilder::InsertionGuard g(rewriter);
60076cf33daSThomas Raoux     rewriter.setInsertionPointAfter(warpOp);
60176cf33daSThomas Raoux     for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
60276cf33daSThomas Raoux       AffineExpr d0, d1;
60376cf33daSThomas Raoux       bindDims(read.getContext(), d0, d1);
60476cf33daSThomas Raoux       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
60576cf33daSThomas Raoux       if (!indexExpr)
60676cf33daSThomas Raoux         continue;
60776cf33daSThomas Raoux       unsigned indexPos = indexExpr.getPosition();
60876cf33daSThomas Raoux       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
60976cf33daSThomas Raoux       int64_t scale =
61076cf33daSThomas Raoux           distributedVal.getType().cast<VectorType>().getDimSize(vectorPos);
61176cf33daSThomas Raoux       indices[indexPos] =
61276cf33daSThomas Raoux           makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1,
61376cf33daSThomas Raoux                                   {indices[indexPos], warpOp.getLaneid()});
61476cf33daSThomas Raoux     }
61576cf33daSThomas Raoux     Value newRead = rewriter.create<vector::TransferReadOp>(
61676cf33daSThomas Raoux         read.getLoc(), distributedVal.getType(), read.getSource(), indices,
61776cf33daSThomas Raoux         read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
61876cf33daSThomas Raoux         read.getInBoundsAttr());
61976cf33daSThomas Raoux     distributedVal.replaceAllUsesWith(newRead);
62076cf33daSThomas Raoux     return success();
62176cf33daSThomas Raoux   }
62276cf33daSThomas Raoux };
62376cf33daSThomas Raoux 
62476cf33daSThomas Raoux /// Remove any result that has no use along with the matching yieldOp operand.
62576cf33daSThomas Raoux // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
62676cf33daSThomas Raoux struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
62776cf33daSThomas Raoux   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpDeadResult62876cf33daSThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
62976cf33daSThomas Raoux                                 PatternRewriter &rewriter) const override {
63076cf33daSThomas Raoux     SmallVector<Type> resultTypes;
63176cf33daSThomas Raoux     SmallVector<Value> yieldValues;
63276cf33daSThomas Raoux     auto yield = cast<vector::YieldOp>(
63376cf33daSThomas Raoux         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
63476cf33daSThomas Raoux     for (OpResult result : warpOp.getResults()) {
63576cf33daSThomas Raoux       if (result.use_empty())
63676cf33daSThomas Raoux         continue;
63776cf33daSThomas Raoux       resultTypes.push_back(result.getType());
63876cf33daSThomas Raoux       yieldValues.push_back(yield.getOperand(result.getResultNumber()));
63976cf33daSThomas Raoux     }
64076cf33daSThomas Raoux     if (yield.getNumOperands() == yieldValues.size())
64176cf33daSThomas Raoux       return failure();
64276cf33daSThomas Raoux     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
64376cf33daSThomas Raoux         rewriter, warpOp, yieldValues, resultTypes);
64476cf33daSThomas Raoux     unsigned resultIndex = 0;
64576cf33daSThomas Raoux     for (OpResult result : warpOp.getResults()) {
64676cf33daSThomas Raoux       if (result.use_empty())
64776cf33daSThomas Raoux         continue;
64876cf33daSThomas Raoux       result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++));
64976cf33daSThomas Raoux     }
65076cf33daSThomas Raoux     rewriter.eraseOp(warpOp);
65176cf33daSThomas Raoux     return success();
65276cf33daSThomas Raoux   }
65376cf33daSThomas Raoux };
65476cf33daSThomas Raoux 
65576cf33daSThomas Raoux // If an operand is directly yielded out of the region we can forward it
65676cf33daSThomas Raoux // directly and it doesn't need to go through the region.
65776cf33daSThomas Raoux struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
65876cf33daSThomas Raoux   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpForwardOperand65976cf33daSThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
66076cf33daSThomas Raoux                                 PatternRewriter &rewriter) const override {
66176cf33daSThomas Raoux     SmallVector<Type> resultTypes;
66276cf33daSThomas Raoux     SmallVector<Value> yieldValues;
66376cf33daSThomas Raoux     auto yield = cast<vector::YieldOp>(
66476cf33daSThomas Raoux         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
66576cf33daSThomas Raoux     Value valForwarded;
66676cf33daSThomas Raoux     unsigned resultIndex;
66776cf33daSThomas Raoux     for (OpOperand &operand : yield->getOpOperands()) {
66876cf33daSThomas Raoux       Value result = warpOp.getResult(operand.getOperandNumber());
66976cf33daSThomas Raoux       if (result.use_empty())
67076cf33daSThomas Raoux         continue;
67176cf33daSThomas Raoux 
67276cf33daSThomas Raoux       // Assume all the values coming from above are uniform.
67376cf33daSThomas Raoux       if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
67476cf33daSThomas Raoux         if (result.getType() != operand.get().getType())
67576cf33daSThomas Raoux           continue;
67676cf33daSThomas Raoux         valForwarded = operand.get();
67776cf33daSThomas Raoux         resultIndex = operand.getOperandNumber();
67876cf33daSThomas Raoux         break;
67976cf33daSThomas Raoux       }
68076cf33daSThomas Raoux       auto arg = operand.get().dyn_cast<BlockArgument>();
68176cf33daSThomas Raoux       if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
68276cf33daSThomas Raoux         continue;
68376cf33daSThomas Raoux       Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
68476cf33daSThomas Raoux       if (result.getType() != warpOperand.getType())
68576cf33daSThomas Raoux         continue;
68676cf33daSThomas Raoux       valForwarded = warpOperand;
68776cf33daSThomas Raoux       resultIndex = operand.getOperandNumber();
68876cf33daSThomas Raoux       break;
68976cf33daSThomas Raoux     }
69076cf33daSThomas Raoux     if (!valForwarded)
69176cf33daSThomas Raoux       return failure();
69276cf33daSThomas Raoux     warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded);
69376cf33daSThomas Raoux     return success();
69476cf33daSThomas Raoux   }
69576cf33daSThomas Raoux };
69676cf33daSThomas Raoux 
69776cf33daSThomas Raoux struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
69876cf33daSThomas Raoux   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpBroadcast69976cf33daSThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
70076cf33daSThomas Raoux                                 PatternRewriter &rewriter) const override {
70176cf33daSThomas Raoux     OpOperand *operand = getWarpResult(
70276cf33daSThomas Raoux         warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); });
70376cf33daSThomas Raoux     if (!operand)
70476cf33daSThomas Raoux       return failure();
70576cf33daSThomas Raoux     unsigned int operandNumber = operand->getOperandNumber();
70676cf33daSThomas Raoux     auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
70776cf33daSThomas Raoux     Location loc = broadcastOp.getLoc();
70876cf33daSThomas Raoux     auto destVecType =
70976cf33daSThomas Raoux         warpOp->getResultTypes()[operandNumber].cast<VectorType>();
710d7d6443dSThomas Raoux     SmallVector<size_t> newRetIndices;
71176cf33daSThomas Raoux     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
71276cf33daSThomas Raoux         rewriter, warpOp, {broadcastOp.getSource()},
713d7d6443dSThomas Raoux         {broadcastOp.getSource().getType()}, newRetIndices);
71476cf33daSThomas Raoux     rewriter.setInsertionPointAfter(newWarpOp);
71576cf33daSThomas Raoux     Value broadcasted = rewriter.create<vector::BroadcastOp>(
716d7d6443dSThomas Raoux         loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
71776cf33daSThomas Raoux     newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
71876cf33daSThomas Raoux     return success();
71976cf33daSThomas Raoux   }
72076cf33daSThomas Raoux };
72176cf33daSThomas Raoux 
722*f48ce52cSThomas Raoux /// Pattern to move out vector.extract of single element vector. Those don't
723*f48ce52cSThomas Raoux /// need to be distributed and can just be propagated outside of the region.
724*f48ce52cSThomas Raoux struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
725*f48ce52cSThomas Raoux   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpExtract726*f48ce52cSThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
727*f48ce52cSThomas Raoux                                 PatternRewriter &rewriter) const override {
728*f48ce52cSThomas Raoux     OpOperand *operand = getWarpResult(
729*f48ce52cSThomas Raoux         warpOp, [](Operation *op) { return isa<vector::ExtractOp>(op); });
730*f48ce52cSThomas Raoux     if (!operand)
731*f48ce52cSThomas Raoux       return failure();
732*f48ce52cSThomas Raoux     unsigned int operandNumber = operand->getOperandNumber();
733*f48ce52cSThomas Raoux     auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
734*f48ce52cSThomas Raoux     if (extractOp.getVectorType().getNumElements() != 1)
735*f48ce52cSThomas Raoux       return failure();
736*f48ce52cSThomas Raoux     Location loc = extractOp.getLoc();
737*f48ce52cSThomas Raoux     SmallVector<size_t> newRetIndices;
738*f48ce52cSThomas Raoux     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
739*f48ce52cSThomas Raoux         rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()},
740*f48ce52cSThomas Raoux         newRetIndices);
741*f48ce52cSThomas Raoux     rewriter.setInsertionPointAfter(newWarpOp);
742*f48ce52cSThomas Raoux     Value newExtract = rewriter.create<vector::ExtractOp>(
743*f48ce52cSThomas Raoux         loc, newWarpOp->getResult(newRetIndices[0]), extractOp.getPosition());
744*f48ce52cSThomas Raoux     newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
745*f48ce52cSThomas Raoux     return success();
746*f48ce52cSThomas Raoux   }
747*f48ce52cSThomas Raoux };
748*f48ce52cSThomas Raoux 
74976cf33daSThomas Raoux /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
75076cf33daSThomas Raoux /// the scf.ForOp is the last operation in the region so that it doesn't change
75176cf33daSThomas Raoux /// the order of execution. This creates a new scf.for region after the
75276cf33daSThomas Raoux /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
75376cf33daSThomas Raoux /// WarpExecuteOnLane0Op region. Example:
75476cf33daSThomas Raoux /// ```
75576cf33daSThomas Raoux /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
75676cf33daSThomas Raoux ///   ...
75776cf33daSThomas Raoux ///   %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
75876cf33daSThomas Raoux ///   -> (vector<128xf32>) {
75976cf33daSThomas Raoux ///     ...
76076cf33daSThomas Raoux ///     scf.yield %r : vector<128xf32>
76176cf33daSThomas Raoux ///   }
76276cf33daSThomas Raoux ///   vector.yield %v1 : vector<128xf32>
76376cf33daSThomas Raoux /// }
76476cf33daSThomas Raoux /// ```
76576cf33daSThomas Raoux /// To:
76676cf33daSThomas Raoux /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
76776cf33daSThomas Raoux ///   ...
76876cf33daSThomas Raoux ///   vector.yield %v : vector<128xf32>
76976cf33daSThomas Raoux /// }
77076cf33daSThomas Raoux /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
77176cf33daSThomas Raoux ///   -> (vector<4xf32>) {
77276cf33daSThomas Raoux ///     %iw = vector.warp_execute_on_lane_0(%laneid)
77376cf33daSThomas Raoux ///     args(%varg : vector<4xf32>) -> (vector<4xf32>) {
77476cf33daSThomas Raoux ///     ^bb0(%arg: vector<128xf32>):
77576cf33daSThomas Raoux ///       ...
77676cf33daSThomas Raoux ///       vector.yield %ir : vector<128xf32>
77776cf33daSThomas Raoux ///     }
77876cf33daSThomas Raoux ///     scf.yield %iw : vector<4xf32>
77976cf33daSThomas Raoux ///  }
78076cf33daSThomas Raoux /// ```
78176cf33daSThomas Raoux struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
78276cf33daSThomas Raoux   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
matchAndRewrite__anonbfa0aa500211::WarpOpScfForOp78376cf33daSThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
78476cf33daSThomas Raoux                                 PatternRewriter &rewriter) const override {
78576cf33daSThomas Raoux     auto yield = cast<vector::YieldOp>(
78676cf33daSThomas Raoux         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
78776cf33daSThomas Raoux     // Only pick up forOp if it is the last op in the region.
78876cf33daSThomas Raoux     Operation *lastNode = yield->getPrevNode();
78976cf33daSThomas Raoux     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
79076cf33daSThomas Raoux     if (!forOp)
79176cf33daSThomas Raoux       return failure();
79276cf33daSThomas Raoux     SmallVector<Value> newOperands;
79376cf33daSThomas Raoux     SmallVector<unsigned> resultIdx;
79476cf33daSThomas Raoux     // Collect all the outputs coming from the forOp.
79576cf33daSThomas Raoux     for (OpOperand &yieldOperand : yield->getOpOperands()) {
79676cf33daSThomas Raoux       if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
79776cf33daSThomas Raoux         continue;
79876cf33daSThomas Raoux       auto forResult = yieldOperand.get().cast<OpResult>();
79976cf33daSThomas Raoux       newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber()));
80076cf33daSThomas Raoux       yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
80176cf33daSThomas Raoux       resultIdx.push_back(yieldOperand.getOperandNumber());
80276cf33daSThomas Raoux     }
80376cf33daSThomas Raoux     OpBuilder::InsertionGuard g(rewriter);
80476cf33daSThomas Raoux     rewriter.setInsertionPointAfter(warpOp);
80576cf33daSThomas Raoux     // Create a new for op outside the region with a WarpExecuteOnLane0Op region
80676cf33daSThomas Raoux     // inside.
80776cf33daSThomas Raoux     auto newForOp = rewriter.create<scf::ForOp>(
80876cf33daSThomas Raoux         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
80976cf33daSThomas Raoux         forOp.getStep(), newOperands);
81076cf33daSThomas Raoux     rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
81176cf33daSThomas Raoux     auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
81276cf33daSThomas Raoux         warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(),
81376cf33daSThomas Raoux         warpOp.getWarpSize(), newForOp.getRegionIterArgs(),
81476cf33daSThomas Raoux         forOp.getResultTypes());
81576cf33daSThomas Raoux 
81676cf33daSThomas Raoux     SmallVector<Value> argMapping;
81776cf33daSThomas Raoux     argMapping.push_back(newForOp.getInductionVar());
81876cf33daSThomas Raoux     for (Value args : innerWarp.getBody()->getArguments()) {
81976cf33daSThomas Raoux       argMapping.push_back(args);
82076cf33daSThomas Raoux     }
82176cf33daSThomas Raoux     SmallVector<Value> yieldOperands;
82276cf33daSThomas Raoux     for (Value operand : forOp.getBody()->getTerminator()->getOperands())
82376cf33daSThomas Raoux       yieldOperands.push_back(operand);
82476cf33daSThomas Raoux     rewriter.eraseOp(forOp.getBody()->getTerminator());
82576cf33daSThomas Raoux     rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
82676cf33daSThomas Raoux     rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
82776cf33daSThomas Raoux     rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
82876cf33daSThomas Raoux     rewriter.setInsertionPointAfter(innerWarp);
829d343cdd5SThomas Raoux     if (!innerWarp.getResults().empty())
83076cf33daSThomas Raoux       rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
83176cf33daSThomas Raoux     rewriter.eraseOp(forOp);
83276cf33daSThomas Raoux     // Replace the warpOp result coming from the original ForOp.
83376cf33daSThomas Raoux     for (const auto &res : llvm::enumerate(resultIdx)) {
83476cf33daSThomas Raoux       warpOp.getResult(res.value())
83576cf33daSThomas Raoux           .replaceAllUsesWith(newForOp.getResult(res.index()));
83676cf33daSThomas Raoux       newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value()));
83776cf33daSThomas Raoux     }
83876cf33daSThomas Raoux     return success();
83976cf33daSThomas Raoux   }
84076cf33daSThomas Raoux };
84176cf33daSThomas Raoux 
842087aba4fSThomas Raoux /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
8436834803cSThomas Raoux /// The vector is reduced in parallel. Currently limited to vector size matching
8446834803cSThomas Raoux /// the warpOp size. E.g.:
845087aba4fSThomas Raoux /// ```
8466834803cSThomas Raoux /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
847087aba4fSThomas Raoux ///   %0 = "some_def"() : () -> (vector<32xf32>)
848087aba4fSThomas Raoux ///   %1 = vector.reduction "add", %0 : vector<32xf32> into f32
849087aba4fSThomas Raoux ///   vector_ext.yield %1 : f32
850087aba4fSThomas Raoux /// }
851087aba4fSThomas Raoux /// ```
852087aba4fSThomas Raoux /// is lowered to:
853087aba4fSThomas Raoux /// ```
8546834803cSThomas Raoux /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
855087aba4fSThomas Raoux ///   %1 = "some_def"() : () -> (vector<32xf32>)
856087aba4fSThomas Raoux ///   vector_ext.yield %1 : vector<32xf32>
857087aba4fSThomas Raoux /// }
858087aba4fSThomas Raoux /// %a = vector.extract %0[0] : vector<1xf32>
8596834803cSThomas Raoux /// %r = ("warp.reduction %a")
860087aba4fSThomas Raoux /// ```
8616834803cSThomas Raoux struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
WarpOpReduction__anonbfa0aa500211::WarpOpReduction8626834803cSThomas Raoux   WarpOpReduction(MLIRContext *context,
8636834803cSThomas Raoux                   DistributedReductionFn distributedReductionFn,
8646834803cSThomas Raoux                   PatternBenefit benefit = 1)
8656834803cSThomas Raoux       : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
8666834803cSThomas Raoux         distributedReductionFn(distributedReductionFn) {}
867087aba4fSThomas Raoux 
matchAndRewrite__anonbfa0aa500211::WarpOpReduction868087aba4fSThomas Raoux   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
869087aba4fSThomas Raoux                                 PatternRewriter &rewriter) const override {
870087aba4fSThomas Raoux     OpOperand *yieldOperand = getWarpResult(
871087aba4fSThomas Raoux         warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); });
872087aba4fSThomas Raoux     if (!yieldOperand)
873087aba4fSThomas Raoux       return failure();
874087aba4fSThomas Raoux 
875087aba4fSThomas Raoux     auto reductionOp =
876087aba4fSThomas Raoux         cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
877087aba4fSThomas Raoux     auto vectorType = reductionOp.getVector().getType().cast<VectorType>();
878087aba4fSThomas Raoux     // Only rank 1 vectors supported.
879087aba4fSThomas Raoux     if (vectorType.getRank() != 1)
880087aba4fSThomas Raoux       return rewriter.notifyMatchFailure(
881087aba4fSThomas Raoux           warpOp, "Only rank 1 reductions can be distributed.");
882087aba4fSThomas Raoux     // Only warp_size-sized vectors supported.
8830660f3c5SThomas Raoux     if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
884087aba4fSThomas Raoux       return rewriter.notifyMatchFailure(
885087aba4fSThomas Raoux           warpOp, "Reduction vector dimension must match was size.");
886087aba4fSThomas Raoux     // Only f32 and i32 element types are supported.
887087aba4fSThomas Raoux     if (!reductionOp.getType().isF32() &&
888087aba4fSThomas Raoux         !reductionOp.getType().isSignlessInteger(32))
889087aba4fSThomas Raoux       return rewriter.notifyMatchFailure(
890087aba4fSThomas Raoux           warpOp,
891087aba4fSThomas Raoux           "Reduction distribution currently only supports 32bits types.");
892087aba4fSThomas Raoux 
8930660f3c5SThomas Raoux     int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
894087aba4fSThomas Raoux     // Return vector that will be reduced from the WarpExecuteOnLane0Op.
895087aba4fSThomas Raoux     unsigned operandIndex = yieldOperand->getOperandNumber();
896087aba4fSThomas Raoux     SmallVector<Value> yieldValues = {reductionOp.getVector()};
8970660f3c5SThomas Raoux     SmallVector<Type> retTypes = {
8980660f3c5SThomas Raoux         VectorType::get({numElements}, reductionOp.getType())};
899ffa7384fSThomas Raoux     if (reductionOp.getAcc()) {
900ffa7384fSThomas Raoux       yieldValues.push_back(reductionOp.getAcc());
901ffa7384fSThomas Raoux       retTypes.push_back(reductionOp.getAcc().getType());
902ffa7384fSThomas Raoux     }
903d7d6443dSThomas Raoux     SmallVector<size_t> newRetIndices;
904087aba4fSThomas Raoux     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
905d7d6443dSThomas Raoux         rewriter, warpOp, yieldValues, retTypes, newRetIndices);
906087aba4fSThomas Raoux     rewriter.setInsertionPointAfter(newWarpOp);
907087aba4fSThomas Raoux 
908d7d6443dSThomas Raoux     Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
9090660f3c5SThomas Raoux     // First reduce on a single thread.
9100660f3c5SThomas Raoux     Value perLaneReduction = rewriter.create<vector::ReductionOp>(
9110660f3c5SThomas Raoux         reductionOp.getLoc(), reductionOp.getKind(), laneValVec);
9120660f3c5SThomas Raoux     // Then distribute across threads.
9130660f3c5SThomas Raoux     Value fullReduce =
9140660f3c5SThomas Raoux         distributedReductionFn(reductionOp.getLoc(), rewriter, perLaneReduction,
9156834803cSThomas Raoux                                reductionOp.getKind(), newWarpOp.getWarpSize());
916ffa7384fSThomas Raoux     if (reductionOp.getAcc()) {
917ffa7384fSThomas Raoux       fullReduce = vector::makeArithReduction(
918ffa7384fSThomas Raoux           rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
919ffa7384fSThomas Raoux           newWarpOp.getResult(newRetIndices[1]));
920ffa7384fSThomas Raoux     }
9210660f3c5SThomas Raoux     newWarpOp.getResult(operandIndex).replaceAllUsesWith(fullReduce);
922087aba4fSThomas Raoux     return success();
923087aba4fSThomas Raoux   }
9246834803cSThomas Raoux 
9256834803cSThomas Raoux private:
9266834803cSThomas Raoux   DistributedReductionFn distributedReductionFn;
927087aba4fSThomas Raoux };
928087aba4fSThomas Raoux 
929d02f10d9SThomas Raoux } // namespace
930d02f10d9SThomas Raoux 
populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet & patterns,const WarpExecuteOnLane0LoweringOptions & options)931d02f10d9SThomas Raoux void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
932d02f10d9SThomas Raoux     RewritePatternSet &patterns,
933d02f10d9SThomas Raoux     const WarpExecuteOnLane0LoweringOptions &options) {
934d02f10d9SThomas Raoux   patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options);
935d02f10d9SThomas Raoux }
936ed0288f7SThomas Raoux 
populateDistributeTransferWriteOpPatterns(RewritePatternSet & patterns,const DistributionMapFn & distributionMapFn)937ed0288f7SThomas Raoux void mlir::vector::populateDistributeTransferWriteOpPatterns(
93808d651d7SMehdi Amini     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn) {
939ed0288f7SThomas Raoux   patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn);
940ed0288f7SThomas Raoux }
941ed0288f7SThomas Raoux 
populatePropagateWarpVectorDistributionPatterns(RewritePatternSet & patterns)94276cf33daSThomas Raoux void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
94376cf33daSThomas Raoux     RewritePatternSet &patterns) {
94476cf33daSThomas Raoux   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
945*f48ce52cSThomas Raoux                WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
946*f48ce52cSThomas Raoux                WarpOpScfForOp, WarpOpConstant>(patterns.getContext());
94776cf33daSThomas Raoux }
94876cf33daSThomas Raoux 
populateDistributeReduction(RewritePatternSet & patterns,DistributedReductionFn distributedReductionFn)9496834803cSThomas Raoux void mlir::vector::populateDistributeReduction(
9506834803cSThomas Raoux     RewritePatternSet &patterns,
9516834803cSThomas Raoux     DistributedReductionFn distributedReductionFn) {
9526834803cSThomas Raoux   patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn);
953087aba4fSThomas Raoux }
954087aba4fSThomas Raoux 
moveScalarUniformCode(WarpExecuteOnLane0Op warpOp)955ed0288f7SThomas Raoux void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
956ed0288f7SThomas Raoux   Block *body = warpOp.getBody();
957ed0288f7SThomas Raoux 
958ed0288f7SThomas Raoux   // Keep track of the ops we want to hoist.
959ed0288f7SThomas Raoux   llvm::SmallSetVector<Operation *, 8> opsToMove;
960ed0288f7SThomas Raoux 
961ed0288f7SThomas Raoux   // Helper to check if a value is or will be defined outside of the region.
962ed0288f7SThomas Raoux   auto isDefinedOutsideOfBody = [&](Value value) {
963ed0288f7SThomas Raoux     auto *definingOp = value.getDefiningOp();
964ed0288f7SThomas Raoux     return (definingOp && opsToMove.count(definingOp)) ||
965ed0288f7SThomas Raoux            warpOp.isDefinedOutsideOfRegion(value);
966ed0288f7SThomas Raoux   };
967ed0288f7SThomas Raoux 
968ed0288f7SThomas Raoux   // Do not use walk here, as we do not want to go into nested regions and hoist
969ed0288f7SThomas Raoux   // operations from there.
970ed0288f7SThomas Raoux   for (auto &op : body->without_terminator()) {
971ed0288f7SThomas Raoux     bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
972ed0288f7SThomas Raoux       return result.getType().isa<VectorType>();
973ed0288f7SThomas Raoux     });
974ed0288f7SThomas Raoux     if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
975ed0288f7SThomas Raoux       opsToMove.insert(&op);
976ed0288f7SThomas Raoux   }
977ed0288f7SThomas Raoux 
978ed0288f7SThomas Raoux   // Move all the ops marked as uniform outside of the region.
979ed0288f7SThomas Raoux   for (Operation *op : opsToMove)
980ed0288f7SThomas Raoux     op->moveBefore(warpOp);
981ed0288f7SThomas Raoux }
982