1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/SCF/IR/SCF.h"
13 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/Transforms/SideEffectUtils.h"
17 
18 #include <utility>
19 
20 using namespace mlir;
21 using namespace mlir::vector;
22 
23 static LogicalResult
24 rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
25                       const WarpExecuteOnLane0LoweringOptions &options) {
26   assert(warpOp.getBodyRegion().hasOneBlock() &&
27          "expected WarpOp with single block");
28   Block *warpOpBody = &warpOp.getBodyRegion().front();
29   Location loc = warpOp.getLoc();
30 
31   // Passed all checks. Start rewriting.
32   OpBuilder::InsertionGuard g(rewriter);
33   rewriter.setInsertionPoint(warpOp);
34 
35   // Create scf.if op.
36   Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
37   Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
38                                                  warpOp.getLaneid(), c0);
39   auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
40                                          /*withElseRegion=*/false);
41   rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
42 
43   // Store vectors that are defined outside of warpOp into the scratch pad
44   // buffer.
45   SmallVector<Value> bbArgReplacements;
46   for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
47     Value val = it.value();
48     Value bbArg = warpOpBody->getArgument(it.index());
49 
50     rewriter.setInsertionPoint(ifOp);
51     Value buffer =
52         options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType());
53 
54     // Store arg vector into buffer.
55     rewriter.setInsertionPoint(ifOp);
56     auto vectorType = val.getType().cast<VectorType>();
57     int64_t storeSize = vectorType.getShape()[0];
58     Value storeOffset = rewriter.create<arith::MulIOp>(
59         loc, warpOp.getLaneid(),
60         rewriter.create<arith::ConstantIndexOp>(loc, storeSize));
61     rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset);
62 
63     // Load bbArg vector from buffer.
64     rewriter.setInsertionPointToStart(ifOp.thenBlock());
65     auto bbArgType = bbArg.getType().cast<VectorType>();
66     Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0);
67     bbArgReplacements.push_back(loadOp);
68   }
69 
70   // Insert sync after all the stores and before all the loads.
71   if (!warpOp.getArgs().empty()) {
72     rewriter.setInsertionPoint(ifOp);
73     options.warpSyncronizationFn(loc, rewriter, warpOp);
74   }
75 
76   // Move body of warpOp to ifOp.
77   rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
78 
79   // Rewrite terminator and compute replacements of WarpOp results.
80   SmallVector<Value> replacements;
81   auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
82   Location yieldLoc = yieldOp.getLoc();
83   for (const auto &it : llvm::enumerate(yieldOp.operands())) {
84     Value val = it.value();
85     Type resultType = warpOp->getResultTypes()[it.index()];
86     rewriter.setInsertionPoint(ifOp);
87     Value buffer =
88         options.warpAllocationFn(loc, rewriter, warpOp, val.getType());
89 
90     // Store yielded value into buffer.
91     rewriter.setInsertionPoint(yieldOp);
92     if (val.getType().isa<VectorType>())
93       rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0);
94     else
95       rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0);
96 
97     // Load value from buffer (after warpOp).
98     rewriter.setInsertionPointAfter(ifOp);
99     if (resultType == val.getType()) {
100       // Result type and yielded value type are the same. This is a broadcast.
101       // E.g.:
102       // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
103       //   vector.yield %cst : f32
104       // }
105       // Both types are f32. The constant %cst is broadcasted to all lanes.
106       // This is described in more detail in the documentation of the op.
107       Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0);
108       replacements.push_back(loadOp);
109     } else {
110       auto loadedVectorType = resultType.cast<VectorType>();
111       int64_t loadSize = loadedVectorType.getShape()[0];
112 
113       // loadOffset = laneid * loadSize
114       Value loadOffset = rewriter.create<arith::MulIOp>(
115           loc, warpOp.getLaneid(),
116           rewriter.create<arith::ConstantIndexOp>(loc, loadSize));
117       Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType,
118                                                      buffer, loadOffset);
119       replacements.push_back(loadOp);
120     }
121   }
122 
123   // Insert sync after all the stores and before all the loads.
124   if (!yieldOp.operands().empty()) {
125     rewriter.setInsertionPointAfter(ifOp);
126     options.warpSyncronizationFn(loc, rewriter, warpOp);
127   }
128 
129   // Delete terminator and add empty scf.yield.
130   rewriter.eraseOp(yieldOp);
131   rewriter.setInsertionPointToEnd(ifOp.thenBlock());
132   rewriter.create<scf::YieldOp>(yieldLoc);
133 
134   // Compute replacements for WarpOp results.
135   rewriter.replaceOp(warpOp, replacements);
136 
137   return success();
138 }
139 
140 /// Helper to create a new WarpExecuteOnLane0Op with different signature.
141 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
142     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
143     ValueRange newYieldedValues, TypeRange newReturnTypes) {
144   // Create a new op before the existing one, with the extra operands.
145   OpBuilder::InsertionGuard g(rewriter);
146   rewriter.setInsertionPoint(warpOp);
147   auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
148       warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
149       warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
150 
151   Region &opBody = warpOp.getBodyRegion();
152   Region &newOpBody = newWarpOp.getBodyRegion();
153   Block &newOpFirstBlock = newOpBody.front();
154   rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
155   rewriter.eraseBlock(&newOpFirstBlock);
156   assert(newWarpOp.getWarpRegion().hasOneBlock() &&
157          "expected WarpOp with single block");
158 
159   auto yield =
160       cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
161 
162   rewriter.updateRootInPlace(
163       yield, [&]() { yield.operandsMutable().assign(newYieldedValues); });
164   return newWarpOp;
165 }
166 
167 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
168 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
169     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
170     ValueRange newYieldedValues, TypeRange newReturnTypes) {
171   SmallVector<Type> types(warpOp.getResultTypes().begin(),
172                           warpOp.getResultTypes().end());
173   types.append(newReturnTypes.begin(), newReturnTypes.end());
174   auto yield = cast<vector::YieldOp>(
175       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
176   SmallVector<Value> yieldValues(yield.getOperands().begin(),
177                                  yield.getOperands().end());
178   yieldValues.append(newYieldedValues.begin(), newYieldedValues.end());
179   WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
180       rewriter, warpOp, yieldValues, types);
181   rewriter.replaceOp(warpOp,
182                      newWarpOp.getResults().take_front(warpOp.getNumResults()));
183   return newWarpOp;
184 }
185 
186 /// Helper to know if an op can be hoisted out of the region.
187 static bool canBeHoisted(Operation *op,
188                          function_ref<bool(Value)> definedOutside) {
189   return llvm::all_of(op->getOperands(), definedOutside) &&
190          isSideEffectFree(op) && op->getNumRegions() == 0;
191 }
192 
193 /// Return a value yielded by `warpOp` which statifies the filter lamdba
194 /// condition and is not dead.
195 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
196                                 std::function<bool(Operation *)> fn) {
197   auto yield = cast<vector::YieldOp>(
198       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
199   for (OpOperand &yieldOperand : yield->getOpOperands()) {
200     Value yieldValues = yieldOperand.get();
201     Operation *definedOp = yieldValues.getDefiningOp();
202     if (definedOp && fn(definedOp)) {
203       if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
204         return &yieldOperand;
205     }
206   }
207   return {};
208 }
209 
210 // Clones `op` into a new operation that takes `operands` and returns
211 // `resultTypes`.
212 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
213                                               Location loc, Operation *op,
214                                               ArrayRef<Value> operands,
215                                               ArrayRef<Type> resultTypes) {
216   OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
217                      op->getAttrs());
218   return rewriter.create(res);
219 }
220 
221 /// Currently the distribution map is implicit based on the vector shape. In the
222 /// future it will be part of the op.
223 /// Example:
224 /// ```
225 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
226 ///   ...
227 ///   vector.yield %3 : vector<32x16x64xf32>
228 /// }
229 /// ```
230 /// Would have an implicit map of:
231 /// `(d0, d1, d2) -> (d0, d2)`
232 static AffineMap calculateImplicitMap(Value yield, Value ret) {
233   auto srcType = yield.getType().cast<VectorType>();
234   auto dstType = ret.getType().cast<VectorType>();
235   SmallVector<AffineExpr> perm;
236   // Check which dimensions of the yield value are different than the dimensions
237   // of the result to know the distributed dimensions. Then associate each
238   // distributed dimension to an ID in order.
239   for (unsigned i = 0, e = srcType.getRank(); i < e; i++) {
240     if (srcType.getDimSize(i) != dstType.getDimSize(i))
241       perm.push_back(getAffineDimExpr(i, yield.getContext()));
242   }
243   auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext());
244   return map;
245 }
246 
247 namespace {
248 
249 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
250   WarpOpToScfForPattern(MLIRContext *context,
251                         const WarpExecuteOnLane0LoweringOptions &options,
252                         PatternBenefit benefit = 1)
253       : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
254         options(options) {}
255 
256   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
257                                 PatternRewriter &rewriter) const override {
258     return rewriteWarpOpToScfFor(rewriter, warpOp, options);
259   }
260 
261 private:
262   const WarpExecuteOnLane0LoweringOptions &options;
263 };
264 
265 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute
266 /// op with the proper return type.
267 /// The new write op is updated to write the result of the new warp execute op.
268 /// The old `writeOp` is deleted.
269 static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
270                                             WarpExecuteOnLane0Op warpOp,
271                                             vector::TransferWriteOp writeOp,
272                                             VectorType targetType) {
273   assert(writeOp->getParentOp() == warpOp &&
274          "write must be nested immediately under warp");
275   OpBuilder::InsertionGuard g(rewriter);
276   WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
277       rewriter, warpOp, ValueRange{{writeOp.getVector()}},
278       TypeRange{targetType});
279   rewriter.setInsertionPointAfter(newWarpOp);
280   auto newWriteOp =
281       cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
282   rewriter.eraseOp(writeOp);
283   newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
284   return newWriteOp;
285 }
286 
287 /// Distribute transfer_write ops based on the affine map returned by
288 /// `distributionMapFn`.
289 /// Example:
290 /// ```
291 /// %0 = vector.warp_execute_on_lane_0(%id){
292 ///   ...
293 ///   vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
294 ///   vector.yield
295 /// }
296 /// ```
297 /// To
298 /// ```
299 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
300 ///   ...
301 ///   vector.yield %v : vector<32xf32>
302 /// }
303 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
304 struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
305   WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
306                       PatternBenefit b = 1)
307       : OpRewritePattern<vector::TransferWriteOp>(ctx, b),
308         distributionMapFn(std::move(fn)) {}
309 
310   /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
311   /// are multiples of the distribution ratio are supported at the moment.
312   LogicalResult tryDistributeOp(RewriterBase &rewriter,
313                                 vector::TransferWriteOp writeOp,
314                                 WarpExecuteOnLane0Op warpOp) const {
315     VectorType writtenVectorType = writeOp.getVectorType();
316 
317     // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
318     // to separate it from the rest.
319     if (writtenVectorType.getRank() == 0)
320       return failure();
321 
322     // 2. Compute the distribution map.
323     AffineMap map = distributionMapFn(writeOp);
324     if (map.getNumResults() != 1)
325       return writeOp->emitError("multi-dim distribution not implemented yet");
326 
327     // 3. Compute the targetType using the distribution map.
328     SmallVector<int64_t> targetShape(writtenVectorType.getShape().begin(),
329                                      writtenVectorType.getShape().end());
330     for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
331       unsigned position = map.getDimPosition(i);
332       if (targetShape[position] % warpOp.getWarpSize() != 0)
333         return failure();
334       targetShape[position] = targetShape[position] / warpOp.getWarpSize();
335     }
336     VectorType targetType =
337         VectorType::get(targetShape, writtenVectorType.getElementType());
338 
339     // 4. clone the write into a new WarpExecuteOnLane0Op to separate it from
340     // the rest.
341     vector::TransferWriteOp newWriteOp =
342         cloneWriteOp(rewriter, warpOp, writeOp, targetType);
343 
344     // 5. Reindex the write using the distribution map.
345     auto newWarpOp =
346         newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
347     rewriter.setInsertionPoint(newWriteOp);
348     AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
349     Location loc = newWriteOp.getLoc();
350     SmallVector<Value> indices(newWriteOp.getIndices().begin(),
351                                newWriteOp.getIndices().end());
352     for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
353       AffineExpr d0, d1;
354       bindDims(newWarpOp.getContext(), d0, d1);
355       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
356       if (!indexExpr)
357         continue;
358       unsigned indexPos = indexExpr.getPosition();
359       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
360       auto scale = rewriter.getAffineConstantExpr(targetShape[vectorPos]);
361       indices[indexPos] =
362           makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
363                                   {indices[indexPos], newWarpOp.getLaneid()});
364     }
365     newWriteOp.getIndicesMutable().assign(indices);
366 
367     return success();
368   }
369 
370   /// Extract TransferWriteOps of vector<1x> into a separate warp op.
371   LogicalResult tryExtractOp(RewriterBase &rewriter,
372                              vector::TransferWriteOp writeOp,
373                              WarpExecuteOnLane0Op warpOp) const {
374     Location loc = writeOp.getLoc();
375     VectorType vecType = writeOp.getVectorType();
376 
377     // Only sink out vector of 1 element for now to not serialize large vector
378     // store. This can later be controlled by user.
379     if (vecType.getNumElements() != 1)
380       return failure();
381 
382     // Do not process warp ops that contain only TransferWriteOps.
383     if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
384           return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
385         }))
386       return failure();
387 
388     SmallVector<Value> yieldValues = {writeOp.getVector()};
389     SmallVector<Type> retTypes = {vecType};
390     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
391         rewriter, warpOp, yieldValues, retTypes);
392     rewriter.setInsertionPointAfter(newWarpOp);
393 
394     // Create a second warp op that contains only writeOp.
395     auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
396         loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
397     Block &body = secondWarpOp.getBodyRegion().front();
398     rewriter.setInsertionPointToStart(&body);
399     auto newWriteOp =
400         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
401     newWriteOp.getVectorMutable().assign(
402         newWarpOp.getResult(newWarpOp.getNumResults() - 1));
403     rewriter.eraseOp(writeOp);
404     rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
405     return success();
406   }
407 
408   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
409                                 PatternRewriter &rewriter) const override {
410     // Ops with mask not supported yet.
411     if (writeOp.getMask())
412       return failure();
413 
414     auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
415     if (!warpOp)
416       return failure();
417 
418     // There must be no op with a side effect after writeOp.
419     Operation *nextOp = writeOp.getOperation();
420     while ((nextOp = nextOp->getNextNode()))
421       if (!isSideEffectFree(nextOp))
422         return failure();
423 
424     if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
425           return writeOp.getVector() == value ||
426                  warpOp.isDefinedOutsideOfRegion(value);
427         }))
428       return failure();
429 
430     if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
431       return success();
432 
433     if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
434       return success();
435 
436     return failure();
437   }
438 
439 private:
440   DistributionMapFn distributionMapFn;
441 };
442 
443 /// Sink out elementwise op feeding into a warp op yield.
444 /// ```
445 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
446 ///   ...
447 ///   %3 = arith.addf %1, %2 : vector<32xf32>
448 ///   vector.yield %3 : vector<32xf32>
449 /// }
450 /// ```
451 /// To
452 /// ```
453 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
454 /// vector<1xf32>, vector<1xf32>) {
455 ///   ...
456 ///   %4 = arith.addf %2, %3 : vector<32xf32>
457 ///   vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
458 ///   vector<32xf32>
459 /// }
460 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
461 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
462   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
463   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
464                                 PatternRewriter &rewriter) const override {
465     OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
466       return OpTrait::hasElementwiseMappableTraits(op);
467     });
468     if (!yieldOperand)
469       return failure();
470     Operation *elementWise = yieldOperand->get().getDefiningOp();
471     unsigned operandIndex = yieldOperand->getOperandNumber();
472     Value distributedVal = warpOp.getResult(operandIndex);
473     SmallVector<Value> yieldValues;
474     SmallVector<Type> retTypes;
475     Location loc = warpOp.getLoc();
476     for (OpOperand &operand : elementWise->getOpOperands()) {
477       Type targetType;
478       if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) {
479         // If the result type is a vector, the operands must also be vectors.
480         auto operandType = operand.get().getType().cast<VectorType>();
481         targetType =
482             VectorType::get(vecType.getShape(), operandType.getElementType());
483       } else {
484         auto operandType = operand.get().getType();
485         assert(!operandType.isa<VectorType>() &&
486                "unexpected yield of vector from op with scalar result type");
487         targetType = operandType;
488       }
489       retTypes.push_back(targetType);
490       yieldValues.push_back(operand.get());
491     }
492     unsigned numResults = warpOp.getNumResults();
493     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
494         rewriter, warpOp, yieldValues, retTypes);
495     rewriter.setInsertionPointAfter(newWarpOp);
496     SmallVector<Value> newOperands(elementWise->getOperands().begin(),
497                                    elementWise->getOperands().end());
498     for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
499       newOperands[i] = newWarpOp.getResult(i + numResults);
500     }
501     OpBuilder::InsertionGuard g(rewriter);
502     rewriter.setInsertionPointAfter(newWarpOp);
503     Operation *newOp = cloneOpWithOperandsAndTypes(
504         rewriter, loc, elementWise, newOperands,
505         {newWarpOp.getResult(operandIndex).getType()});
506     newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0));
507     return success();
508   }
509 };
510 
511 /// Sink out transfer_read op feeding into a warp op yield.
512 /// ```
513 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
514 ///   ...
515 //    %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
516 //    vector<32xf32>
517 ///   vector.yield %2 : vector<32xf32>
518 /// }
519 /// ```
520 /// To
521 /// ```
522 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
523 /// vector<1xf32>, vector<1xf32>) {
524 ///   ...
525 ///   %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
526 ///   vector<32xf32> vector.yield %2 : vector<32xf32>
527 /// }
528 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
529 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
530   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
531   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
532                                 PatternRewriter &rewriter) const override {
533     OpOperand *operand = getWarpResult(
534         warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
535     if (!operand)
536       return failure();
537     auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
538     unsigned operandIndex = operand->getOperandNumber();
539     Value distributedVal = warpOp.getResult(operandIndex);
540 
541     SmallVector<Value, 4> indices(read.getIndices().begin(),
542                                   read.getIndices().end());
543     AffineMap map = calculateImplicitMap(read.getResult(), distributedVal);
544     AffineMap indexMap = map.compose(read.getPermutationMap());
545     OpBuilder::InsertionGuard g(rewriter);
546     rewriter.setInsertionPointAfter(warpOp);
547     for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
548       AffineExpr d0, d1;
549       bindDims(read.getContext(), d0, d1);
550       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
551       if (!indexExpr)
552         continue;
553       unsigned indexPos = indexExpr.getPosition();
554       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
555       int64_t scale =
556           distributedVal.getType().cast<VectorType>().getDimSize(vectorPos);
557       indices[indexPos] =
558           makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1,
559                                   {indices[indexPos], warpOp.getLaneid()});
560     }
561     Value newRead = rewriter.create<vector::TransferReadOp>(
562         read.getLoc(), distributedVal.getType(), read.getSource(), indices,
563         read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
564         read.getInBoundsAttr());
565     distributedVal.replaceAllUsesWith(newRead);
566     return success();
567   }
568 };
569 
570 /// Remove any result that has no use along with the matching yieldOp operand.
571 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
572 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
573   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
574   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
575                                 PatternRewriter &rewriter) const override {
576     SmallVector<Type> resultTypes;
577     SmallVector<Value> yieldValues;
578     auto yield = cast<vector::YieldOp>(
579         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
580     for (OpResult result : warpOp.getResults()) {
581       if (result.use_empty())
582         continue;
583       resultTypes.push_back(result.getType());
584       yieldValues.push_back(yield.getOperand(result.getResultNumber()));
585     }
586     if (yield.getNumOperands() == yieldValues.size())
587       return failure();
588     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
589         rewriter, warpOp, yieldValues, resultTypes);
590     unsigned resultIndex = 0;
591     for (OpResult result : warpOp.getResults()) {
592       if (result.use_empty())
593         continue;
594       result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++));
595     }
596     rewriter.eraseOp(warpOp);
597     return success();
598   }
599 };
600 
601 // If an operand is directly yielded out of the region we can forward it
602 // directly and it doesn't need to go through the region.
603 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
604   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
605   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
606                                 PatternRewriter &rewriter) const override {
607     SmallVector<Type> resultTypes;
608     SmallVector<Value> yieldValues;
609     auto yield = cast<vector::YieldOp>(
610         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
611     Value valForwarded;
612     unsigned resultIndex;
613     for (OpOperand &operand : yield->getOpOperands()) {
614       Value result = warpOp.getResult(operand.getOperandNumber());
615       if (result.use_empty())
616         continue;
617 
618       // Assume all the values coming from above are uniform.
619       if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
620         if (result.getType() != operand.get().getType())
621           continue;
622         valForwarded = operand.get();
623         resultIndex = operand.getOperandNumber();
624         break;
625       }
626       auto arg = operand.get().dyn_cast<BlockArgument>();
627       if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
628         continue;
629       Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
630       if (result.getType() != warpOperand.getType())
631         continue;
632       valForwarded = warpOperand;
633       resultIndex = operand.getOperandNumber();
634       break;
635     }
636     if (!valForwarded)
637       return failure();
638     warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded);
639     return success();
640   }
641 };
642 
643 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
644   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
645   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
646                                 PatternRewriter &rewriter) const override {
647     OpOperand *operand = getWarpResult(
648         warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); });
649     if (!operand)
650       return failure();
651     unsigned int operandNumber = operand->getOperandNumber();
652     auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
653     Location loc = broadcastOp.getLoc();
654     auto destVecType =
655         warpOp->getResultTypes()[operandNumber].cast<VectorType>();
656     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
657         rewriter, warpOp, {broadcastOp.getSource()},
658         {broadcastOp.getSource().getType()});
659     rewriter.setInsertionPointAfter(newWarpOp);
660     Value broadcasted = rewriter.create<vector::BroadcastOp>(
661         loc, destVecType, newWarpOp->getResults().back());
662     newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
663     return success();
664   }
665 };
666 
667 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
668 /// the scf.ForOp is the last operation in the region so that it doesn't change
669 /// the order of execution. This creates a new scf.for region after the
670 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
671 /// WarpExecuteOnLane0Op region. Example:
672 /// ```
673 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
674 ///   ...
675 ///   %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
676 ///   -> (vector<128xf32>) {
677 ///     ...
678 ///     scf.yield %r : vector<128xf32>
679 ///   }
680 ///   vector.yield %v1 : vector<128xf32>
681 /// }
682 /// ```
683 /// To:
684 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
685 ///   ...
686 ///   vector.yield %v : vector<128xf32>
687 /// }
688 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
689 ///   -> (vector<4xf32>) {
690 ///     %iw = vector.warp_execute_on_lane_0(%laneid)
691 ///     args(%varg : vector<4xf32>) -> (vector<4xf32>) {
692 ///     ^bb0(%arg: vector<128xf32>):
693 ///       ...
694 ///       vector.yield %ir : vector<128xf32>
695 ///     }
696 ///     scf.yield %iw : vector<4xf32>
697 ///  }
698 /// ```
699 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
700   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
701   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
702                                 PatternRewriter &rewriter) const override {
703     auto yield = cast<vector::YieldOp>(
704         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
705     // Only pick up forOp if it is the last op in the region.
706     Operation *lastNode = yield->getPrevNode();
707     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
708     if (!forOp)
709       return failure();
710     SmallVector<Value> newOperands;
711     SmallVector<unsigned> resultIdx;
712     // Collect all the outputs coming from the forOp.
713     for (OpOperand &yieldOperand : yield->getOpOperands()) {
714       if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
715         continue;
716       auto forResult = yieldOperand.get().cast<OpResult>();
717       newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber()));
718       yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
719       resultIdx.push_back(yieldOperand.getOperandNumber());
720     }
721     OpBuilder::InsertionGuard g(rewriter);
722     rewriter.setInsertionPointAfter(warpOp);
723     // Create a new for op outside the region with a WarpExecuteOnLane0Op region
724     // inside.
725     auto newForOp = rewriter.create<scf::ForOp>(
726         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
727         forOp.getStep(), newOperands);
728     rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
729     auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
730         warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(),
731         warpOp.getWarpSize(), newForOp.getRegionIterArgs(),
732         forOp.getResultTypes());
733 
734     SmallVector<Value> argMapping;
735     argMapping.push_back(newForOp.getInductionVar());
736     for (Value args : innerWarp.getBody()->getArguments()) {
737       argMapping.push_back(args);
738     }
739     SmallVector<Value> yieldOperands;
740     for (Value operand : forOp.getBody()->getTerminator()->getOperands())
741       yieldOperands.push_back(operand);
742     rewriter.eraseOp(forOp.getBody()->getTerminator());
743     rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
744     rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
745     rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
746     rewriter.setInsertionPointAfter(innerWarp);
747     if (!innerWarp.getResults().empty())
748       rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
749     rewriter.eraseOp(forOp);
750     // Replace the warpOp result coming from the original ForOp.
751     for (const auto &res : llvm::enumerate(resultIdx)) {
752       warpOp.getResult(res.value())
753           .replaceAllUsesWith(newForOp.getResult(res.index()));
754       newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value()));
755     }
756     return success();
757   }
758 };
759 
760 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
761 /// The vector is reduced in parallel. Currently limited to vector size matching
762 /// the warpOp size. E.g.:
763 /// ```
764 /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
765 ///   %0 = "some_def"() : () -> (vector<32xf32>)
766 ///   %1 = vector.reduction "add", %0 : vector<32xf32> into f32
767 ///   vector_ext.yield %1 : f32
768 /// }
769 /// ```
770 /// is lowered to:
771 /// ```
772 /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
773 ///   %1 = "some_def"() : () -> (vector<32xf32>)
774 ///   vector_ext.yield %1 : vector<32xf32>
775 /// }
776 /// %a = vector.extract %0[0] : vector<1xf32>
777 /// %r = ("warp.reduction %a")
778 /// ```
779 struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
780   WarpOpReduction(MLIRContext *context,
781                   DistributedReductionFn distributedReductionFn,
782                   PatternBenefit benefit = 1)
783       : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
784         distributedReductionFn(distributedReductionFn) {}
785 
786   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
787                                 PatternRewriter &rewriter) const override {
788     OpOperand *yieldOperand = getWarpResult(
789         warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); });
790     if (!yieldOperand)
791       return failure();
792 
793     auto reductionOp =
794         cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
795     auto vectorType = reductionOp.getVector().getType().cast<VectorType>();
796     // Only rank 1 vectors supported.
797     if (vectorType.getRank() != 1)
798       return rewriter.notifyMatchFailure(
799           warpOp, "Only rank 1 reductions can be distributed.");
800     // Only warp_size-sized vectors supported.
801     if (static_cast<uint64_t>(vectorType.getShape()[0]) != warpOp.getWarpSize())
802       return rewriter.notifyMatchFailure(
803           warpOp, "Reduction vector dimension must match was size.");
804     // Only f32 and i32 element types are supported.
805     if (!reductionOp.getType().isF32() &&
806         !reductionOp.getType().isSignlessInteger(32))
807       return rewriter.notifyMatchFailure(
808           warpOp,
809           "Reduction distribution currently only supports 32bits types.");
810 
811     Location yieldLoc = yieldOperand->getOwner()->getLoc();
812 
813     // Return vector that will be reduced from the WarpExecuteOnLane0Op.
814     unsigned operandIndex = yieldOperand->getOperandNumber();
815     SmallVector<Value> yieldValues = {reductionOp.getVector()};
816     SmallVector<Type> retTypes = {VectorType::get({1}, reductionOp.getType())};
817     unsigned numResults = warpOp.getNumResults();
818     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
819         rewriter, warpOp, yieldValues, retTypes);
820     rewriter.setInsertionPointAfter(newWarpOp);
821 
822     // Every lane has one scalar value. These should be reduced.
823     Value laneValVec = newWarpOp.getResult(numResults);
824     Value laneVal = rewriter.create<vector::ExtractOp>(yieldLoc, laneValVec, 0);
825     laneVal =
826         distributedReductionFn(reductionOp.getLoc(), rewriter, laneVal,
827                                reductionOp.getKind(), newWarpOp.getWarpSize());
828     newWarpOp.getResult(operandIndex).replaceAllUsesWith(laneVal);
829     return success();
830   }
831 
832 private:
833   DistributedReductionFn distributedReductionFn;
834 };
835 
836 } // namespace
837 
838 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
839     RewritePatternSet &patterns,
840     const WarpExecuteOnLane0LoweringOptions &options) {
841   patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options);
842 }
843 
844 void mlir::vector::populateDistributeTransferWriteOpPatterns(
845     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn) {
846   patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn);
847 }
848 
849 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
850     RewritePatternSet &patterns) {
851   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
852                WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp>(
853       patterns.getContext());
854 }
855 
856 void mlir::vector::populateDistributeReduction(
857     RewritePatternSet &patterns,
858     DistributedReductionFn distributedReductionFn) {
859   patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn);
860 }
861 
862 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
863   Block *body = warpOp.getBody();
864 
865   // Keep track of the ops we want to hoist.
866   llvm::SmallSetVector<Operation *, 8> opsToMove;
867 
868   // Helper to check if a value is or will be defined outside of the region.
869   auto isDefinedOutsideOfBody = [&](Value value) {
870     auto *definingOp = value.getDefiningOp();
871     return (definingOp && opsToMove.count(definingOp)) ||
872            warpOp.isDefinedOutsideOfRegion(value);
873   };
874 
875   // Do not use walk here, as we do not want to go into nested regions and hoist
876   // operations from there.
877   for (auto &op : body->without_terminator()) {
878     bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
879       return result.getType().isa<VectorType>();
880     });
881     if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
882       opsToMove.insert(&op);
883   }
884 
885   // Move all the ops marked as uniform outside of the region.
886   for (Operation *op : opsToMove)
887     op->moveBefore(warpOp);
888 }
889