1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
14 #include "mlir/IR/BlockAndValueMapping.h"
15 #include "mlir/Transforms/SideEffectUtils.h"
16 
17 using namespace mlir;
18 using namespace mlir::vector;
19 
20 static LogicalResult
21 rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
22                       const WarpExecuteOnLane0LoweringOptions &options) {
23   assert(warpOp.getBodyRegion().hasOneBlock() &&
24          "expected WarpOp with single block");
25   Block *warpOpBody = &warpOp.getBodyRegion().front();
26   Location loc = warpOp.getLoc();
27 
28   // Passed all checks. Start rewriting.
29   OpBuilder::InsertionGuard g(rewriter);
30   rewriter.setInsertionPoint(warpOp);
31 
32   // Create scf.if op.
33   Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
34   Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
35                                                  warpOp.getLaneid(), c0);
36   auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
37                                          /*withElseRegion=*/false);
38   rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
39 
40   // Store vectors that are defined outside of warpOp into the scratch pad
41   // buffer.
42   SmallVector<Value> bbArgReplacements;
43   for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
44     Value val = it.value();
45     Value bbArg = warpOpBody->getArgument(it.index());
46 
47     rewriter.setInsertionPoint(ifOp);
48     Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp,
49                                             bbArg.getType());
50 
51     // Store arg vector into buffer.
52     rewriter.setInsertionPoint(ifOp);
53     auto vectorType = val.getType().cast<VectorType>();
54     int64_t storeSize = vectorType.getShape()[0];
55     Value storeOffset = rewriter.create<arith::MulIOp>(
56         loc, warpOp.getLaneid(),
57         rewriter.create<arith::ConstantIndexOp>(loc, storeSize));
58     rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset);
59 
60     // Load bbArg vector from buffer.
61     rewriter.setInsertionPointToStart(ifOp.thenBlock());
62     auto bbArgType = bbArg.getType().cast<VectorType>();
63     Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0);
64     bbArgReplacements.push_back(loadOp);
65   }
66 
67   // Insert sync after all the stores and before all the loads.
68   if (!warpOp.getArgs().empty()) {
69     rewriter.setInsertionPoint(ifOp);
70     options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp);
71   }
72 
73   // Move body of warpOp to ifOp.
74   rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
75 
76   // Rewrite terminator and compute replacements of WarpOp results.
77   SmallVector<Value> replacements;
78   auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
79   Location yieldLoc = yieldOp.getLoc();
80   for (const auto &it : llvm::enumerate(yieldOp.operands())) {
81     Value val = it.value();
82     Type resultType = warpOp->getResultTypes()[it.index()];
83     rewriter.setInsertionPoint(ifOp);
84     Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp,
85                                             val.getType());
86 
87     // Store yielded value into buffer.
88     rewriter.setInsertionPoint(yieldOp);
89     if (val.getType().isa<VectorType>())
90       rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0);
91     else
92       rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0);
93 
94     // Load value from buffer (after warpOp).
95     rewriter.setInsertionPointAfter(ifOp);
96     if (resultType == val.getType()) {
97       // Result type and yielded value type are the same. This is a broadcast.
98       // E.g.:
99       // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
100       //   vector.yield %cst : f32
101       // }
102       // Both types are f32. The constant %cst is broadcasted to all lanes.
103       // This is described in more detail in the documentation of the op.
104       Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0);
105       replacements.push_back(loadOp);
106     } else {
107       auto loadedVectorType = resultType.cast<VectorType>();
108       int64_t loadSize = loadedVectorType.getShape()[0];
109 
110       // loadOffset = laneid * loadSize
111       Value loadOffset = rewriter.create<arith::MulIOp>(
112           loc, warpOp.getLaneid(),
113           rewriter.create<arith::ConstantIndexOp>(loc, loadSize));
114       Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType,
115                                                      buffer, loadOffset);
116       replacements.push_back(loadOp);
117     }
118   }
119 
120   // Insert sync after all the stores and before all the loads.
121   if (!yieldOp.operands().empty()) {
122     rewriter.setInsertionPointAfter(ifOp);
123     options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp);
124   }
125 
126   // Delete terminator and add empty scf.yield.
127   rewriter.eraseOp(yieldOp);
128   rewriter.setInsertionPointToEnd(ifOp.thenBlock());
129   rewriter.create<scf::YieldOp>(yieldLoc);
130 
131   // Compute replacements for WarpOp results.
132   rewriter.replaceOp(warpOp, replacements);
133 
134   return success();
135 }
136 
137 /// Helper to create a new WarpExecuteOnLane0Op with different signature.
138 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
139     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
140     ValueRange newYieldedValues, TypeRange newReturnTypes) {
141   // Create a new op before the existing one, with the extra operands.
142   OpBuilder::InsertionGuard g(rewriter);
143   rewriter.setInsertionPoint(warpOp);
144   auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
145       warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
146       warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
147 
148   Region &opBody = warpOp.getBodyRegion();
149   Region &newOpBody = newWarpOp.getBodyRegion();
150   rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
151   auto yield =
152       cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
153 
154   rewriter.updateRootInPlace(
155       yield, [&]() { yield.operandsMutable().assign(newYieldedValues); });
156   return newWarpOp;
157 }
158 
159 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
160 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
161     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
162     ValueRange newYieldedValues, TypeRange newReturnTypes) {
163   SmallVector<Type> types(warpOp.getResultTypes().begin(),
164                           warpOp.getResultTypes().end());
165   types.append(newReturnTypes.begin(), newReturnTypes.end());
166   auto yield = cast<vector::YieldOp>(
167       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
168   SmallVector<Value> yieldValues(yield.getOperands().begin(),
169                                  yield.getOperands().end());
170   yieldValues.append(newYieldedValues.begin(), newYieldedValues.end());
171   WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
172       rewriter, warpOp, yieldValues, types);
173   rewriter.replaceOp(warpOp,
174                      newWarpOp.getResults().take_front(warpOp.getNumResults()));
175   return newWarpOp;
176 }
177 
178 /// Helper to know if an op can be hoisted out of the region.
179 static bool canBeHoisted(Operation *op,
180                          function_ref<bool(Value)> definedOutside) {
181   return llvm::all_of(op->getOperands(), definedOutside) &&
182          isSideEffectFree(op) && op->getNumRegions() == 0;
183 }
184 
185 /// Return a value yielded by `warpOp` which statifies the filter lamdba
186 /// condition and is not dead.
187 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
188                                 std::function<bool(Operation *)> fn) {
189   auto yield = cast<vector::YieldOp>(
190       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
191   for (OpOperand &yieldOperand : yield->getOpOperands()) {
192     Value yieldValues = yieldOperand.get();
193     Operation *definedOp = yieldValues.getDefiningOp();
194     if (definedOp && fn(definedOp)) {
195       if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
196         return &yieldOperand;
197     }
198   }
199   return {};
200 }
201 
202 // Clones `op` into a new operation that takes `operands` and returns
203 // `resultTypes`.
204 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
205                                               Location loc, Operation *op,
206                                               ArrayRef<Value> operands,
207                                               ArrayRef<Type> resultTypes) {
208   OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
209                      op->getAttrs());
210   return rewriter.create(res);
211 }
212 
213 /// Currently the distribution map is implicit based on the vector shape. In the
214 /// future it will be part of the op.
215 /// Example:
216 /// ```
217 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
218 ///   ...
219 ///   vector.yield %3 : vector<32x16x64xf32>
220 /// }
221 /// ```
222 /// Would have an implicit map of:
223 /// `(d0, d1, d2) -> (d0, d2)`
224 static AffineMap calculateImplicitMap(Value yield, Value ret) {
225   auto srcType = yield.getType().cast<VectorType>();
226   auto dstType = ret.getType().cast<VectorType>();
227   SmallVector<AffineExpr> perm;
228   // Check which dimensions of the yield value are different than the dimensions
229   // of the result to know the distributed dimensions. Then associate each
230   // distributed dimension to an ID in order.
231   for (unsigned i = 0, e = srcType.getRank(); i < e; i++) {
232     if (srcType.getDimSize(i) != dstType.getDimSize(i))
233       perm.push_back(getAffineDimExpr(i, yield.getContext()));
234   }
235   auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext());
236   return map;
237 }
238 
239 namespace {
240 
241 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
242   WarpOpToScfForPattern(MLIRContext *context,
243                         const WarpExecuteOnLane0LoweringOptions &options,
244                         PatternBenefit benefit = 1)
245       : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
246         options(options) {}
247 
248   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
249                                 PatternRewriter &rewriter) const override {
250     return rewriteWarpOpToScfFor(rewriter, warpOp, options);
251   }
252 
253 private:
254   const WarpExecuteOnLane0LoweringOptions &options;
255 };
256 
257 /// Distribute transfer_write ops based on the affine map returned by
258 /// `distributionMapFn`.
259 /// Example:
260 /// ```
261 /// %0 = vector.warp_execute_on_lane_0(%id){
262 ///   ...
263 ///   vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
264 ///   vector.yield
265 /// }
266 /// ```
267 /// To
268 /// ```
269 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
270 ///   ...
271 ///   vector.yield %v : vector<32xf32>
272 /// }
273 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
274 struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
275   WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
276                       PatternBenefit b = 1)
277       : OpRewritePattern<vector::TransferWriteOp>(ctx, b),
278         distributionMapFn(fn) {}
279 
280   /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
281   /// are multiples of the distribution ratio are supported at the moment.
282   LogicalResult tryDistributeOp(RewriterBase &rewriter,
283                                 vector::TransferWriteOp writeOp,
284                                 WarpExecuteOnLane0Op warpOp) const {
285     AffineMap map = distributionMapFn(writeOp);
286     SmallVector<int64_t> targetShape(writeOp.getVectorType().getShape().begin(),
287                                      writeOp.getVectorType().getShape().end());
288     assert(map.getNumResults() == 1 &&
289            "multi-dim distribution not implemented yet");
290     for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
291       unsigned position = map.getDimPosition(i);
292       if (targetShape[position] % warpOp.getWarpSize() != 0)
293         return failure();
294       targetShape[position] = targetShape[position] / warpOp.getWarpSize();
295     }
296     VectorType targetType =
297         VectorType::get(targetShape, writeOp.getVectorType().getElementType());
298 
299     SmallVector<Value> yieldValues = {writeOp.getVector()};
300     SmallVector<Type> retTypes = {targetType};
301     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
302         rewriter, warpOp, yieldValues, retTypes);
303     rewriter.setInsertionPointAfter(newWarpOp);
304 
305     // Move op outside of region: Insert clone at the insertion point and delete
306     // the old op.
307     auto newWriteOp =
308         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
309     rewriter.eraseOp(writeOp);
310 
311     rewriter.setInsertionPoint(newWriteOp);
312     AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
313     Location loc = newWriteOp.getLoc();
314     SmallVector<Value> indices(newWriteOp.getIndices().begin(),
315                                newWriteOp.getIndices().end());
316     for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
317       AffineExpr d0, d1;
318       bindDims(newWarpOp.getContext(), d0, d1);
319       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
320       if (!indexExpr)
321         continue;
322       unsigned indexPos = indexExpr.getPosition();
323       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
324       auto scale =
325           getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext());
326       indices[indexPos] =
327           makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
328                                   {indices[indexPos], newWarpOp.getLaneid()});
329     }
330     newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
331     newWriteOp.getIndicesMutable().assign(indices);
332 
333     return success();
334   }
335 
336   /// Extract TransferWriteOps of vector<1x> into a separate warp op.
337   LogicalResult tryExtractOp(RewriterBase &rewriter,
338                              vector::TransferWriteOp writeOp,
339                              WarpExecuteOnLane0Op warpOp) const {
340     Location loc = writeOp.getLoc();
341     VectorType vecType = writeOp.getVectorType();
342 
343     // Only vector<1x> is supported at the moment.
344     if (vecType.getShape().size() != 1 || vecType.getShape()[0] != 1)
345       return failure();
346 
347     // Do not process warp ops that contain only TransferWriteOps.
348     if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
349           return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
350         }))
351       return failure();
352 
353     SmallVector<Value> yieldValues = {writeOp.getVector()};
354     SmallVector<Type> retTypes = {vecType};
355     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
356         rewriter, warpOp, yieldValues, retTypes);
357     rewriter.setInsertionPointAfter(newWarpOp);
358 
359     // Create a second warp op that contains only writeOp.
360     auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
361         loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
362     Block &body = secondWarpOp.getBodyRegion().front();
363     rewriter.setInsertionPointToStart(&body);
364     auto newWriteOp =
365         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
366     newWriteOp.getVectorMutable().assign(
367         newWarpOp.getResult(newWarpOp.getNumResults() - 1));
368     rewriter.eraseOp(writeOp);
369     rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
370     return success();
371   }
372 
373   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
374                                 PatternRewriter &rewriter) const override {
375     // Ops with mask not supported yet.
376     if (writeOp.getMask())
377       return failure();
378 
379     auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
380     if (!warpOp)
381       return failure();
382 
383     // There must be no op with a side effect after writeOp.
384     Operation *nextOp = writeOp.getOperation();
385     while ((nextOp = nextOp->getNextNode()))
386       if (!isSideEffectFree(nextOp))
387         return failure();
388 
389     if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
390           return writeOp.getVector() == value ||
391                  warpOp.isDefinedOutsideOfRegion(value);
392         }))
393       return failure();
394 
395     if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
396       return success();
397 
398     if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
399       return success();
400 
401     return failure();
402   }
403 
404 private:
405   DistributionMapFn distributionMapFn;
406 };
407 
408 /// Sink out elementwise op feeding into a warp op yield.
409 /// ```
410 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
411 ///   ...
412 ///   %3 = arith.addf %1, %2 : vector<32xf32>
413 ///   vector.yield %3 : vector<32xf32>
414 /// }
415 /// ```
416 /// To
417 /// ```
418 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
419 /// vector<1xf32>, vector<1xf32>) {
420 ///   ...
421 ///   %4 = arith.addf %2, %3 : vector<32xf32>
422 ///   vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
423 ///   vector<32xf32>
424 /// }
425 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
426 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
427   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
428   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
429                                 PatternRewriter &rewriter) const override {
430     OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
431       return OpTrait::hasElementwiseMappableTraits(op);
432     });
433     if (!yieldOperand)
434       return failure();
435     Operation *elementWise = yieldOperand->get().getDefiningOp();
436     unsigned operandIndex = yieldOperand->getOperandNumber();
437     Value distributedVal = warpOp.getResult(operandIndex);
438     SmallVector<Value> yieldValues;
439     SmallVector<Type> retTypes;
440     Location loc = warpOp.getLoc();
441     for (OpOperand &operand : elementWise->getOpOperands()) {
442       Type targetType;
443       if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) {
444         // If the result type is a vector, the operands must also be vectors.
445         auto operandType = operand.get().getType().cast<VectorType>();
446         targetType =
447             VectorType::get(vecType.getShape(), operandType.getElementType());
448       } else {
449         auto operandType = operand.get().getType();
450         assert(!operandType.isa<VectorType>() &&
451                "unexpected yield of vector from op with scalar result type");
452         targetType = operandType;
453       }
454       retTypes.push_back(targetType);
455       yieldValues.push_back(operand.get());
456     }
457     unsigned numResults = warpOp.getNumResults();
458     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
459         rewriter, warpOp, yieldValues, retTypes);
460     rewriter.setInsertionPointAfter(newWarpOp);
461     SmallVector<Value> newOperands(elementWise->getOperands().begin(),
462                                    elementWise->getOperands().end());
463     for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
464       newOperands[i] = newWarpOp.getResult(i + numResults);
465     }
466     OpBuilder::InsertionGuard g(rewriter);
467     rewriter.setInsertionPointAfter(newWarpOp);
468     Operation *newOp = cloneOpWithOperandsAndTypes(
469         rewriter, loc, elementWise, newOperands,
470         {newWarpOp.getResult(operandIndex).getType()});
471     newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0));
472     return success();
473   }
474 };
475 
476 /// Sink out transfer_read op feeding into a warp op yield.
477 /// ```
478 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
479 ///   ...
480 //    %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
481 //    vector<32xf32>
482 ///   vector.yield %2 : vector<32xf32>
483 /// }
484 /// ```
485 /// To
486 /// ```
487 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
488 /// vector<1xf32>, vector<1xf32>) {
489 ///   ...
490 ///   %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
491 ///   vector<32xf32> vector.yield %2 : vector<32xf32>
492 /// }
493 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
494 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
495   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
496   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
497                                 PatternRewriter &rewriter) const override {
498     OpOperand *operand = getWarpResult(
499         warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
500     if (!operand)
501       return failure();
502     auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
503     unsigned operandIndex = operand->getOperandNumber();
504     Value distributedVal = warpOp.getResult(operandIndex);
505 
506     SmallVector<Value, 4> indices(read.getIndices().begin(),
507                                   read.getIndices().end());
508     AffineMap map = calculateImplicitMap(read.getResult(), distributedVal);
509     AffineMap indexMap = map.compose(read.getPermutationMap());
510     OpBuilder::InsertionGuard g(rewriter);
511     rewriter.setInsertionPointAfter(warpOp);
512     for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
513       AffineExpr d0, d1;
514       bindDims(read.getContext(), d0, d1);
515       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
516       if (!indexExpr)
517         continue;
518       unsigned indexPos = indexExpr.getPosition();
519       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
520       int64_t scale =
521           distributedVal.getType().cast<VectorType>().getDimSize(vectorPos);
522       indices[indexPos] =
523           makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1,
524                                   {indices[indexPos], warpOp.getLaneid()});
525     }
526     Value newRead = rewriter.create<vector::TransferReadOp>(
527         read.getLoc(), distributedVal.getType(), read.getSource(), indices,
528         read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
529         read.getInBoundsAttr());
530     distributedVal.replaceAllUsesWith(newRead);
531     return success();
532   }
533 };
534 
535 /// Remove any result that has no use along with the matching yieldOp operand.
536 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
537 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
538   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
539   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
540                                 PatternRewriter &rewriter) const override {
541     SmallVector<Type> resultTypes;
542     SmallVector<Value> yieldValues;
543     auto yield = cast<vector::YieldOp>(
544         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
545     for (OpResult result : warpOp.getResults()) {
546       if (result.use_empty())
547         continue;
548       resultTypes.push_back(result.getType());
549       yieldValues.push_back(yield.getOperand(result.getResultNumber()));
550     }
551     if (yield.getNumOperands() == yieldValues.size())
552       return failure();
553     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
554         rewriter, warpOp, yieldValues, resultTypes);
555     unsigned resultIndex = 0;
556     for (OpResult result : warpOp.getResults()) {
557       if (result.use_empty())
558         continue;
559       result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++));
560     }
561     rewriter.eraseOp(warpOp);
562     return success();
563   }
564 };
565 
566 // If an operand is directly yielded out of the region we can forward it
567 // directly and it doesn't need to go through the region.
568 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
569   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
570   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
571                                 PatternRewriter &rewriter) const override {
572     SmallVector<Type> resultTypes;
573     SmallVector<Value> yieldValues;
574     auto yield = cast<vector::YieldOp>(
575         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
576     Value valForwarded;
577     unsigned resultIndex;
578     for (OpOperand &operand : yield->getOpOperands()) {
579       Value result = warpOp.getResult(operand.getOperandNumber());
580       if (result.use_empty())
581         continue;
582 
583       // Assume all the values coming from above are uniform.
584       if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
585         if (result.getType() != operand.get().getType())
586           continue;
587         valForwarded = operand.get();
588         resultIndex = operand.getOperandNumber();
589         break;
590       }
591       auto arg = operand.get().dyn_cast<BlockArgument>();
592       if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
593         continue;
594       Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
595       if (result.getType() != warpOperand.getType())
596         continue;
597       valForwarded = warpOperand;
598       resultIndex = operand.getOperandNumber();
599       break;
600     }
601     if (!valForwarded)
602       return failure();
603     warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded);
604     return success();
605   }
606 };
607 
608 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
609   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
610   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
611                                 PatternRewriter &rewriter) const override {
612     OpOperand *operand = getWarpResult(
613         warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); });
614     if (!operand)
615       return failure();
616     unsigned int operandNumber = operand->getOperandNumber();
617     auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
618     Location loc = broadcastOp.getLoc();
619     auto destVecType =
620         warpOp->getResultTypes()[operandNumber].cast<VectorType>();
621     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
622         rewriter, warpOp, {broadcastOp.getSource()},
623         {broadcastOp.getSource().getType()});
624     rewriter.setInsertionPointAfter(newWarpOp);
625     Value broadcasted = rewriter.create<vector::BroadcastOp>(
626         loc, destVecType, newWarpOp->getResults().back());
627     newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
628 
629     return success();
630   }
631 };
632 
633 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
634 /// the scf.ForOp is the last operation in the region so that it doesn't change
635 /// the order of execution. This creates a new scf.for region after the
636 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
637 /// WarpExecuteOnLane0Op region. Example:
638 /// ```
639 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
640 ///   ...
641 ///   %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
642 ///   -> (vector<128xf32>) {
643 ///     ...
644 ///     scf.yield %r : vector<128xf32>
645 ///   }
646 ///   vector.yield %v1 : vector<128xf32>
647 /// }
648 /// ```
649 /// To:
650 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
651 ///   ...
652 ///   vector.yield %v : vector<128xf32>
653 /// }
654 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
655 ///   -> (vector<4xf32>) {
656 ///     %iw = vector.warp_execute_on_lane_0(%laneid)
657 ///     args(%varg : vector<4xf32>) -> (vector<4xf32>) {
658 ///     ^bb0(%arg: vector<128xf32>):
659 ///       ...
660 ///       vector.yield %ir : vector<128xf32>
661 ///     }
662 ///     scf.yield %iw : vector<4xf32>
663 ///  }
664 /// ```
665 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
666   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
667   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
668                                 PatternRewriter &rewriter) const override {
669     auto yield = cast<vector::YieldOp>(
670         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
671     // Only pick up forOp if it is the last op in the region.
672     Operation *lastNode = yield->getPrevNode();
673     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
674     if (!forOp)
675       return failure();
676     SmallVector<Value> newOperands;
677     SmallVector<unsigned> resultIdx;
678     // Collect all the outputs coming from the forOp.
679     for (OpOperand &yieldOperand : yield->getOpOperands()) {
680       if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
681         continue;
682       auto forResult = yieldOperand.get().cast<OpResult>();
683       newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber()));
684       yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
685       resultIdx.push_back(yieldOperand.getOperandNumber());
686     }
687     OpBuilder::InsertionGuard g(rewriter);
688     rewriter.setInsertionPointAfter(warpOp);
689     // Create a new for op outside the region with a WarpExecuteOnLane0Op region
690     // inside.
691     auto newForOp = rewriter.create<scf::ForOp>(
692         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
693         forOp.getStep(), newOperands);
694     rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
695     auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
696         warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(),
697         warpOp.getWarpSize(), newForOp.getRegionIterArgs(),
698         forOp.getResultTypes());
699 
700     SmallVector<Value> argMapping;
701     argMapping.push_back(newForOp.getInductionVar());
702     for (Value args : innerWarp.getBody()->getArguments()) {
703       argMapping.push_back(args);
704     }
705     SmallVector<Value> yieldOperands;
706     for (Value operand : forOp.getBody()->getTerminator()->getOperands())
707       yieldOperands.push_back(operand);
708     rewriter.eraseOp(forOp.getBody()->getTerminator());
709     rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
710     rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
711     rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
712     rewriter.setInsertionPointAfter(innerWarp);
713     rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
714     rewriter.eraseOp(forOp);
715     // Replace the warpOp result coming from the original ForOp.
716     for (const auto &res : llvm::enumerate(resultIdx)) {
717       warpOp.getResult(res.value())
718           .replaceAllUsesWith(newForOp.getResult(res.index()));
719       newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value()));
720     }
721     return success();
722   }
723 };
724 
725 } // namespace
726 
727 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
728     RewritePatternSet &patterns,
729     const WarpExecuteOnLane0LoweringOptions &options) {
730   patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options);
731 }
732 
733 void mlir::vector::populateDistributeTransferWriteOpPatterns(
734     RewritePatternSet &patterns, DistributionMapFn distributionMapFn) {
735   patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn);
736 }
737 
738 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
739     RewritePatternSet &patterns) {
740   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
741                WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp>(
742       patterns.getContext());
743 }
744 
745 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
746   Block *body = warpOp.getBody();
747 
748   // Keep track of the ops we want to hoist.
749   llvm::SmallSetVector<Operation *, 8> opsToMove;
750 
751   // Helper to check if a value is or will be defined outside of the region.
752   auto isDefinedOutsideOfBody = [&](Value value) {
753     auto *definingOp = value.getDefiningOp();
754     return (definingOp && opsToMove.count(definingOp)) ||
755            warpOp.isDefinedOutsideOfRegion(value);
756   };
757 
758   // Do not use walk here, as we do not want to go into nested regions and hoist
759   // operations from there.
760   for (auto &op : body->without_terminator()) {
761     bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
762       return result.getType().isa<VectorType>();
763     });
764     if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
765       opsToMove.insert(&op);
766   }
767 
768   // Move all the ops marked as uniform outside of the region.
769   for (Operation *op : opsToMove)
770     op->moveBefore(warpOp);
771 }
772