1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
14 #include "mlir/IR/BlockAndValueMapping.h"
15 #include "mlir/Transforms/SideEffectUtils.h"
16 
17 using namespace mlir;
18 using namespace mlir::vector;
19 
20 static LogicalResult
21 rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
22                       const WarpExecuteOnLane0LoweringOptions &options) {
23   assert(warpOp.getBodyRegion().hasOneBlock() &&
24          "expected WarpOp with single block");
25   Block *warpOpBody = &warpOp.getBodyRegion().front();
26   Location loc = warpOp.getLoc();
27 
28   // Passed all checks. Start rewriting.
29   OpBuilder::InsertionGuard g(rewriter);
30   rewriter.setInsertionPoint(warpOp);
31 
32   // Create scf.if op.
33   Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
34   Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
35                                                  warpOp.getLaneid(), c0);
36   auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
37                                          /*withElseRegion=*/false);
38   rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
39 
40   // Store vectors that are defined outside of warpOp into the scratch pad
41   // buffer.
42   SmallVector<Value> bbArgReplacements;
43   for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
44     Value val = it.value();
45     Value bbArg = warpOpBody->getArgument(it.index());
46 
47     rewriter.setInsertionPoint(ifOp);
48     Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp,
49                                             bbArg.getType());
50 
51     // Store arg vector into buffer.
52     rewriter.setInsertionPoint(ifOp);
53     auto vectorType = val.getType().cast<VectorType>();
54     int64_t storeSize = vectorType.getShape()[0];
55     Value storeOffset = rewriter.create<arith::MulIOp>(
56         loc, warpOp.getLaneid(),
57         rewriter.create<arith::ConstantIndexOp>(loc, storeSize));
58     rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset);
59 
60     // Load bbArg vector from buffer.
61     rewriter.setInsertionPointToStart(ifOp.thenBlock());
62     auto bbArgType = bbArg.getType().cast<VectorType>();
63     Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0);
64     bbArgReplacements.push_back(loadOp);
65   }
66 
67   // Insert sync after all the stores and before all the loads.
68   if (!warpOp.getArgs().empty()) {
69     rewriter.setInsertionPoint(ifOp);
70     options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp);
71   }
72 
73   // Move body of warpOp to ifOp.
74   rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
75 
76   // Rewrite terminator and compute replacements of WarpOp results.
77   SmallVector<Value> replacements;
78   auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
79   Location yieldLoc = yieldOp.getLoc();
80   for (const auto &it : llvm::enumerate(yieldOp.operands())) {
81     Value val = it.value();
82     Type resultType = warpOp->getResultTypes()[it.index()];
83     rewriter.setInsertionPoint(ifOp);
84     Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp,
85                                             val.getType());
86 
87     // Store yielded value into buffer.
88     rewriter.setInsertionPoint(yieldOp);
89     if (val.getType().isa<VectorType>())
90       rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0);
91     else
92       rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0);
93 
94     // Load value from buffer (after warpOp).
95     rewriter.setInsertionPointAfter(ifOp);
96     if (resultType == val.getType()) {
97       // Result type and yielded value type are the same. This is a broadcast.
98       // E.g.:
99       // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
100       //   vector.yield %cst : f32
101       // }
102       // Both types are f32. The constant %cst is broadcasted to all lanes.
103       // This is described in more detail in the documentation of the op.
104       Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0);
105       replacements.push_back(loadOp);
106     } else {
107       auto loadedVectorType = resultType.cast<VectorType>();
108       int64_t loadSize = loadedVectorType.getShape()[0];
109 
110       // loadOffset = laneid * loadSize
111       Value loadOffset = rewriter.create<arith::MulIOp>(
112           loc, warpOp.getLaneid(),
113           rewriter.create<arith::ConstantIndexOp>(loc, loadSize));
114       Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType,
115                                                      buffer, loadOffset);
116       replacements.push_back(loadOp);
117     }
118   }
119 
120   // Insert sync after all the stores and before all the loads.
121   if (!yieldOp.operands().empty()) {
122     rewriter.setInsertionPointAfter(ifOp);
123     options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp);
124   }
125 
126   // Delete terminator and add empty scf.yield.
127   rewriter.eraseOp(yieldOp);
128   rewriter.setInsertionPointToEnd(ifOp.thenBlock());
129   rewriter.create<scf::YieldOp>(yieldLoc);
130 
131   // Compute replacements for WarpOp results.
132   rewriter.replaceOp(warpOp, replacements);
133 
134   return success();
135 }
136 
137 /// Helper to create a new WarpExecuteOnLane0Op with different signature.
138 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
139     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
140     ValueRange newYieldedValues, TypeRange newReturnTypes) {
141   // Create a new op before the existing one, with the extra operands.
142   OpBuilder::InsertionGuard g(rewriter);
143   rewriter.setInsertionPoint(warpOp);
144   auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
145       warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
146       warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
147 
148   Region &opBody = warpOp.getBodyRegion();
149   Region &newOpBody = newWarpOp.getBodyRegion();
150   rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
151   auto yield =
152       cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
153 
154   rewriter.updateRootInPlace(
155       yield, [&]() { yield.operandsMutable().assign(newYieldedValues); });
156   return newWarpOp;
157 }
158 
159 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
160 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
161     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
162     ValueRange newYieldedValues, TypeRange newReturnTypes) {
163   SmallVector<Type> types(warpOp.getResultTypes().begin(),
164                           warpOp.getResultTypes().end());
165   types.append(newReturnTypes.begin(), newReturnTypes.end());
166   auto yield = cast<vector::YieldOp>(
167       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
168   SmallVector<Value> yieldValues(yield.getOperands().begin(),
169                                  yield.getOperands().end());
170   yieldValues.append(newYieldedValues.begin(), newYieldedValues.end());
171   WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
172       rewriter, warpOp, yieldValues, types);
173   rewriter.replaceOp(warpOp,
174                      newWarpOp.getResults().take_front(warpOp.getNumResults()));
175   return newWarpOp;
176 }
177 
178 /// Helper to know if an op can be hoisted out of the region.
179 static bool canBeHoisted(Operation *op,
180                          function_ref<bool(Value)> definedOutside) {
181   return llvm::all_of(op->getOperands(), definedOutside) &&
182          isSideEffectFree(op) && op->getNumRegions() == 0;
183 }
184 
185 /// Return a value yielded by `warpOp` which statifies the filter lamdba
186 /// condition and is not dead.
187 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
188                                 std::function<bool(Operation *)> fn) {
189   auto yield = cast<vector::YieldOp>(
190       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
191   for (OpOperand &yieldOperand : yield->getOpOperands()) {
192     Value yieldValues = yieldOperand.get();
193     Operation *definedOp = yieldValues.getDefiningOp();
194     if (definedOp && fn(definedOp)) {
195       if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
196         return &yieldOperand;
197     }
198   }
199   return {};
200 }
201 
202 // Clones `op` into a new operation that takes `operands` and returns
203 // `resultTypes`.
204 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
205                                               Location loc, Operation *op,
206                                               ArrayRef<Value> operands,
207                                               ArrayRef<Type> resultTypes) {
208   OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
209                      op->getAttrs());
210   return rewriter.create(res);
211 }
212 
213 /// Currently the distribution map is implicit based on the vector shape. In the
214 /// future it will be part of the op.
215 /// Example:
216 /// ```
217 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
218 ///   ...
219 ///   vector.yield %3 : vector<32x16x64xf32>
220 /// }
221 /// ```
222 /// Would have an implicit map of:
223 /// `(d0, d1, d2) -> (d0, d2)`
224 static AffineMap calculateImplicitMap(Value yield, Value ret) {
225   auto srcType = yield.getType().cast<VectorType>();
226   auto dstType = ret.getType().cast<VectorType>();
227   SmallVector<AffineExpr> perm;
228   // Check which dimensions of the yield value are different than the dimensions
229   // of the result to know the distributed dimensions. Then associate each
230   // distributed dimension to an ID in order.
231   for (unsigned i = 0, e = srcType.getRank(); i < e; i++) {
232     if (srcType.getDimSize(i) != dstType.getDimSize(i))
233       perm.push_back(getAffineDimExpr(i, yield.getContext()));
234   }
235   auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext());
236   return map;
237 }
238 
239 namespace {
240 
241 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
242   WarpOpToScfForPattern(MLIRContext *context,
243                         const WarpExecuteOnLane0LoweringOptions &options,
244                         PatternBenefit benefit = 1)
245       : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
246         options(options) {}
247 
248   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
249                                 PatternRewriter &rewriter) const override {
250     return rewriteWarpOpToScfFor(rewriter, warpOp, options);
251   }
252 
253 private:
254   const WarpExecuteOnLane0LoweringOptions &options;
255 };
256 
257 /// Distribute transfer_write ops based on the affine map returned by
258 /// `distributionMapFn`.
259 /// Example:
260 /// ```
261 /// %0 = vector.warp_execute_on_lane_0(%id){
262 ///   ...
263 ///   vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
264 ///   vector.yield
265 /// }
266 /// ```
267 /// To
268 /// ```
269 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
270 ///   ...
271 ///   vector.yield %v : vector<32xf32>
272 /// }
273 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
274 struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
275   WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
276                       PatternBenefit b = 1)
277       : OpRewritePattern<vector::TransferWriteOp>(ctx, b),
278         distributionMapFn(fn) {}
279 
280   /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
281   /// are multiples of the distribution ratio are supported at the moment.
282   LogicalResult tryDistributeOp(RewriterBase &rewriter,
283                                 vector::TransferWriteOp writeOp,
284                                 WarpExecuteOnLane0Op warpOp) const {
285     AffineMap map = distributionMapFn(writeOp);
286     SmallVector<int64_t> targetShape(writeOp.getVectorType().getShape().begin(),
287                                      writeOp.getVectorType().getShape().end());
288     assert(map.getNumResults() == 1 &&
289            "multi-dim distribution not implemented yet");
290     for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
291       unsigned position = map.getDimPosition(i);
292       if (targetShape[position] % warpOp.getWarpSize() != 0)
293         return failure();
294       targetShape[position] = targetShape[position] / warpOp.getWarpSize();
295     }
296     VectorType targetType =
297         VectorType::get(targetShape, writeOp.getVectorType().getElementType());
298 
299     SmallVector<Value> yieldValues = {writeOp.getVector()};
300     SmallVector<Type> retTypes = {targetType};
301     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
302         rewriter, warpOp, yieldValues, retTypes);
303     rewriter.setInsertionPointAfter(newWarpOp);
304 
305     // Move op outside of region: Insert clone at the insertion point and delete
306     // the old op.
307     auto newWriteOp =
308         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
309     rewriter.eraseOp(writeOp);
310 
311     rewriter.setInsertionPoint(newWriteOp);
312     AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
313     Location loc = newWriteOp.getLoc();
314     SmallVector<Value> indices(newWriteOp.getIndices().begin(),
315                                newWriteOp.getIndices().end());
316     for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
317       AffineExpr d0, d1;
318       bindDims(newWarpOp.getContext(), d0, d1);
319       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
320       if (!indexExpr)
321         continue;
322       unsigned indexPos = indexExpr.getPosition();
323       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
324       auto scale =
325           getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext());
326       indices[indexPos] =
327           makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
328                                   {indices[indexPos], newWarpOp.getLaneid()});
329     }
330     newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
331     newWriteOp.getIndicesMutable().assign(indices);
332 
333     return success();
334   }
335 
336   /// Extract TransferWriteOps of vector<1x> into a separate warp op.
337   LogicalResult tryExtractOp(RewriterBase &rewriter,
338                              vector::TransferWriteOp writeOp,
339                              WarpExecuteOnLane0Op warpOp) const {
340     Location loc = writeOp.getLoc();
341     VectorType vecType = writeOp.getVectorType();
342 
343     // Only vector<1x> is supported at the moment.
344     if (vecType.getShape().size() != 1 || vecType.getShape()[0] != 1)
345       return failure();
346 
347     // Do not process warp ops that contain only TransferWriteOps.
348     if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
349           return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
350         }))
351       return failure();
352 
353     SmallVector<Value> yieldValues = {writeOp.getVector()};
354     SmallVector<Type> retTypes = {vecType};
355     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
356         rewriter, warpOp, yieldValues, retTypes);
357     rewriter.setInsertionPointAfter(newWarpOp);
358 
359     // Create a second warp op that contains only writeOp.
360     auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
361         loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
362     Block &body = secondWarpOp.getBodyRegion().front();
363     rewriter.setInsertionPointToStart(&body);
364     auto newWriteOp =
365         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
366     newWriteOp.getVectorMutable().assign(
367         newWarpOp.getResult(newWarpOp.getNumResults() - 1));
368     rewriter.eraseOp(writeOp);
369     rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
370     return success();
371   }
372 
373   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
374                                 PatternRewriter &rewriter) const override {
375     // Ops with mask not supported yet.
376     if (writeOp.getMask())
377       return failure();
378 
379     auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
380     if (!warpOp)
381       return failure();
382 
383     // There must be no op with a side effect after writeOp.
384     Operation *nextOp = writeOp.getOperation();
385     while ((nextOp = nextOp->getNextNode()))
386       if (!isSideEffectFree(nextOp))
387         return failure();
388 
389     if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
390           return writeOp.getVector() == value ||
391                  warpOp.isDefinedOutsideOfRegion(value);
392         }))
393       return failure();
394 
395     if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
396       return success();
397 
398     if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
399       return success();
400 
401     return failure();
402   }
403 
404 private:
405   DistributionMapFn distributionMapFn;
406 };
407 
408 /// Sink out elementwise op feeding into a warp op yield.
409 /// ```
410 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
411 ///   ...
412 ///   %3 = arith.addf %1, %2 : vector<32xf32>
413 ///   vector.yield %3 : vector<32xf32>
414 /// }
415 /// ```
416 /// To
417 /// ```
418 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
419 /// vector<1xf32>, vector<1xf32>) {
420 ///   ...
421 ///   %4 = arith.addf %2, %3 : vector<32xf32>
422 ///   vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
423 ///   vector<32xf32>
424 /// }
425 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
426 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
427   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
428   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
429                                 PatternRewriter &rewriter) const override {
430     OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
431       return OpTrait::hasElementwiseMappableTraits(op);
432     });
433     if (!yieldOperand)
434       return failure();
435     Operation *elementWise = yieldOperand->get().getDefiningOp();
436     unsigned operandIndex = yieldOperand->getOperandNumber();
437     Value distributedVal = warpOp.getResult(operandIndex);
438     SmallVector<Value> yieldValues;
439     SmallVector<Type> retTypes;
440     Location loc = warpOp.getLoc();
441     for (OpOperand &operand : elementWise->getOpOperands()) {
442       Type targetType;
443       if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) {
444         // If the result type is a vector, the operands must also be vectors.
445         auto operandType = operand.get().getType().cast<VectorType>();
446         targetType =
447             VectorType::get(vecType.getShape(), operandType.getElementType());
448       } else {
449         auto operandType = operand.get().getType();
450         assert(!operandType.isa<VectorType>() &&
451                "unexpected yield of vector from op with scalar result type");
452         targetType = operandType;
453       }
454       retTypes.push_back(targetType);
455       yieldValues.push_back(operand.get());
456     }
457     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
458         rewriter, warpOp, yieldValues, retTypes);
459     rewriter.setInsertionPointAfter(newWarpOp);
460     SmallVector<Value> newOperands(elementWise->getOperands().begin(),
461                                    elementWise->getOperands().end());
462     for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
463       newOperands[i] = newWarpOp.getResult(i + warpOp.getNumResults());
464     }
465     OpBuilder::InsertionGuard g(rewriter);
466     rewriter.setInsertionPointAfter(newWarpOp);
467     Operation *newOp = cloneOpWithOperandsAndTypes(
468         rewriter, loc, elementWise, newOperands,
469         {newWarpOp.getResult(operandIndex).getType()});
470     newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0));
471     return success();
472   }
473 };
474 
475 /// Sink out transfer_read op feeding into a warp op yield.
476 /// ```
477 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
478 ///   ...
479 //    %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
480 //    vector<32xf32>
481 ///   vector.yield %2 : vector<32xf32>
482 /// }
483 /// ```
484 /// To
485 /// ```
486 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
487 /// vector<1xf32>, vector<1xf32>) {
488 ///   ...
489 ///   %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
490 ///   vector<32xf32> vector.yield %2 : vector<32xf32>
491 /// }
492 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
493 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
494   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
495   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
496                                 PatternRewriter &rewriter) const override {
497     OpOperand *operand = getWarpResult(
498         warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
499     if (!operand)
500       return failure();
501     auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
502     unsigned operandIndex = operand->getOperandNumber();
503     Value distributedVal = warpOp.getResult(operandIndex);
504 
505     SmallVector<Value, 4> indices(read.getIndices().begin(),
506                                   read.getIndices().end());
507     AffineMap map = calculateImplicitMap(read.getResult(), distributedVal);
508     AffineMap indexMap = map.compose(read.getPermutationMap());
509     OpBuilder::InsertionGuard g(rewriter);
510     rewriter.setInsertionPointAfter(warpOp);
511     for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
512       AffineExpr d0, d1;
513       bindDims(read.getContext(), d0, d1);
514       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
515       if (!indexExpr)
516         continue;
517       unsigned indexPos = indexExpr.getPosition();
518       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
519       int64_t scale =
520           distributedVal.getType().cast<VectorType>().getDimSize(vectorPos);
521       indices[indexPos] =
522           makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1,
523                                   {indices[indexPos], warpOp.getLaneid()});
524     }
525     Value newRead = rewriter.create<vector::TransferReadOp>(
526         read.getLoc(), distributedVal.getType(), read.getSource(), indices,
527         read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
528         read.getInBoundsAttr());
529     distributedVal.replaceAllUsesWith(newRead);
530     return success();
531   }
532 };
533 
534 /// Remove any result that has no use along with the matching yieldOp operand.
535 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
536 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
537   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
538   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
539                                 PatternRewriter &rewriter) const override {
540     SmallVector<Type> resultTypes;
541     SmallVector<Value> yieldValues;
542     auto yield = cast<vector::YieldOp>(
543         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
544     for (OpResult result : warpOp.getResults()) {
545       if (result.use_empty())
546         continue;
547       resultTypes.push_back(result.getType());
548       yieldValues.push_back(yield.getOperand(result.getResultNumber()));
549     }
550     if (yield.getNumOperands() == yieldValues.size())
551       return failure();
552     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
553         rewriter, warpOp, yieldValues, resultTypes);
554     unsigned resultIndex = 0;
555     for (OpResult result : warpOp.getResults()) {
556       if (result.use_empty())
557         continue;
558       result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++));
559     }
560     rewriter.eraseOp(warpOp);
561     return success();
562   }
563 };
564 
565 // If an operand is directly yielded out of the region we can forward it
566 // directly and it doesn't need to go through the region.
567 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
568   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
569   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
570                                 PatternRewriter &rewriter) const override {
571     SmallVector<Type> resultTypes;
572     SmallVector<Value> yieldValues;
573     auto yield = cast<vector::YieldOp>(
574         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
575     Value valForwarded;
576     unsigned resultIndex;
577     for (OpOperand &operand : yield->getOpOperands()) {
578       Value result = warpOp.getResult(operand.getOperandNumber());
579       if (result.use_empty())
580         continue;
581 
582       // Assume all the values coming from above are uniform.
583       if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
584         if (result.getType() != operand.get().getType())
585           continue;
586         valForwarded = operand.get();
587         resultIndex = operand.getOperandNumber();
588         break;
589       }
590       auto arg = operand.get().dyn_cast<BlockArgument>();
591       if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
592         continue;
593       Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
594       if (result.getType() != warpOperand.getType())
595         continue;
596       valForwarded = warpOperand;
597       resultIndex = operand.getOperandNumber();
598       break;
599     }
600     if (!valForwarded)
601       return failure();
602     warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded);
603     return success();
604   }
605 };
606 
607 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
608   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
609   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
610                                 PatternRewriter &rewriter) const override {
611     OpOperand *operand = getWarpResult(
612         warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); });
613     if (!operand)
614       return failure();
615     unsigned int operandNumber = operand->getOperandNumber();
616     auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
617     Location loc = broadcastOp.getLoc();
618     auto destVecType =
619         warpOp->getResultTypes()[operandNumber].cast<VectorType>();
620     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
621         rewriter, warpOp, {broadcastOp.getSource()},
622         {broadcastOp.getSource().getType()});
623     rewriter.setInsertionPointAfter(newWarpOp);
624     Value broadcasted = rewriter.create<vector::BroadcastOp>(
625         loc, destVecType, newWarpOp->getResults().back());
626     newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
627 
628     return success();
629   }
630 };
631 
632 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
633 /// the scf.ForOp is the last operation in the region so that it doesn't change
634 /// the order of execution. This creates a new scf.for region after the
635 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
636 /// WarpExecuteOnLane0Op region. Example:
637 /// ```
638 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
639 ///   ...
640 ///   %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
641 ///   -> (vector<128xf32>) {
642 ///     ...
643 ///     scf.yield %r : vector<128xf32>
644 ///   }
645 ///   vector.yield %v1 : vector<128xf32>
646 /// }
647 /// ```
648 /// To:
649 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
650 ///   ...
651 ///   vector.yield %v : vector<128xf32>
652 /// }
653 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
654 ///   -> (vector<4xf32>) {
655 ///     %iw = vector.warp_execute_on_lane_0(%laneid)
656 ///     args(%varg : vector<4xf32>) -> (vector<4xf32>) {
657 ///     ^bb0(%arg: vector<128xf32>):
658 ///       ...
659 ///       vector.yield %ir : vector<128xf32>
660 ///     }
661 ///     scf.yield %iw : vector<4xf32>
662 ///  }
663 /// ```
664 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
665   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
666   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
667                                 PatternRewriter &rewriter) const override {
668     auto yield = cast<vector::YieldOp>(
669         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
670     // Only pick up forOp if it is the last op in the region.
671     Operation *lastNode = yield->getPrevNode();
672     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
673     if (!forOp)
674       return failure();
675     SmallVector<Value> newOperands;
676     SmallVector<unsigned> resultIdx;
677     // Collect all the outputs coming from the forOp.
678     for (OpOperand &yieldOperand : yield->getOpOperands()) {
679       if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
680         continue;
681       auto forResult = yieldOperand.get().cast<OpResult>();
682       newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber()));
683       yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
684       resultIdx.push_back(yieldOperand.getOperandNumber());
685     }
686     OpBuilder::InsertionGuard g(rewriter);
687     rewriter.setInsertionPointAfter(warpOp);
688     // Create a new for op outside the region with a WarpExecuteOnLane0Op region
689     // inside.
690     auto newForOp = rewriter.create<scf::ForOp>(
691         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
692         forOp.getStep(), newOperands);
693     rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
694     auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
695         warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(),
696         warpOp.getWarpSize(), newForOp.getRegionIterArgs(),
697         forOp.getResultTypes());
698 
699     SmallVector<Value> argMapping;
700     argMapping.push_back(newForOp.getInductionVar());
701     for (Value args : innerWarp.getBody()->getArguments()) {
702       argMapping.push_back(args);
703     }
704     SmallVector<Value> yieldOperands;
705     for (Value operand : forOp.getBody()->getTerminator()->getOperands())
706       yieldOperands.push_back(operand);
707     rewriter.eraseOp(forOp.getBody()->getTerminator());
708     rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
709     rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
710     rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
711     rewriter.setInsertionPointAfter(innerWarp);
712     rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
713     rewriter.eraseOp(forOp);
714     // Replace the warpOp result coming from the original ForOp.
715     for (const auto &res : llvm::enumerate(resultIdx)) {
716       warpOp.getResult(res.value())
717           .replaceAllUsesWith(newForOp.getResult(res.index()));
718       newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value()));
719     }
720     return success();
721   }
722 };
723 
724 } // namespace
725 
726 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
727     RewritePatternSet &patterns,
728     const WarpExecuteOnLane0LoweringOptions &options) {
729   patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options);
730 }
731 
732 void mlir::vector::populateDistributeTransferWriteOpPatterns(
733     RewritePatternSet &patterns, DistributionMapFn distributionMapFn) {
734   patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn);
735 }
736 
737 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
738     RewritePatternSet &patterns) {
739   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
740                WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp>(
741       patterns.getContext());
742 }
743 
744 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
745   Block *body = warpOp.getBody();
746 
747   // Keep track of the ops we want to hoist.
748   llvm::SmallSetVector<Operation *, 8> opsToMove;
749 
750   // Helper to check if a value is or will be defined outside of the region.
751   auto isDefinedOutsideOfBody = [&](Value value) {
752     auto *definingOp = value.getDefiningOp();
753     return (definingOp && opsToMove.count(definingOp)) ||
754            warpOp.isDefinedOutsideOfRegion(value);
755   };
756 
757   // Do not use walk here, as we do not want to go into nested regions and hoist
758   // operations from there.
759   for (auto &op : body->without_terminator()) {
760     bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
761       return result.getType().isa<VectorType>();
762     });
763     if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
764       opsToMove.insert(&op);
765   }
766 
767   // Move all the ops marked as uniform outside of the region.
768   for (Operation *op : opsToMove)
769     op->moveBefore(warpOp);
770 }
771