1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/SCF/IR/SCF.h"
13 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/Transforms/SideEffectUtils.h"
17 
18 using namespace mlir;
19 using namespace mlir::vector;
20 
21 static LogicalResult
22 rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
23                       const WarpExecuteOnLane0LoweringOptions &options) {
24   assert(warpOp.getBodyRegion().hasOneBlock() &&
25          "expected WarpOp with single block");
26   Block *warpOpBody = &warpOp.getBodyRegion().front();
27   Location loc = warpOp.getLoc();
28 
29   // Passed all checks. Start rewriting.
30   OpBuilder::InsertionGuard g(rewriter);
31   rewriter.setInsertionPoint(warpOp);
32 
33   // Create scf.if op.
34   Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
35   Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
36                                                  warpOp.getLaneid(), c0);
37   auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
38                                          /*withElseRegion=*/false);
39   rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
40 
41   // Store vectors that are defined outside of warpOp into the scratch pad
42   // buffer.
43   SmallVector<Value> bbArgReplacements;
44   for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
45     Value val = it.value();
46     Value bbArg = warpOpBody->getArgument(it.index());
47 
48     rewriter.setInsertionPoint(ifOp);
49     Value buffer =
50         options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType());
51 
52     // Store arg vector into buffer.
53     rewriter.setInsertionPoint(ifOp);
54     auto vectorType = val.getType().cast<VectorType>();
55     int64_t storeSize = vectorType.getShape()[0];
56     Value storeOffset = rewriter.create<arith::MulIOp>(
57         loc, warpOp.getLaneid(),
58         rewriter.create<arith::ConstantIndexOp>(loc, storeSize));
59     rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset);
60 
61     // Load bbArg vector from buffer.
62     rewriter.setInsertionPointToStart(ifOp.thenBlock());
63     auto bbArgType = bbArg.getType().cast<VectorType>();
64     Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0);
65     bbArgReplacements.push_back(loadOp);
66   }
67 
68   // Insert sync after all the stores and before all the loads.
69   if (!warpOp.getArgs().empty()) {
70     rewriter.setInsertionPoint(ifOp);
71     options.warpSyncronizationFn(loc, rewriter, warpOp);
72   }
73 
74   // Move body of warpOp to ifOp.
75   rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
76 
77   // Rewrite terminator and compute replacements of WarpOp results.
78   SmallVector<Value> replacements;
79   auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
80   Location yieldLoc = yieldOp.getLoc();
81   for (const auto &it : llvm::enumerate(yieldOp.operands())) {
82     Value val = it.value();
83     Type resultType = warpOp->getResultTypes()[it.index()];
84     rewriter.setInsertionPoint(ifOp);
85     Value buffer =
86         options.warpAllocationFn(loc, rewriter, warpOp, val.getType());
87 
88     // Store yielded value into buffer.
89     rewriter.setInsertionPoint(yieldOp);
90     if (val.getType().isa<VectorType>())
91       rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0);
92     else
93       rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0);
94 
95     // Load value from buffer (after warpOp).
96     rewriter.setInsertionPointAfter(ifOp);
97     if (resultType == val.getType()) {
98       // Result type and yielded value type are the same. This is a broadcast.
99       // E.g.:
100       // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
101       //   vector.yield %cst : f32
102       // }
103       // Both types are f32. The constant %cst is broadcasted to all lanes.
104       // This is described in more detail in the documentation of the op.
105       Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0);
106       replacements.push_back(loadOp);
107     } else {
108       auto loadedVectorType = resultType.cast<VectorType>();
109       int64_t loadSize = loadedVectorType.getShape()[0];
110 
111       // loadOffset = laneid * loadSize
112       Value loadOffset = rewriter.create<arith::MulIOp>(
113           loc, warpOp.getLaneid(),
114           rewriter.create<arith::ConstantIndexOp>(loc, loadSize));
115       Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType,
116                                                      buffer, loadOffset);
117       replacements.push_back(loadOp);
118     }
119   }
120 
121   // Insert sync after all the stores and before all the loads.
122   if (!yieldOp.operands().empty()) {
123     rewriter.setInsertionPointAfter(ifOp);
124     options.warpSyncronizationFn(loc, rewriter, warpOp);
125   }
126 
127   // Delete terminator and add empty scf.yield.
128   rewriter.eraseOp(yieldOp);
129   rewriter.setInsertionPointToEnd(ifOp.thenBlock());
130   rewriter.create<scf::YieldOp>(yieldLoc);
131 
132   // Compute replacements for WarpOp results.
133   rewriter.replaceOp(warpOp, replacements);
134 
135   return success();
136 }
137 
138 /// Helper to create a new WarpExecuteOnLane0Op with different signature.
139 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
140     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
141     ValueRange newYieldedValues, TypeRange newReturnTypes) {
142   // Create a new op before the existing one, with the extra operands.
143   OpBuilder::InsertionGuard g(rewriter);
144   rewriter.setInsertionPoint(warpOp);
145   auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
146       warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
147       warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
148 
149   Region &opBody = warpOp.getBodyRegion();
150   Region &newOpBody = newWarpOp.getBodyRegion();
151   Block &newOpFirstBlock = newOpBody.front();
152   rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
153   rewriter.eraseBlock(&newOpFirstBlock);
154   assert(newWarpOp.getWarpRegion().hasOneBlock() &&
155          "expected WarpOp with single block");
156 
157   auto yield =
158       cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
159 
160   rewriter.updateRootInPlace(
161       yield, [&]() { yield.operandsMutable().assign(newYieldedValues); });
162   return newWarpOp;
163 }
164 
165 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
166 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
167     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
168     ValueRange newYieldedValues, TypeRange newReturnTypes) {
169   SmallVector<Type> types(warpOp.getResultTypes().begin(),
170                           warpOp.getResultTypes().end());
171   types.append(newReturnTypes.begin(), newReturnTypes.end());
172   auto yield = cast<vector::YieldOp>(
173       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
174   SmallVector<Value> yieldValues(yield.getOperands().begin(),
175                                  yield.getOperands().end());
176   yieldValues.append(newYieldedValues.begin(), newYieldedValues.end());
177   WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
178       rewriter, warpOp, yieldValues, types);
179   rewriter.replaceOp(warpOp,
180                      newWarpOp.getResults().take_front(warpOp.getNumResults()));
181   return newWarpOp;
182 }
183 
184 /// Helper to know if an op can be hoisted out of the region.
185 static bool canBeHoisted(Operation *op,
186                          function_ref<bool(Value)> definedOutside) {
187   return llvm::all_of(op->getOperands(), definedOutside) &&
188          isSideEffectFree(op) && op->getNumRegions() == 0;
189 }
190 
191 /// Return a value yielded by `warpOp` which statifies the filter lamdba
192 /// condition and is not dead.
193 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
194                                 std::function<bool(Operation *)> fn) {
195   auto yield = cast<vector::YieldOp>(
196       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
197   for (OpOperand &yieldOperand : yield->getOpOperands()) {
198     Value yieldValues = yieldOperand.get();
199     Operation *definedOp = yieldValues.getDefiningOp();
200     if (definedOp && fn(definedOp)) {
201       if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
202         return &yieldOperand;
203     }
204   }
205   return {};
206 }
207 
208 // Clones `op` into a new operation that takes `operands` and returns
209 // `resultTypes`.
210 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
211                                               Location loc, Operation *op,
212                                               ArrayRef<Value> operands,
213                                               ArrayRef<Type> resultTypes) {
214   OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
215                      op->getAttrs());
216   return rewriter.create(res);
217 }
218 
219 /// Currently the distribution map is implicit based on the vector shape. In the
220 /// future it will be part of the op.
221 /// Example:
222 /// ```
223 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
224 ///   ...
225 ///   vector.yield %3 : vector<32x16x64xf32>
226 /// }
227 /// ```
228 /// Would have an implicit map of:
229 /// `(d0, d1, d2) -> (d0, d2)`
230 static AffineMap calculateImplicitMap(Value yield, Value ret) {
231   auto srcType = yield.getType().cast<VectorType>();
232   auto dstType = ret.getType().cast<VectorType>();
233   SmallVector<AffineExpr> perm;
234   // Check which dimensions of the yield value are different than the dimensions
235   // of the result to know the distributed dimensions. Then associate each
236   // distributed dimension to an ID in order.
237   for (unsigned i = 0, e = srcType.getRank(); i < e; i++) {
238     if (srcType.getDimSize(i) != dstType.getDimSize(i))
239       perm.push_back(getAffineDimExpr(i, yield.getContext()));
240   }
241   auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext());
242   return map;
243 }
244 
245 namespace {
246 
247 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
248   WarpOpToScfForPattern(MLIRContext *context,
249                         const WarpExecuteOnLane0LoweringOptions &options,
250                         PatternBenefit benefit = 1)
251       : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
252         options(options) {}
253 
254   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
255                                 PatternRewriter &rewriter) const override {
256     return rewriteWarpOpToScfFor(rewriter, warpOp, options);
257   }
258 
259 private:
260   const WarpExecuteOnLane0LoweringOptions &options;
261 };
262 
263 /// Distribute transfer_write ops based on the affine map returned by
264 /// `distributionMapFn`.
265 /// Example:
266 /// ```
267 /// %0 = vector.warp_execute_on_lane_0(%id){
268 ///   ...
269 ///   vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
270 ///   vector.yield
271 /// }
272 /// ```
273 /// To
274 /// ```
275 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
276 ///   ...
277 ///   vector.yield %v : vector<32xf32>
278 /// }
279 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
280 struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
281   WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
282                       PatternBenefit b = 1)
283       : OpRewritePattern<vector::TransferWriteOp>(ctx, b),
284         distributionMapFn(fn) {}
285 
286   /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
287   /// are multiples of the distribution ratio are supported at the moment.
288   LogicalResult tryDistributeOp(RewriterBase &rewriter,
289                                 vector::TransferWriteOp writeOp,
290                                 WarpExecuteOnLane0Op warpOp) const {
291     AffineMap map = distributionMapFn(writeOp);
292     SmallVector<int64_t> targetShape(writeOp.getVectorType().getShape().begin(),
293                                      writeOp.getVectorType().getShape().end());
294     assert(map.getNumResults() == 1 &&
295            "multi-dim distribution not implemented yet");
296     for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
297       unsigned position = map.getDimPosition(i);
298       if (targetShape[position] % warpOp.getWarpSize() != 0)
299         return failure();
300       targetShape[position] = targetShape[position] / warpOp.getWarpSize();
301     }
302     VectorType targetType =
303         VectorType::get(targetShape, writeOp.getVectorType().getElementType());
304 
305     SmallVector<Value> yieldValues = {writeOp.getVector()};
306     SmallVector<Type> retTypes = {targetType};
307     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
308         rewriter, warpOp, yieldValues, retTypes);
309     rewriter.setInsertionPointAfter(newWarpOp);
310 
311     // Move op outside of region: Insert clone at the insertion point and delete
312     // the old op.
313     auto newWriteOp =
314         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
315     rewriter.eraseOp(writeOp);
316 
317     rewriter.setInsertionPoint(newWriteOp);
318     AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
319     Location loc = newWriteOp.getLoc();
320     SmallVector<Value> indices(newWriteOp.getIndices().begin(),
321                                newWriteOp.getIndices().end());
322     for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
323       AffineExpr d0, d1;
324       bindDims(newWarpOp.getContext(), d0, d1);
325       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
326       if (!indexExpr)
327         continue;
328       unsigned indexPos = indexExpr.getPosition();
329       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
330       auto scale =
331           getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext());
332       indices[indexPos] =
333           makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
334                                   {indices[indexPos], newWarpOp.getLaneid()});
335     }
336     newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
337     newWriteOp.getIndicesMutable().assign(indices);
338 
339     return success();
340   }
341 
342   /// Extract TransferWriteOps of vector<1x> into a separate warp op.
343   LogicalResult tryExtractOp(RewriterBase &rewriter,
344                              vector::TransferWriteOp writeOp,
345                              WarpExecuteOnLane0Op warpOp) const {
346     Location loc = writeOp.getLoc();
347     VectorType vecType = writeOp.getVectorType();
348 
349     // Only sink out vector of 1 element for now to not serialize large vector
350     // store. This can later be controlled by user.
351     if (vecType.getNumElements() != 1)
352       return failure();
353 
354     // Do not process warp ops that contain only TransferWriteOps.
355     if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
356           return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
357         }))
358       return failure();
359 
360     SmallVector<Value> yieldValues = {writeOp.getVector()};
361     SmallVector<Type> retTypes = {vecType};
362     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
363         rewriter, warpOp, yieldValues, retTypes);
364     rewriter.setInsertionPointAfter(newWarpOp);
365 
366     // Create a second warp op that contains only writeOp.
367     auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
368         loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
369     Block &body = secondWarpOp.getBodyRegion().front();
370     rewriter.setInsertionPointToStart(&body);
371     auto newWriteOp =
372         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
373     newWriteOp.getVectorMutable().assign(
374         newWarpOp.getResult(newWarpOp.getNumResults() - 1));
375     rewriter.eraseOp(writeOp);
376     rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
377     return success();
378   }
379 
380   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
381                                 PatternRewriter &rewriter) const override {
382     // Ops with mask not supported yet.
383     if (writeOp.getMask())
384       return failure();
385 
386     auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
387     if (!warpOp)
388       return failure();
389 
390     // There must be no op with a side effect after writeOp.
391     Operation *nextOp = writeOp.getOperation();
392     while ((nextOp = nextOp->getNextNode()))
393       if (!isSideEffectFree(nextOp))
394         return failure();
395 
396     if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
397           return writeOp.getVector() == value ||
398                  warpOp.isDefinedOutsideOfRegion(value);
399         }))
400       return failure();
401 
402     if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
403       return success();
404 
405     if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
406       return success();
407 
408     return failure();
409   }
410 
411 private:
412   DistributionMapFn distributionMapFn;
413 };
414 
415 /// Sink out elementwise op feeding into a warp op yield.
416 /// ```
417 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
418 ///   ...
419 ///   %3 = arith.addf %1, %2 : vector<32xf32>
420 ///   vector.yield %3 : vector<32xf32>
421 /// }
422 /// ```
423 /// To
424 /// ```
425 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
426 /// vector<1xf32>, vector<1xf32>) {
427 ///   ...
428 ///   %4 = arith.addf %2, %3 : vector<32xf32>
429 ///   vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
430 ///   vector<32xf32>
431 /// }
432 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
433 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
434   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
435   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
436                                 PatternRewriter &rewriter) const override {
437     OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
438       return OpTrait::hasElementwiseMappableTraits(op);
439     });
440     if (!yieldOperand)
441       return failure();
442     Operation *elementWise = yieldOperand->get().getDefiningOp();
443     unsigned operandIndex = yieldOperand->getOperandNumber();
444     Value distributedVal = warpOp.getResult(operandIndex);
445     SmallVector<Value> yieldValues;
446     SmallVector<Type> retTypes;
447     Location loc = warpOp.getLoc();
448     for (OpOperand &operand : elementWise->getOpOperands()) {
449       Type targetType;
450       if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) {
451         // If the result type is a vector, the operands must also be vectors.
452         auto operandType = operand.get().getType().cast<VectorType>();
453         targetType =
454             VectorType::get(vecType.getShape(), operandType.getElementType());
455       } else {
456         auto operandType = operand.get().getType();
457         assert(!operandType.isa<VectorType>() &&
458                "unexpected yield of vector from op with scalar result type");
459         targetType = operandType;
460       }
461       retTypes.push_back(targetType);
462       yieldValues.push_back(operand.get());
463     }
464     unsigned numResults = warpOp.getNumResults();
465     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
466         rewriter, warpOp, yieldValues, retTypes);
467     rewriter.setInsertionPointAfter(newWarpOp);
468     SmallVector<Value> newOperands(elementWise->getOperands().begin(),
469                                    elementWise->getOperands().end());
470     for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
471       newOperands[i] = newWarpOp.getResult(i + numResults);
472     }
473     OpBuilder::InsertionGuard g(rewriter);
474     rewriter.setInsertionPointAfter(newWarpOp);
475     Operation *newOp = cloneOpWithOperandsAndTypes(
476         rewriter, loc, elementWise, newOperands,
477         {newWarpOp.getResult(operandIndex).getType()});
478     newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0));
479     return success();
480   }
481 };
482 
483 /// Sink out transfer_read op feeding into a warp op yield.
484 /// ```
485 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
486 ///   ...
487 //    %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
488 //    vector<32xf32>
489 ///   vector.yield %2 : vector<32xf32>
490 /// }
491 /// ```
492 /// To
493 /// ```
494 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
495 /// vector<1xf32>, vector<1xf32>) {
496 ///   ...
497 ///   %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
498 ///   vector<32xf32> vector.yield %2 : vector<32xf32>
499 /// }
500 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
501 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
502   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
503   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
504                                 PatternRewriter &rewriter) const override {
505     OpOperand *operand = getWarpResult(
506         warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
507     if (!operand)
508       return failure();
509     auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
510     unsigned operandIndex = operand->getOperandNumber();
511     Value distributedVal = warpOp.getResult(operandIndex);
512 
513     SmallVector<Value, 4> indices(read.getIndices().begin(),
514                                   read.getIndices().end());
515     AffineMap map = calculateImplicitMap(read.getResult(), distributedVal);
516     AffineMap indexMap = map.compose(read.getPermutationMap());
517     OpBuilder::InsertionGuard g(rewriter);
518     rewriter.setInsertionPointAfter(warpOp);
519     for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
520       AffineExpr d0, d1;
521       bindDims(read.getContext(), d0, d1);
522       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
523       if (!indexExpr)
524         continue;
525       unsigned indexPos = indexExpr.getPosition();
526       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
527       int64_t scale =
528           distributedVal.getType().cast<VectorType>().getDimSize(vectorPos);
529       indices[indexPos] =
530           makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1,
531                                   {indices[indexPos], warpOp.getLaneid()});
532     }
533     Value newRead = rewriter.create<vector::TransferReadOp>(
534         read.getLoc(), distributedVal.getType(), read.getSource(), indices,
535         read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
536         read.getInBoundsAttr());
537     distributedVal.replaceAllUsesWith(newRead);
538     return success();
539   }
540 };
541 
542 /// Remove any result that has no use along with the matching yieldOp operand.
543 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
544 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
545   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
546   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
547                                 PatternRewriter &rewriter) const override {
548     SmallVector<Type> resultTypes;
549     SmallVector<Value> yieldValues;
550     auto yield = cast<vector::YieldOp>(
551         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
552     for (OpResult result : warpOp.getResults()) {
553       if (result.use_empty())
554         continue;
555       resultTypes.push_back(result.getType());
556       yieldValues.push_back(yield.getOperand(result.getResultNumber()));
557     }
558     if (yield.getNumOperands() == yieldValues.size())
559       return failure();
560     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
561         rewriter, warpOp, yieldValues, resultTypes);
562     unsigned resultIndex = 0;
563     for (OpResult result : warpOp.getResults()) {
564       if (result.use_empty())
565         continue;
566       result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++));
567     }
568     rewriter.eraseOp(warpOp);
569     return success();
570   }
571 };
572 
573 // If an operand is directly yielded out of the region we can forward it
574 // directly and it doesn't need to go through the region.
575 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
576   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
577   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
578                                 PatternRewriter &rewriter) const override {
579     SmallVector<Type> resultTypes;
580     SmallVector<Value> yieldValues;
581     auto yield = cast<vector::YieldOp>(
582         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
583     Value valForwarded;
584     unsigned resultIndex;
585     for (OpOperand &operand : yield->getOpOperands()) {
586       Value result = warpOp.getResult(operand.getOperandNumber());
587       if (result.use_empty())
588         continue;
589 
590       // Assume all the values coming from above are uniform.
591       if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
592         if (result.getType() != operand.get().getType())
593           continue;
594         valForwarded = operand.get();
595         resultIndex = operand.getOperandNumber();
596         break;
597       }
598       auto arg = operand.get().dyn_cast<BlockArgument>();
599       if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
600         continue;
601       Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
602       if (result.getType() != warpOperand.getType())
603         continue;
604       valForwarded = warpOperand;
605       resultIndex = operand.getOperandNumber();
606       break;
607     }
608     if (!valForwarded)
609       return failure();
610     warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded);
611     return success();
612   }
613 };
614 
615 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
616   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
617   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
618                                 PatternRewriter &rewriter) const override {
619     OpOperand *operand = getWarpResult(
620         warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); });
621     if (!operand)
622       return failure();
623     unsigned int operandNumber = operand->getOperandNumber();
624     auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
625     Location loc = broadcastOp.getLoc();
626     auto destVecType =
627         warpOp->getResultTypes()[operandNumber].cast<VectorType>();
628     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
629         rewriter, warpOp, {broadcastOp.getSource()},
630         {broadcastOp.getSource().getType()});
631     rewriter.setInsertionPointAfter(newWarpOp);
632     Value broadcasted = rewriter.create<vector::BroadcastOp>(
633         loc, destVecType, newWarpOp->getResults().back());
634     newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
635 
636     return success();
637   }
638 };
639 
640 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
641 /// the scf.ForOp is the last operation in the region so that it doesn't change
642 /// the order of execution. This creates a new scf.for region after the
643 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
644 /// WarpExecuteOnLane0Op region. Example:
645 /// ```
646 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
647 ///   ...
648 ///   %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
649 ///   -> (vector<128xf32>) {
650 ///     ...
651 ///     scf.yield %r : vector<128xf32>
652 ///   }
653 ///   vector.yield %v1 : vector<128xf32>
654 /// }
655 /// ```
656 /// To:
657 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
658 ///   ...
659 ///   vector.yield %v : vector<128xf32>
660 /// }
661 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
662 ///   -> (vector<4xf32>) {
663 ///     %iw = vector.warp_execute_on_lane_0(%laneid)
664 ///     args(%varg : vector<4xf32>) -> (vector<4xf32>) {
665 ///     ^bb0(%arg: vector<128xf32>):
666 ///       ...
667 ///       vector.yield %ir : vector<128xf32>
668 ///     }
669 ///     scf.yield %iw : vector<4xf32>
670 ///  }
671 /// ```
672 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
673   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
674   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
675                                 PatternRewriter &rewriter) const override {
676     auto yield = cast<vector::YieldOp>(
677         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
678     // Only pick up forOp if it is the last op in the region.
679     Operation *lastNode = yield->getPrevNode();
680     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
681     if (!forOp)
682       return failure();
683     SmallVector<Value> newOperands;
684     SmallVector<unsigned> resultIdx;
685     // Collect all the outputs coming from the forOp.
686     for (OpOperand &yieldOperand : yield->getOpOperands()) {
687       if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
688         continue;
689       auto forResult = yieldOperand.get().cast<OpResult>();
690       newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber()));
691       yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
692       resultIdx.push_back(yieldOperand.getOperandNumber());
693     }
694     OpBuilder::InsertionGuard g(rewriter);
695     rewriter.setInsertionPointAfter(warpOp);
696     // Create a new for op outside the region with a WarpExecuteOnLane0Op region
697     // inside.
698     auto newForOp = rewriter.create<scf::ForOp>(
699         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
700         forOp.getStep(), newOperands);
701     rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
702     auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
703         warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(),
704         warpOp.getWarpSize(), newForOp.getRegionIterArgs(),
705         forOp.getResultTypes());
706 
707     SmallVector<Value> argMapping;
708     argMapping.push_back(newForOp.getInductionVar());
709     for (Value args : innerWarp.getBody()->getArguments()) {
710       argMapping.push_back(args);
711     }
712     SmallVector<Value> yieldOperands;
713     for (Value operand : forOp.getBody()->getTerminator()->getOperands())
714       yieldOperands.push_back(operand);
715     rewriter.eraseOp(forOp.getBody()->getTerminator());
716     rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
717     rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
718     rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
719     rewriter.setInsertionPointAfter(innerWarp);
720     rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
721     rewriter.eraseOp(forOp);
722     // Replace the warpOp result coming from the original ForOp.
723     for (const auto &res : llvm::enumerate(resultIdx)) {
724       warpOp.getResult(res.value())
725           .replaceAllUsesWith(newForOp.getResult(res.index()));
726       newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value()));
727     }
728     return success();
729   }
730 };
731 
732 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
733 /// The vector is reduced in parallel. Currently limited to vector size matching
734 /// the warpOp size. E.g.:
735 /// ```
736 /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
737 ///   %0 = "some_def"() : () -> (vector<32xf32>)
738 ///   %1 = vector.reduction "add", %0 : vector<32xf32> into f32
739 ///   vector_ext.yield %1 : f32
740 /// }
741 /// ```
742 /// is lowered to:
743 /// ```
744 /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
745 ///   %1 = "some_def"() : () -> (vector<32xf32>)
746 ///   vector_ext.yield %1 : vector<32xf32>
747 /// }
748 /// %a = vector.extract %0[0] : vector<1xf32>
749 /// %r = ("warp.reduction %a")
750 /// ```
751 struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
752   WarpOpReduction(MLIRContext *context,
753                   DistributedReductionFn distributedReductionFn,
754                   PatternBenefit benefit = 1)
755       : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
756         distributedReductionFn(distributedReductionFn) {}
757 
758   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
759                                 PatternRewriter &rewriter) const override {
760     OpOperand *yieldOperand = getWarpResult(
761         warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); });
762     if (!yieldOperand)
763       return failure();
764 
765     auto reductionOp =
766         cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
767     auto vectorType = reductionOp.getVector().getType().cast<VectorType>();
768     // Only rank 1 vectors supported.
769     if (vectorType.getRank() != 1)
770       return rewriter.notifyMatchFailure(
771           warpOp, "Only rank 1 reductions can be distributed.");
772     // Only warp_size-sized vectors supported.
773     if (static_cast<uint64_t>(vectorType.getShape()[0]) != warpOp.getWarpSize())
774       return rewriter.notifyMatchFailure(
775           warpOp, "Reduction vector dimension must match was size.");
776     // Only f32 and i32 element types are supported.
777     if (!reductionOp.getType().isF32() &&
778         !reductionOp.getType().isSignlessInteger(32))
779       return rewriter.notifyMatchFailure(
780           warpOp,
781           "Reduction distribution currently only supports 32bits types.");
782 
783     Location yieldLoc = yieldOperand->getOwner()->getLoc();
784 
785     // Return vector that will be reduced from the WarpExecuteOnLane0Op.
786     unsigned operandIndex = yieldOperand->getOperandNumber();
787     SmallVector<Value> yieldValues = {reductionOp.getVector()};
788     SmallVector<Type> retTypes = {VectorType::get({1}, reductionOp.getType())};
789     unsigned numResults = warpOp.getNumResults();
790     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
791         rewriter, warpOp, yieldValues, retTypes);
792     rewriter.setInsertionPointAfter(newWarpOp);
793 
794     // Every lane has one scalar value. These should be reduced.
795     Value laneValVec = newWarpOp.getResult(numResults);
796     Value laneVal = rewriter.create<vector::ExtractOp>(yieldLoc, laneValVec, 0);
797     laneVal =
798         distributedReductionFn(reductionOp.getLoc(), rewriter, laneVal,
799                                reductionOp.getKind(), newWarpOp.getWarpSize());
800     newWarpOp.getResult(operandIndex).replaceAllUsesWith(laneVal);
801     return success();
802   }
803 
804 private:
805   DistributedReductionFn distributedReductionFn;
806 };
807 
808 } // namespace
809 
810 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
811     RewritePatternSet &patterns,
812     const WarpExecuteOnLane0LoweringOptions &options) {
813   patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options);
814 }
815 
816 void mlir::vector::populateDistributeTransferWriteOpPatterns(
817     RewritePatternSet &patterns, DistributionMapFn distributionMapFn) {
818   patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn);
819 }
820 
821 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
822     RewritePatternSet &patterns) {
823   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
824                WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp>(
825       patterns.getContext());
826 }
827 
828 void mlir::vector::populateDistributeReduction(
829     RewritePatternSet &patterns,
830     DistributedReductionFn distributedReductionFn) {
831   patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn);
832 }
833 
834 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
835   Block *body = warpOp.getBody();
836 
837   // Keep track of the ops we want to hoist.
838   llvm::SmallSetVector<Operation *, 8> opsToMove;
839 
840   // Helper to check if a value is or will be defined outside of the region.
841   auto isDefinedOutsideOfBody = [&](Value value) {
842     auto *definingOp = value.getDefiningOp();
843     return (definingOp && opsToMove.count(definingOp)) ||
844            warpOp.isDefinedOutsideOfRegion(value);
845   };
846 
847   // Do not use walk here, as we do not want to go into nested regions and hoist
848   // operations from there.
849   for (auto &op : body->without_terminator()) {
850     bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
851       return result.getType().isa<VectorType>();
852     });
853     if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
854       opsToMove.insert(&op);
855   }
856 
857   // Move all the ops marked as uniform outside of the region.
858   for (Operation *op : opsToMove)
859     op->moveBefore(warpOp);
860 }
861