1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
15 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
16 #include "mlir/IR/BlockAndValueMapping.h"
17 #include "mlir/Transforms/SideEffectUtils.h"
18 
19 using namespace mlir;
20 using namespace mlir::vector;
21 
22 static LogicalResult
23 rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
24                       const WarpExecuteOnLane0LoweringOptions &options) {
25   assert(warpOp.getBodyRegion().hasOneBlock() &&
26          "expected WarpOp with single block");
27   Block *warpOpBody = &warpOp.getBodyRegion().front();
28   Location loc = warpOp.getLoc();
29 
30   // Passed all checks. Start rewriting.
31   OpBuilder::InsertionGuard g(rewriter);
32   rewriter.setInsertionPoint(warpOp);
33 
34   // Create scf.if op.
35   Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
36   Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
37                                                  warpOp.getLaneid(), c0);
38   auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
39                                          /*withElseRegion=*/false);
40   rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
41 
42   // Store vectors that are defined outside of warpOp into the scratch pad
43   // buffer.
44   SmallVector<Value> bbArgReplacements;
45   for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
46     Value val = it.value();
47     Value bbArg = warpOpBody->getArgument(it.index());
48 
49     rewriter.setInsertionPoint(ifOp);
50     Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp,
51                                             bbArg.getType());
52 
53     // Store arg vector into buffer.
54     rewriter.setInsertionPoint(ifOp);
55     auto vectorType = val.getType().cast<VectorType>();
56     int64_t storeSize = vectorType.getShape()[0];
57     Value storeOffset = rewriter.create<arith::MulIOp>(
58         loc, warpOp.getLaneid(),
59         rewriter.create<arith::ConstantIndexOp>(loc, storeSize));
60     rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset);
61 
62     // Load bbArg vector from buffer.
63     rewriter.setInsertionPointToStart(ifOp.thenBlock());
64     auto bbArgType = bbArg.getType().cast<VectorType>();
65     Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0);
66     bbArgReplacements.push_back(loadOp);
67   }
68 
69   // Insert sync after all the stores and before all the loads.
70   if (!warpOp.getArgs().empty()) {
71     rewriter.setInsertionPoint(ifOp);
72     options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp);
73   }
74 
75   // Move body of warpOp to ifOp.
76   rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
77 
78   // Rewrite terminator and compute replacements of WarpOp results.
79   SmallVector<Value> replacements;
80   auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
81   Location yieldLoc = yieldOp.getLoc();
82   for (const auto &it : llvm::enumerate(yieldOp.operands())) {
83     Value val = it.value();
84     Type resultType = warpOp->getResultTypes()[it.index()];
85     rewriter.setInsertionPoint(ifOp);
86     Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp,
87                                             val.getType());
88 
89     // Store yielded value into buffer.
90     rewriter.setInsertionPoint(yieldOp);
91     if (val.getType().isa<VectorType>())
92       rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0);
93     else
94       rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0);
95 
96     // Load value from buffer (after warpOp).
97     rewriter.setInsertionPointAfter(ifOp);
98     if (resultType == val.getType()) {
99       // Result type and yielded value type are the same. This is a broadcast.
100       // E.g.:
101       // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
102       //   vector.yield %cst : f32
103       // }
104       // Both types are f32. The constant %cst is broadcasted to all lanes.
105       // This is described in more detail in the documentation of the op.
106       Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0);
107       replacements.push_back(loadOp);
108     } else {
109       auto loadedVectorType = resultType.cast<VectorType>();
110       int64_t loadSize = loadedVectorType.getShape()[0];
111 
112       // loadOffset = laneid * loadSize
113       Value loadOffset = rewriter.create<arith::MulIOp>(
114           loc, warpOp.getLaneid(),
115           rewriter.create<arith::ConstantIndexOp>(loc, loadSize));
116       Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType,
117                                                      buffer, loadOffset);
118       replacements.push_back(loadOp);
119     }
120   }
121 
122   // Insert sync after all the stores and before all the loads.
123   if (!yieldOp.operands().empty()) {
124     rewriter.setInsertionPointAfter(ifOp);
125     options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp);
126   }
127 
128   // Delete terminator and add empty scf.yield.
129   rewriter.eraseOp(yieldOp);
130   rewriter.setInsertionPointToEnd(ifOp.thenBlock());
131   rewriter.create<scf::YieldOp>(yieldLoc);
132 
133   // Compute replacements for WarpOp results.
134   rewriter.replaceOp(warpOp, replacements);
135 
136   return success();
137 }
138 
139 /// Helper to create a new WarpExecuteOnLane0Op with different signature.
140 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
141     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
142     ValueRange newYieldedValues, TypeRange newReturnTypes) {
143   // Create a new op before the existing one, with the extra operands.
144   OpBuilder::InsertionGuard g(rewriter);
145   rewriter.setInsertionPoint(warpOp);
146   auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
147       warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
148       warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
149 
150   Region &opBody = warpOp.getBodyRegion();
151   Region &newOpBody = newWarpOp.getBodyRegion();
152   rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
153   auto yield =
154       cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
155 
156   rewriter.updateRootInPlace(
157       yield, [&]() { yield.operandsMutable().assign(newYieldedValues); });
158   return newWarpOp;
159 }
160 
161 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
162 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
163     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
164     ValueRange newYieldedValues, TypeRange newReturnTypes) {
165   SmallVector<Type> types(warpOp.getResultTypes().begin(),
166                           warpOp.getResultTypes().end());
167   types.append(newReturnTypes.begin(), newReturnTypes.end());
168   auto yield = cast<vector::YieldOp>(
169       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
170   SmallVector<Value> yieldValues(yield.getOperands().begin(),
171                                  yield.getOperands().end());
172   yieldValues.append(newYieldedValues.begin(), newYieldedValues.end());
173   WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
174       rewriter, warpOp, yieldValues, types);
175   rewriter.replaceOp(warpOp,
176                      newWarpOp.getResults().take_front(warpOp.getNumResults()));
177   return newWarpOp;
178 }
179 
180 /// Helper to know if an op can be hoisted out of the region.
181 static bool canBeHoisted(Operation *op,
182                          function_ref<bool(Value)> definedOutside) {
183   return llvm::all_of(op->getOperands(), definedOutside) &&
184          isSideEffectFree(op) && op->getNumRegions() == 0;
185 }
186 
187 /// Return a value yielded by `warpOp` which statifies the filter lamdba
188 /// condition and is not dead.
189 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
190                                 std::function<bool(Operation *)> fn) {
191   auto yield = cast<vector::YieldOp>(
192       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
193   for (OpOperand &yieldOperand : yield->getOpOperands()) {
194     Value yieldValues = yieldOperand.get();
195     Operation *definedOp = yieldValues.getDefiningOp();
196     if (definedOp && fn(definedOp)) {
197       if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
198         return &yieldOperand;
199     }
200   }
201   return {};
202 }
203 
204 // Clones `op` into a new operation that takes `operands` and returns
205 // `resultTypes`.
206 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
207                                               Location loc, Operation *op,
208                                               ArrayRef<Value> operands,
209                                               ArrayRef<Type> resultTypes) {
210   OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
211                      op->getAttrs());
212   return rewriter.create(res);
213 }
214 
215 /// Currently the distribution map is implicit based on the vector shape. In the
216 /// future it will be part of the op.
217 /// Example:
218 /// ```
219 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
220 ///   ...
221 ///   vector.yield %3 : vector<32x16x64xf32>
222 /// }
223 /// ```
224 /// Would have an implicit map of:
225 /// `(d0, d1, d2) -> (d0, d2)`
226 static AffineMap calculateImplicitMap(Value yield, Value ret) {
227   auto srcType = yield.getType().cast<VectorType>();
228   auto dstType = ret.getType().cast<VectorType>();
229   SmallVector<AffineExpr> perm;
230   // Check which dimensions of the yield value are different than the dimensions
231   // of the result to know the distributed dimensions. Then associate each
232   // distributed dimension to an ID in order.
233   for (unsigned i = 0, e = srcType.getRank(); i < e; i++) {
234     if (srcType.getDimSize(i) != dstType.getDimSize(i))
235       perm.push_back(getAffineDimExpr(i, yield.getContext()));
236   }
237   auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext());
238   return map;
239 }
240 
241 namespace {
242 
243 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
244   WarpOpToScfForPattern(MLIRContext *context,
245                         const WarpExecuteOnLane0LoweringOptions &options,
246                         PatternBenefit benefit = 1)
247       : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
248         options(options) {}
249 
250   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
251                                 PatternRewriter &rewriter) const override {
252     return rewriteWarpOpToScfFor(rewriter, warpOp, options);
253   }
254 
255 private:
256   const WarpExecuteOnLane0LoweringOptions &options;
257 };
258 
259 /// Distribute transfer_write ops based on the affine map returned by
260 /// `distributionMapFn`.
261 /// Example:
262 /// ```
263 /// %0 = vector.warp_execute_on_lane_0(%id){
264 ///   ...
265 ///   vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
266 ///   vector.yield
267 /// }
268 /// ```
269 /// To
270 /// ```
271 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
272 ///   ...
273 ///   vector.yield %v : vector<32xf32>
274 /// }
275 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
276 struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
277   WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
278                       PatternBenefit b = 1)
279       : OpRewritePattern<vector::TransferWriteOp>(ctx, b),
280         distributionMapFn(fn) {}
281 
282   /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
283   /// are multiples of the distribution ratio are supported at the moment.
284   LogicalResult tryDistributeOp(RewriterBase &rewriter,
285                                 vector::TransferWriteOp writeOp,
286                                 WarpExecuteOnLane0Op warpOp) const {
287     AffineMap map = distributionMapFn(writeOp);
288     SmallVector<int64_t> targetShape(writeOp.getVectorType().getShape().begin(),
289                                      writeOp.getVectorType().getShape().end());
290     assert(map.getNumResults() == 1 &&
291            "multi-dim distribution not implemented yet");
292     for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
293       unsigned position = map.getDimPosition(i);
294       if (targetShape[position] % warpOp.getWarpSize() != 0)
295         return failure();
296       targetShape[position] = targetShape[position] / warpOp.getWarpSize();
297     }
298     VectorType targetType =
299         VectorType::get(targetShape, writeOp.getVectorType().getElementType());
300 
301     SmallVector<Value> yieldValues = {writeOp.getVector()};
302     SmallVector<Type> retTypes = {targetType};
303     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
304         rewriter, warpOp, yieldValues, retTypes);
305     rewriter.setInsertionPointAfter(newWarpOp);
306 
307     // Move op outside of region: Insert clone at the insertion point and delete
308     // the old op.
309     auto newWriteOp =
310         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
311     rewriter.eraseOp(writeOp);
312 
313     rewriter.setInsertionPoint(newWriteOp);
314     AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
315     Location loc = newWriteOp.getLoc();
316     SmallVector<Value> indices(newWriteOp.getIndices().begin(),
317                                newWriteOp.getIndices().end());
318     for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
319       AffineExpr d0, d1;
320       bindDims(newWarpOp.getContext(), d0, d1);
321       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
322       if (!indexExpr)
323         continue;
324       unsigned indexPos = indexExpr.getPosition();
325       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
326       auto scale =
327           getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext());
328       indices[indexPos] =
329           makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
330                                   {indices[indexPos], newWarpOp.getLaneid()});
331     }
332     newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
333     newWriteOp.getIndicesMutable().assign(indices);
334 
335     return success();
336   }
337 
338   /// Extract TransferWriteOps of vector<1x> into a separate warp op.
339   LogicalResult tryExtractOp(RewriterBase &rewriter,
340                              vector::TransferWriteOp writeOp,
341                              WarpExecuteOnLane0Op warpOp) const {
342     Location loc = writeOp.getLoc();
343     VectorType vecType = writeOp.getVectorType();
344 
345     // Only vector<1x> is supported at the moment.
346     if (vecType.getShape().size() != 1 || vecType.getShape()[0] != 1)
347       return failure();
348 
349     // Do not process warp ops that contain only TransferWriteOps.
350     if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
351           return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
352         }))
353       return failure();
354 
355     SmallVector<Value> yieldValues = {writeOp.getVector()};
356     SmallVector<Type> retTypes = {vecType};
357     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
358         rewriter, warpOp, yieldValues, retTypes);
359     rewriter.setInsertionPointAfter(newWarpOp);
360 
361     // Create a second warp op that contains only writeOp.
362     auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
363         loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
364     Block &body = secondWarpOp.getBodyRegion().front();
365     rewriter.setInsertionPointToStart(&body);
366     auto newWriteOp =
367         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
368     newWriteOp.getVectorMutable().assign(
369         newWarpOp.getResult(newWarpOp.getNumResults() - 1));
370     rewriter.eraseOp(writeOp);
371     rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
372     return success();
373   }
374 
375   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
376                                 PatternRewriter &rewriter) const override {
377     // Ops with mask not supported yet.
378     if (writeOp.getMask())
379       return failure();
380 
381     auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
382     if (!warpOp)
383       return failure();
384 
385     // There must be no op with a side effect after writeOp.
386     Operation *nextOp = writeOp.getOperation();
387     while ((nextOp = nextOp->getNextNode()))
388       if (!isSideEffectFree(nextOp))
389         return failure();
390 
391     if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
392           return writeOp.getVector() == value ||
393                  warpOp.isDefinedOutsideOfRegion(value);
394         }))
395       return failure();
396 
397     if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
398       return success();
399 
400     if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
401       return success();
402 
403     return failure();
404   }
405 
406 private:
407   DistributionMapFn distributionMapFn;
408 };
409 
410 /// Sink out elementwise op feeding into a warp op yield.
411 /// ```
412 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
413 ///   ...
414 ///   %3 = arith.addf %1, %2 : vector<32xf32>
415 ///   vector.yield %3 : vector<32xf32>
416 /// }
417 /// ```
418 /// To
419 /// ```
420 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
421 /// vector<1xf32>, vector<1xf32>) {
422 ///   ...
423 ///   %4 = arith.addf %2, %3 : vector<32xf32>
424 ///   vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
425 ///   vector<32xf32>
426 /// }
427 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
428 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
429   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
430   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
431                                 PatternRewriter &rewriter) const override {
432     OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
433       return OpTrait::hasElementwiseMappableTraits(op);
434     });
435     if (!yieldOperand)
436       return failure();
437     Operation *elementWise = yieldOperand->get().getDefiningOp();
438     unsigned operandIndex = yieldOperand->getOperandNumber();
439     Value distributedVal = warpOp.getResult(operandIndex);
440     SmallVector<Value> yieldValues;
441     SmallVector<Type> retTypes;
442     Location loc = warpOp.getLoc();
443     for (OpOperand &operand : elementWise->getOpOperands()) {
444       Type targetType;
445       if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) {
446         // If the result type is a vector, the operands must also be vectors.
447         auto operandType = operand.get().getType().cast<VectorType>();
448         targetType =
449             VectorType::get(vecType.getShape(), operandType.getElementType());
450       } else {
451         auto operandType = operand.get().getType();
452         assert(!operandType.isa<VectorType>() &&
453                "unexpected yield of vector from op with scalar result type");
454         targetType = operandType;
455       }
456       retTypes.push_back(targetType);
457       yieldValues.push_back(operand.get());
458     }
459     unsigned numResults = warpOp.getNumResults();
460     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
461         rewriter, warpOp, yieldValues, retTypes);
462     rewriter.setInsertionPointAfter(newWarpOp);
463     SmallVector<Value> newOperands(elementWise->getOperands().begin(),
464                                    elementWise->getOperands().end());
465     for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
466       newOperands[i] = newWarpOp.getResult(i + numResults);
467     }
468     OpBuilder::InsertionGuard g(rewriter);
469     rewriter.setInsertionPointAfter(newWarpOp);
470     Operation *newOp = cloneOpWithOperandsAndTypes(
471         rewriter, loc, elementWise, newOperands,
472         {newWarpOp.getResult(operandIndex).getType()});
473     newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0));
474     return success();
475   }
476 };
477 
478 /// Sink out transfer_read op feeding into a warp op yield.
479 /// ```
480 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
481 ///   ...
482 //    %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
483 //    vector<32xf32>
484 ///   vector.yield %2 : vector<32xf32>
485 /// }
486 /// ```
487 /// To
488 /// ```
489 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
490 /// vector<1xf32>, vector<1xf32>) {
491 ///   ...
492 ///   %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
493 ///   vector<32xf32> vector.yield %2 : vector<32xf32>
494 /// }
495 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
496 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
497   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
498   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
499                                 PatternRewriter &rewriter) const override {
500     OpOperand *operand = getWarpResult(
501         warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
502     if (!operand)
503       return failure();
504     auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
505     unsigned operandIndex = operand->getOperandNumber();
506     Value distributedVal = warpOp.getResult(operandIndex);
507 
508     SmallVector<Value, 4> indices(read.getIndices().begin(),
509                                   read.getIndices().end());
510     AffineMap map = calculateImplicitMap(read.getResult(), distributedVal);
511     AffineMap indexMap = map.compose(read.getPermutationMap());
512     OpBuilder::InsertionGuard g(rewriter);
513     rewriter.setInsertionPointAfter(warpOp);
514     for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
515       AffineExpr d0, d1;
516       bindDims(read.getContext(), d0, d1);
517       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
518       if (!indexExpr)
519         continue;
520       unsigned indexPos = indexExpr.getPosition();
521       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
522       int64_t scale =
523           distributedVal.getType().cast<VectorType>().getDimSize(vectorPos);
524       indices[indexPos] =
525           makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1,
526                                   {indices[indexPos], warpOp.getLaneid()});
527     }
528     Value newRead = rewriter.create<vector::TransferReadOp>(
529         read.getLoc(), distributedVal.getType(), read.getSource(), indices,
530         read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
531         read.getInBoundsAttr());
532     distributedVal.replaceAllUsesWith(newRead);
533     return success();
534   }
535 };
536 
537 /// Remove any result that has no use along with the matching yieldOp operand.
538 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
539 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
540   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
541   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
542                                 PatternRewriter &rewriter) const override {
543     SmallVector<Type> resultTypes;
544     SmallVector<Value> yieldValues;
545     auto yield = cast<vector::YieldOp>(
546         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
547     for (OpResult result : warpOp.getResults()) {
548       if (result.use_empty())
549         continue;
550       resultTypes.push_back(result.getType());
551       yieldValues.push_back(yield.getOperand(result.getResultNumber()));
552     }
553     if (yield.getNumOperands() == yieldValues.size())
554       return failure();
555     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
556         rewriter, warpOp, yieldValues, resultTypes);
557     unsigned resultIndex = 0;
558     for (OpResult result : warpOp.getResults()) {
559       if (result.use_empty())
560         continue;
561       result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++));
562     }
563     rewriter.eraseOp(warpOp);
564     return success();
565   }
566 };
567 
568 // If an operand is directly yielded out of the region we can forward it
569 // directly and it doesn't need to go through the region.
570 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
571   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
572   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
573                                 PatternRewriter &rewriter) const override {
574     SmallVector<Type> resultTypes;
575     SmallVector<Value> yieldValues;
576     auto yield = cast<vector::YieldOp>(
577         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
578     Value valForwarded;
579     unsigned resultIndex;
580     for (OpOperand &operand : yield->getOpOperands()) {
581       Value result = warpOp.getResult(operand.getOperandNumber());
582       if (result.use_empty())
583         continue;
584 
585       // Assume all the values coming from above are uniform.
586       if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
587         if (result.getType() != operand.get().getType())
588           continue;
589         valForwarded = operand.get();
590         resultIndex = operand.getOperandNumber();
591         break;
592       }
593       auto arg = operand.get().dyn_cast<BlockArgument>();
594       if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
595         continue;
596       Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
597       if (result.getType() != warpOperand.getType())
598         continue;
599       valForwarded = warpOperand;
600       resultIndex = operand.getOperandNumber();
601       break;
602     }
603     if (!valForwarded)
604       return failure();
605     warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded);
606     return success();
607   }
608 };
609 
610 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
611   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
612   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
613                                 PatternRewriter &rewriter) const override {
614     OpOperand *operand = getWarpResult(
615         warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); });
616     if (!operand)
617       return failure();
618     unsigned int operandNumber = operand->getOperandNumber();
619     auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
620     Location loc = broadcastOp.getLoc();
621     auto destVecType =
622         warpOp->getResultTypes()[operandNumber].cast<VectorType>();
623     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
624         rewriter, warpOp, {broadcastOp.getSource()},
625         {broadcastOp.getSource().getType()});
626     rewriter.setInsertionPointAfter(newWarpOp);
627     Value broadcasted = rewriter.create<vector::BroadcastOp>(
628         loc, destVecType, newWarpOp->getResults().back());
629     newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
630 
631     return success();
632   }
633 };
634 
635 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
636 /// the scf.ForOp is the last operation in the region so that it doesn't change
637 /// the order of execution. This creates a new scf.for region after the
638 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
639 /// WarpExecuteOnLane0Op region. Example:
640 /// ```
641 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
642 ///   ...
643 ///   %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
644 ///   -> (vector<128xf32>) {
645 ///     ...
646 ///     scf.yield %r : vector<128xf32>
647 ///   }
648 ///   vector.yield %v1 : vector<128xf32>
649 /// }
650 /// ```
651 /// To:
652 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
653 ///   ...
654 ///   vector.yield %v : vector<128xf32>
655 /// }
656 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
657 ///   -> (vector<4xf32>) {
658 ///     %iw = vector.warp_execute_on_lane_0(%laneid)
659 ///     args(%varg : vector<4xf32>) -> (vector<4xf32>) {
660 ///     ^bb0(%arg: vector<128xf32>):
661 ///       ...
662 ///       vector.yield %ir : vector<128xf32>
663 ///     }
664 ///     scf.yield %iw : vector<4xf32>
665 ///  }
666 /// ```
667 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
668   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
669   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
670                                 PatternRewriter &rewriter) const override {
671     auto yield = cast<vector::YieldOp>(
672         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
673     // Only pick up forOp if it is the last op in the region.
674     Operation *lastNode = yield->getPrevNode();
675     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
676     if (!forOp)
677       return failure();
678     SmallVector<Value> newOperands;
679     SmallVector<unsigned> resultIdx;
680     // Collect all the outputs coming from the forOp.
681     for (OpOperand &yieldOperand : yield->getOpOperands()) {
682       if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
683         continue;
684       auto forResult = yieldOperand.get().cast<OpResult>();
685       newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber()));
686       yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
687       resultIdx.push_back(yieldOperand.getOperandNumber());
688     }
689     OpBuilder::InsertionGuard g(rewriter);
690     rewriter.setInsertionPointAfter(warpOp);
691     // Create a new for op outside the region with a WarpExecuteOnLane0Op region
692     // inside.
693     auto newForOp = rewriter.create<scf::ForOp>(
694         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
695         forOp.getStep(), newOperands);
696     rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
697     auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
698         warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(),
699         warpOp.getWarpSize(), newForOp.getRegionIterArgs(),
700         forOp.getResultTypes());
701 
702     SmallVector<Value> argMapping;
703     argMapping.push_back(newForOp.getInductionVar());
704     for (Value args : innerWarp.getBody()->getArguments()) {
705       argMapping.push_back(args);
706     }
707     SmallVector<Value> yieldOperands;
708     for (Value operand : forOp.getBody()->getTerminator()->getOperands())
709       yieldOperands.push_back(operand);
710     rewriter.eraseOp(forOp.getBody()->getTerminator());
711     rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
712     rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
713     rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
714     rewriter.setInsertionPointAfter(innerWarp);
715     rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
716     rewriter.eraseOp(forOp);
717     // Replace the warpOp result coming from the original ForOp.
718     for (const auto &res : llvm::enumerate(resultIdx)) {
719       warpOp.getResult(res.value())
720           .replaceAllUsesWith(newForOp.getResult(res.index()));
721       newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value()));
722     }
723     return success();
724   }
725 };
726 
727 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
728 /// The vector is reduced in parallel. Currently limited to vector<32x...>
729 /// values. Every lane reduces two scalars, 5 times in a row.
730 /// E.g.:
731 /// ```
732 /// %r = vector_ext.warp_execute_on_lane_0(%laneid) -> (f32) {
733 ///   %0 = "some_def"() : () -> (vector<32xf32>)
734 ///   %1 = vector.reduction "add", %0 : vector<32xf32> into f32
735 ///   vector_ext.yield %1 : f32
736 /// }
737 /// ```
738 /// is lowered to:
739 /// ```
740 /// %0 = vector_ext.warp_execute_on_lane_0(%laneid) -> (vector<1xf32>) {
741 ///   %1 = "some_def"() : () -> (vector<32xf32>)
742 ///   vector_ext.yield %1 : vector<32xf32>
743 /// }
744 /// %a = vector.extract %0[0] : vector<1xf32>
745 /// %r0, %s0 = gpu.shuffle xor %e, %c1, %c32 : f32
746 /// %a0 = arith.addf %a, %r0 : f32
747 /// %r1, %s1 = gpu.shuffle xor %a0, %c2, %c32 : f32
748 /// %a1 = arith.addf %a0, %r1 : f32
749 /// ...
750 /// %r4, %s4 = gpu.shuffle xor %a3, %c16, %c32 : f32
751 /// %r = arith.addf %a3, %r4 : f32
752 /// ```
753 struct ReductionToGPUWarpShuffle
754     : public OpRewritePattern<WarpExecuteOnLane0Op> {
755   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
756 
757   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
758                                 PatternRewriter &rewriter) const override {
759     OpOperand *yieldOperand = getWarpResult(
760         warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); });
761     if (!yieldOperand)
762       return failure();
763 
764     auto reductionOp =
765         cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
766     auto vectorType = reductionOp.getVector().getType().cast<VectorType>();
767     // Only rank 1 vectors supported.
768     if (vectorType.getRank() != 1)
769       return rewriter.notifyMatchFailure(
770           warpOp, "Only rank 1 reductions can be distributed.");
771     // Only warp_size-sized vectors supported.
772     if (static_cast<uint64_t>(vectorType.getShape()[0]) != warpOp.getWarpSize())
773       return rewriter.notifyMatchFailure(
774           warpOp, "Reduction vector dimension must match was size.");
775     // Only f32 and i32 element types are supported.
776     if (!reductionOp.getType().isF32() &&
777         !reductionOp.getType().isSignlessInteger(32))
778       return rewriter.notifyMatchFailure(
779           warpOp,
780           "Reduction distribution currently only supports 32bits types.");
781 
782     Location yieldLoc = yieldOperand->getOwner()->getLoc();
783 
784     // Return vector that will be reduced from the WarpExecuteOnLane0Op.
785     unsigned operandIndex = yieldOperand->getOperandNumber();
786     SmallVector<Value> yieldValues = {reductionOp.getVector()};
787     SmallVector<Type> retTypes = {VectorType::get({1}, reductionOp.getType())};
788     unsigned numResults = warpOp.getNumResults();
789     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
790         rewriter, warpOp, yieldValues, retTypes);
791     rewriter.setInsertionPointAfter(newWarpOp);
792 
793     // Every lane has one scalar value. These should be reduced.
794     Value laneValVec = newWarpOp.getResult(numResults);
795     Value laneVal = rewriter.create<vector::ExtractOp>(yieldLoc, laneValVec, 0);
796 
797     // Parallel reduction using butterfly shuffles.
798     for (uint64_t i = 1; i < newWarpOp.getWarpSize(); i <<= 1) {
799       Value shuffled =
800           rewriter
801               .create<gpu::ShuffleOp>(reductionOp.getLoc(), laneVal, i,
802                                       /*width=*/newWarpOp.getWarpSize(),
803                                       /*mode=*/gpu::ShuffleMode::XOR)
804               .result();
805       laneVal = makeArithReduction(rewriter, reductionOp.getLoc(),
806                                    reductionOp.getKind(), laneVal, shuffled);
807     }
808 
809     newWarpOp.getResult(operandIndex).replaceAllUsesWith(laneVal);
810     return success();
811   }
812 };
813 
814 } // namespace
815 
816 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
817     RewritePatternSet &patterns,
818     const WarpExecuteOnLane0LoweringOptions &options) {
819   patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options);
820 }
821 
822 void mlir::vector::populateDistributeTransferWriteOpPatterns(
823     RewritePatternSet &patterns, DistributionMapFn distributionMapFn) {
824   patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn);
825 }
826 
827 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
828     RewritePatternSet &patterns) {
829   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
830                WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp>(
831       patterns.getContext());
832 }
833 
834 void mlir::vector::populateReductionToGPUWarpShufflePatterns(
835     RewritePatternSet &patterns) {
836   patterns.add<ReductionToGPUWarpShuffle>(patterns.getContext());
837 }
838 
839 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
840   Block *body = warpOp.getBody();
841 
842   // Keep track of the ops we want to hoist.
843   llvm::SmallSetVector<Operation *, 8> opsToMove;
844 
845   // Helper to check if a value is or will be defined outside of the region.
846   auto isDefinedOutsideOfBody = [&](Value value) {
847     auto *definingOp = value.getDefiningOp();
848     return (definingOp && opsToMove.count(definingOp)) ||
849            warpOp.isDefinedOutsideOfRegion(value);
850   };
851 
852   // Do not use walk here, as we do not want to go into nested regions and hoist
853   // operations from there.
854   for (auto &op : body->without_terminator()) {
855     bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
856       return result.getType().isa<VectorType>();
857     });
858     if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
859       opsToMove.insert(&op);
860   }
861 
862   // Move all the ops marked as uniform outside of the region.
863   for (Operation *op : opsToMove)
864     op->moveBefore(warpOp);
865 }
866