1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/SCF/IR/SCF.h"
13 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/Transforms/SideEffectUtils.h"
17 #include "llvm/ADT/SetVector.h"
18 #include <utility>
19 
20 using namespace mlir;
21 using namespace mlir::vector;
22 
23 static LogicalResult
24 rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
25                       const WarpExecuteOnLane0LoweringOptions &options) {
26   assert(warpOp.getBodyRegion().hasOneBlock() &&
27          "expected WarpOp with single block");
28   Block *warpOpBody = &warpOp.getBodyRegion().front();
29   Location loc = warpOp.getLoc();
30 
31   // Passed all checks. Start rewriting.
32   OpBuilder::InsertionGuard g(rewriter);
33   rewriter.setInsertionPoint(warpOp);
34 
35   // Create scf.if op.
36   Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
37   Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
38                                                  warpOp.getLaneid(), c0);
39   auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
40                                          /*withElseRegion=*/false);
41   rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
42 
43   // Store vectors that are defined outside of warpOp into the scratch pad
44   // buffer.
45   SmallVector<Value> bbArgReplacements;
46   for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
47     Value val = it.value();
48     Value bbArg = warpOpBody->getArgument(it.index());
49 
50     rewriter.setInsertionPoint(ifOp);
51     Value buffer =
52         options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType());
53 
54     // Store arg vector into buffer.
55     rewriter.setInsertionPoint(ifOp);
56     auto vectorType = val.getType().cast<VectorType>();
57     int64_t storeSize = vectorType.getShape()[0];
58     Value storeOffset = rewriter.create<arith::MulIOp>(
59         loc, warpOp.getLaneid(),
60         rewriter.create<arith::ConstantIndexOp>(loc, storeSize));
61     rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset);
62 
63     // Load bbArg vector from buffer.
64     rewriter.setInsertionPointToStart(ifOp.thenBlock());
65     auto bbArgType = bbArg.getType().cast<VectorType>();
66     Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0);
67     bbArgReplacements.push_back(loadOp);
68   }
69 
70   // Insert sync after all the stores and before all the loads.
71   if (!warpOp.getArgs().empty()) {
72     rewriter.setInsertionPoint(ifOp);
73     options.warpSyncronizationFn(loc, rewriter, warpOp);
74   }
75 
76   // Move body of warpOp to ifOp.
77   rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
78 
79   // Rewrite terminator and compute replacements of WarpOp results.
80   SmallVector<Value> replacements;
81   auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
82   Location yieldLoc = yieldOp.getLoc();
83   for (const auto &it : llvm::enumerate(yieldOp.operands())) {
84     Value val = it.value();
85     Type resultType = warpOp->getResultTypes()[it.index()];
86     rewriter.setInsertionPoint(ifOp);
87     Value buffer =
88         options.warpAllocationFn(loc, rewriter, warpOp, val.getType());
89 
90     // Store yielded value into buffer.
91     rewriter.setInsertionPoint(yieldOp);
92     if (val.getType().isa<VectorType>())
93       rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0);
94     else
95       rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0);
96 
97     // Load value from buffer (after warpOp).
98     rewriter.setInsertionPointAfter(ifOp);
99     if (resultType == val.getType()) {
100       // Result type and yielded value type are the same. This is a broadcast.
101       // E.g.:
102       // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
103       //   vector.yield %cst : f32
104       // }
105       // Both types are f32. The constant %cst is broadcasted to all lanes.
106       // This is described in more detail in the documentation of the op.
107       Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0);
108       replacements.push_back(loadOp);
109     } else {
110       auto loadedVectorType = resultType.cast<VectorType>();
111       int64_t loadSize = loadedVectorType.getShape()[0];
112 
113       // loadOffset = laneid * loadSize
114       Value loadOffset = rewriter.create<arith::MulIOp>(
115           loc, warpOp.getLaneid(),
116           rewriter.create<arith::ConstantIndexOp>(loc, loadSize));
117       Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType,
118                                                      buffer, loadOffset);
119       replacements.push_back(loadOp);
120     }
121   }
122 
123   // Insert sync after all the stores and before all the loads.
124   if (!yieldOp.operands().empty()) {
125     rewriter.setInsertionPointAfter(ifOp);
126     options.warpSyncronizationFn(loc, rewriter, warpOp);
127   }
128 
129   // Delete terminator and add empty scf.yield.
130   rewriter.eraseOp(yieldOp);
131   rewriter.setInsertionPointToEnd(ifOp.thenBlock());
132   rewriter.create<scf::YieldOp>(yieldLoc);
133 
134   // Compute replacements for WarpOp results.
135   rewriter.replaceOp(warpOp, replacements);
136 
137   return success();
138 }
139 
140 /// Helper to create a new WarpExecuteOnLane0Op with different signature.
141 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
142     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
143     ValueRange newYieldedValues, TypeRange newReturnTypes) {
144   // Create a new op before the existing one, with the extra operands.
145   OpBuilder::InsertionGuard g(rewriter);
146   rewriter.setInsertionPoint(warpOp);
147   auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
148       warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
149       warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
150 
151   Region &opBody = warpOp.getBodyRegion();
152   Region &newOpBody = newWarpOp.getBodyRegion();
153   Block &newOpFirstBlock = newOpBody.front();
154   rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
155   rewriter.eraseBlock(&newOpFirstBlock);
156   assert(newWarpOp.getWarpRegion().hasOneBlock() &&
157          "expected WarpOp with single block");
158 
159   auto yield =
160       cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
161 
162   rewriter.updateRootInPlace(
163       yield, [&]() { yield.operandsMutable().assign(newYieldedValues); });
164   return newWarpOp;
165 }
166 
167 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
168 /// `indices` return the index of each new output.
169 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
170     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
171     ValueRange newYieldedValues, TypeRange newReturnTypes,
172     llvm::SmallVector<size_t> &indices) {
173   SmallVector<Type> types(warpOp.getResultTypes().begin(),
174                           warpOp.getResultTypes().end());
175   auto yield = cast<vector::YieldOp>(
176       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
177   llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
178                                               yield.getOperands().end());
179   for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
180     if (yieldValues.insert(std::get<0>(newRet))) {
181       types.push_back(std::get<1>(newRet));
182       indices.push_back(yieldValues.size() - 1);
183     } else {
184       // If the value already exit the region don't create a new output.
185       for (auto &yieldOperand : llvm::enumerate(yieldValues.getArrayRef())) {
186         if (yieldOperand.value() == std::get<0>(newRet)) {
187           indices.push_back(yieldOperand.index());
188           break;
189         }
190       }
191     }
192   }
193   yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
194   WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
195       rewriter, warpOp, yieldValues.getArrayRef(), types);
196   rewriter.replaceOp(warpOp,
197                      newWarpOp.getResults().take_front(warpOp.getNumResults()));
198   return newWarpOp;
199 }
200 
201 /// Helper to know if an op can be hoisted out of the region.
202 static bool canBeHoisted(Operation *op,
203                          function_ref<bool(Value)> definedOutside) {
204   return llvm::all_of(op->getOperands(), definedOutside) &&
205          isSideEffectFree(op) && op->getNumRegions() == 0;
206 }
207 
208 /// Return a value yielded by `warpOp` which statifies the filter lamdba
209 /// condition and is not dead.
210 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
211                                 std::function<bool(Operation *)> fn) {
212   auto yield = cast<vector::YieldOp>(
213       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
214   for (OpOperand &yieldOperand : yield->getOpOperands()) {
215     Value yieldValues = yieldOperand.get();
216     Operation *definedOp = yieldValues.getDefiningOp();
217     if (definedOp && fn(definedOp)) {
218       if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
219         return &yieldOperand;
220     }
221   }
222   return {};
223 }
224 
225 // Clones `op` into a new operation that takes `operands` and returns
226 // `resultTypes`.
227 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
228                                               Location loc, Operation *op,
229                                               ArrayRef<Value> operands,
230                                               ArrayRef<Type> resultTypes) {
231   OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
232                      op->getAttrs());
233   return rewriter.create(res);
234 }
235 
236 /// Currently the distribution map is implicit based on the vector shape. In the
237 /// future it will be part of the op.
238 /// Example:
239 /// ```
240 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
241 ///   ...
242 ///   vector.yield %3 : vector<32x16x64xf32>
243 /// }
244 /// ```
245 /// Would have an implicit map of:
246 /// `(d0, d1, d2) -> (d0, d2)`
247 static AffineMap calculateImplicitMap(Value yield, Value ret) {
248   auto srcType = yield.getType().cast<VectorType>();
249   auto dstType = ret.getType().cast<VectorType>();
250   SmallVector<AffineExpr> perm;
251   // Check which dimensions of the yield value are different than the dimensions
252   // of the result to know the distributed dimensions. Then associate each
253   // distributed dimension to an ID in order.
254   for (unsigned i = 0, e = srcType.getRank(); i < e; i++) {
255     if (srcType.getDimSize(i) != dstType.getDimSize(i))
256       perm.push_back(getAffineDimExpr(i, yield.getContext()));
257   }
258   auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext());
259   return map;
260 }
261 
262 namespace {
263 
264 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
265   WarpOpToScfForPattern(MLIRContext *context,
266                         const WarpExecuteOnLane0LoweringOptions &options,
267                         PatternBenefit benefit = 1)
268       : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
269         options(options) {}
270 
271   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
272                                 PatternRewriter &rewriter) const override {
273     return rewriteWarpOpToScfFor(rewriter, warpOp, options);
274   }
275 
276 private:
277   const WarpExecuteOnLane0LoweringOptions &options;
278 };
279 
280 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute
281 /// op with the proper return type.
282 /// The new write op is updated to write the result of the new warp execute op.
283 /// The old `writeOp` is deleted.
284 static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
285                                             WarpExecuteOnLane0Op warpOp,
286                                             vector::TransferWriteOp writeOp,
287                                             VectorType targetType) {
288   assert(writeOp->getParentOp() == warpOp &&
289          "write must be nested immediately under warp");
290   OpBuilder::InsertionGuard g(rewriter);
291   SmallVector<size_t> newRetIndices;
292   WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
293       rewriter, warpOp, ValueRange{{writeOp.getVector()}},
294       TypeRange{targetType}, newRetIndices);
295   rewriter.setInsertionPointAfter(newWarpOp);
296   auto newWriteOp =
297       cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
298   rewriter.eraseOp(writeOp);
299   newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
300   return newWriteOp;
301 }
302 
303 /// Distribute transfer_write ops based on the affine map returned by
304 /// `distributionMapFn`.
305 /// Example:
306 /// ```
307 /// %0 = vector.warp_execute_on_lane_0(%id){
308 ///   ...
309 ///   vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
310 ///   vector.yield
311 /// }
312 /// ```
313 /// To
314 /// ```
315 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
316 ///   ...
317 ///   vector.yield %v : vector<32xf32>
318 /// }
319 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
320 struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
321   WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
322                       PatternBenefit b = 1)
323       : OpRewritePattern<vector::TransferWriteOp>(ctx, b),
324         distributionMapFn(std::move(fn)) {}
325 
326   /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
327   /// are multiples of the distribution ratio are supported at the moment.
328   LogicalResult tryDistributeOp(RewriterBase &rewriter,
329                                 vector::TransferWriteOp writeOp,
330                                 WarpExecuteOnLane0Op warpOp) const {
331     VectorType writtenVectorType = writeOp.getVectorType();
332 
333     // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
334     // to separate it from the rest.
335     if (writtenVectorType.getRank() == 0)
336       return failure();
337 
338     // 2. Compute the distribution map.
339     AffineMap map = distributionMapFn(writeOp);
340     if (map.getNumResults() != 1)
341       return writeOp->emitError("multi-dim distribution not implemented yet");
342 
343     // 3. Compute the targetType using the distribution map.
344     SmallVector<int64_t> targetShape(writtenVectorType.getShape().begin(),
345                                      writtenVectorType.getShape().end());
346     for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
347       unsigned position = map.getDimPosition(i);
348       if (targetShape[position] % warpOp.getWarpSize() != 0)
349         return failure();
350       targetShape[position] = targetShape[position] / warpOp.getWarpSize();
351     }
352     VectorType targetType =
353         VectorType::get(targetShape, writtenVectorType.getElementType());
354 
355     // 4. clone the write into a new WarpExecuteOnLane0Op to separate it from
356     // the rest.
357     vector::TransferWriteOp newWriteOp =
358         cloneWriteOp(rewriter, warpOp, writeOp, targetType);
359 
360     // 5. Reindex the write using the distribution map.
361     auto newWarpOp =
362         newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
363     rewriter.setInsertionPoint(newWriteOp);
364     AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
365     Location loc = newWriteOp.getLoc();
366     SmallVector<Value> indices(newWriteOp.getIndices().begin(),
367                                newWriteOp.getIndices().end());
368     for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
369       AffineExpr d0, d1;
370       bindDims(newWarpOp.getContext(), d0, d1);
371       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
372       if (!indexExpr)
373         continue;
374       unsigned indexPos = indexExpr.getPosition();
375       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
376       auto scale = rewriter.getAffineConstantExpr(targetShape[vectorPos]);
377       indices[indexPos] =
378           makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
379                                   {indices[indexPos], newWarpOp.getLaneid()});
380     }
381     newWriteOp.getIndicesMutable().assign(indices);
382 
383     return success();
384   }
385 
386   /// Extract TransferWriteOps of vector<1x> into a separate warp op.
387   LogicalResult tryExtractOp(RewriterBase &rewriter,
388                              vector::TransferWriteOp writeOp,
389                              WarpExecuteOnLane0Op warpOp) const {
390     Location loc = writeOp.getLoc();
391     VectorType vecType = writeOp.getVectorType();
392 
393     // Only sink out vector of 1 element for now to not serialize large vector
394     // store. This can later be controlled by user.
395     if (vecType.getNumElements() != 1)
396       return failure();
397 
398     // Do not process warp ops that contain only TransferWriteOps.
399     if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
400           return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
401         }))
402       return failure();
403 
404     SmallVector<Value> yieldValues = {writeOp.getVector()};
405     SmallVector<Type> retTypes = {vecType};
406     SmallVector<size_t> newRetIndices;
407     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
408         rewriter, warpOp, yieldValues, retTypes, newRetIndices);
409     rewriter.setInsertionPointAfter(newWarpOp);
410 
411     // Create a second warp op that contains only writeOp.
412     auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
413         loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
414     Block &body = secondWarpOp.getBodyRegion().front();
415     rewriter.setInsertionPointToStart(&body);
416     auto newWriteOp =
417         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
418     newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
419     rewriter.eraseOp(writeOp);
420     rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
421     return success();
422   }
423 
424   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
425                                 PatternRewriter &rewriter) const override {
426     // Ops with mask not supported yet.
427     if (writeOp.getMask())
428       return failure();
429 
430     auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
431     if (!warpOp)
432       return failure();
433 
434     // There must be no op with a side effect after writeOp.
435     Operation *nextOp = writeOp.getOperation();
436     while ((nextOp = nextOp->getNextNode()))
437       if (!isSideEffectFree(nextOp))
438         return failure();
439 
440     if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
441           return writeOp.getVector() == value ||
442                  warpOp.isDefinedOutsideOfRegion(value);
443         }))
444       return failure();
445 
446     if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
447       return success();
448 
449     if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
450       return success();
451 
452     return failure();
453   }
454 
455 private:
456   DistributionMapFn distributionMapFn;
457 };
458 
459 /// Sink out elementwise op feeding into a warp op yield.
460 /// ```
461 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
462 ///   ...
463 ///   %3 = arith.addf %1, %2 : vector<32xf32>
464 ///   vector.yield %3 : vector<32xf32>
465 /// }
466 /// ```
467 /// To
468 /// ```
469 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
470 /// vector<1xf32>, vector<1xf32>) {
471 ///   ...
472 ///   %4 = arith.addf %2, %3 : vector<32xf32>
473 ///   vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
474 ///   vector<32xf32>
475 /// }
476 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
477 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
478   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
479   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
480                                 PatternRewriter &rewriter) const override {
481     OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
482       return OpTrait::hasElementwiseMappableTraits(op);
483     });
484     if (!yieldOperand)
485       return failure();
486     Operation *elementWise = yieldOperand->get().getDefiningOp();
487     unsigned operandIndex = yieldOperand->getOperandNumber();
488     Value distributedVal = warpOp.getResult(operandIndex);
489     SmallVector<Value> yieldValues;
490     SmallVector<Type> retTypes;
491     Location loc = warpOp.getLoc();
492     for (OpOperand &operand : elementWise->getOpOperands()) {
493       Type targetType;
494       if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) {
495         // If the result type is a vector, the operands must also be vectors.
496         auto operandType = operand.get().getType().cast<VectorType>();
497         targetType =
498             VectorType::get(vecType.getShape(), operandType.getElementType());
499       } else {
500         auto operandType = operand.get().getType();
501         assert(!operandType.isa<VectorType>() &&
502                "unexpected yield of vector from op with scalar result type");
503         targetType = operandType;
504       }
505       retTypes.push_back(targetType);
506       yieldValues.push_back(operand.get());
507     }
508     SmallVector<size_t> newRetIndices;
509     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
510         rewriter, warpOp, yieldValues, retTypes, newRetIndices);
511     rewriter.setInsertionPointAfter(newWarpOp);
512     SmallVector<Value> newOperands(elementWise->getOperands().begin(),
513                                    elementWise->getOperands().end());
514     for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
515       newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
516     }
517     OpBuilder::InsertionGuard g(rewriter);
518     rewriter.setInsertionPointAfter(newWarpOp);
519     Operation *newOp = cloneOpWithOperandsAndTypes(
520         rewriter, loc, elementWise, newOperands,
521         {newWarpOp.getResult(operandIndex).getType()});
522     newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0));
523     return success();
524   }
525 };
526 
527 /// Sink out transfer_read op feeding into a warp op yield.
528 /// ```
529 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
530 ///   ...
531 //    %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
532 //    vector<32xf32>
533 ///   vector.yield %2 : vector<32xf32>
534 /// }
535 /// ```
536 /// To
537 /// ```
538 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
539 /// vector<1xf32>, vector<1xf32>) {
540 ///   ...
541 ///   %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
542 ///   vector<32xf32> vector.yield %2 : vector<32xf32>
543 /// }
544 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
545 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
546   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
547   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
548                                 PatternRewriter &rewriter) const override {
549     OpOperand *operand = getWarpResult(
550         warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
551     if (!operand)
552       return failure();
553     auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
554     unsigned operandIndex = operand->getOperandNumber();
555     Value distributedVal = warpOp.getResult(operandIndex);
556 
557     SmallVector<Value, 4> indices(read.getIndices().begin(),
558                                   read.getIndices().end());
559     AffineMap map = calculateImplicitMap(read.getResult(), distributedVal);
560     AffineMap indexMap = map.compose(read.getPermutationMap());
561     OpBuilder::InsertionGuard g(rewriter);
562     rewriter.setInsertionPointAfter(warpOp);
563     for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
564       AffineExpr d0, d1;
565       bindDims(read.getContext(), d0, d1);
566       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
567       if (!indexExpr)
568         continue;
569       unsigned indexPos = indexExpr.getPosition();
570       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
571       int64_t scale =
572           distributedVal.getType().cast<VectorType>().getDimSize(vectorPos);
573       indices[indexPos] =
574           makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1,
575                                   {indices[indexPos], warpOp.getLaneid()});
576     }
577     Value newRead = rewriter.create<vector::TransferReadOp>(
578         read.getLoc(), distributedVal.getType(), read.getSource(), indices,
579         read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
580         read.getInBoundsAttr());
581     distributedVal.replaceAllUsesWith(newRead);
582     return success();
583   }
584 };
585 
586 /// Remove any result that has no use along with the matching yieldOp operand.
587 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
588 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
589   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
590   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
591                                 PatternRewriter &rewriter) const override {
592     SmallVector<Type> resultTypes;
593     SmallVector<Value> yieldValues;
594     auto yield = cast<vector::YieldOp>(
595         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
596     for (OpResult result : warpOp.getResults()) {
597       if (result.use_empty())
598         continue;
599       resultTypes.push_back(result.getType());
600       yieldValues.push_back(yield.getOperand(result.getResultNumber()));
601     }
602     if (yield.getNumOperands() == yieldValues.size())
603       return failure();
604     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
605         rewriter, warpOp, yieldValues, resultTypes);
606     unsigned resultIndex = 0;
607     for (OpResult result : warpOp.getResults()) {
608       if (result.use_empty())
609         continue;
610       result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++));
611     }
612     rewriter.eraseOp(warpOp);
613     return success();
614   }
615 };
616 
617 // If an operand is directly yielded out of the region we can forward it
618 // directly and it doesn't need to go through the region.
619 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
620   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
621   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
622                                 PatternRewriter &rewriter) const override {
623     SmallVector<Type> resultTypes;
624     SmallVector<Value> yieldValues;
625     auto yield = cast<vector::YieldOp>(
626         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
627     Value valForwarded;
628     unsigned resultIndex;
629     for (OpOperand &operand : yield->getOpOperands()) {
630       Value result = warpOp.getResult(operand.getOperandNumber());
631       if (result.use_empty())
632         continue;
633 
634       // Assume all the values coming from above are uniform.
635       if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
636         if (result.getType() != operand.get().getType())
637           continue;
638         valForwarded = operand.get();
639         resultIndex = operand.getOperandNumber();
640         break;
641       }
642       auto arg = operand.get().dyn_cast<BlockArgument>();
643       if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
644         continue;
645       Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
646       if (result.getType() != warpOperand.getType())
647         continue;
648       valForwarded = warpOperand;
649       resultIndex = operand.getOperandNumber();
650       break;
651     }
652     if (!valForwarded)
653       return failure();
654     warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded);
655     return success();
656   }
657 };
658 
659 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
660   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
661   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
662                                 PatternRewriter &rewriter) const override {
663     OpOperand *operand = getWarpResult(
664         warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); });
665     if (!operand)
666       return failure();
667     unsigned int operandNumber = operand->getOperandNumber();
668     auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
669     Location loc = broadcastOp.getLoc();
670     auto destVecType =
671         warpOp->getResultTypes()[operandNumber].cast<VectorType>();
672     SmallVector<size_t> newRetIndices;
673     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
674         rewriter, warpOp, {broadcastOp.getSource()},
675         {broadcastOp.getSource().getType()}, newRetIndices);
676     rewriter.setInsertionPointAfter(newWarpOp);
677     Value broadcasted = rewriter.create<vector::BroadcastOp>(
678         loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
679     newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
680     return success();
681   }
682 };
683 
684 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
685 /// the scf.ForOp is the last operation in the region so that it doesn't change
686 /// the order of execution. This creates a new scf.for region after the
687 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
688 /// WarpExecuteOnLane0Op region. Example:
689 /// ```
690 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
691 ///   ...
692 ///   %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
693 ///   -> (vector<128xf32>) {
694 ///     ...
695 ///     scf.yield %r : vector<128xf32>
696 ///   }
697 ///   vector.yield %v1 : vector<128xf32>
698 /// }
699 /// ```
700 /// To:
701 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
702 ///   ...
703 ///   vector.yield %v : vector<128xf32>
704 /// }
705 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
706 ///   -> (vector<4xf32>) {
707 ///     %iw = vector.warp_execute_on_lane_0(%laneid)
708 ///     args(%varg : vector<4xf32>) -> (vector<4xf32>) {
709 ///     ^bb0(%arg: vector<128xf32>):
710 ///       ...
711 ///       vector.yield %ir : vector<128xf32>
712 ///     }
713 ///     scf.yield %iw : vector<4xf32>
714 ///  }
715 /// ```
716 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
717   using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
718   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
719                                 PatternRewriter &rewriter) const override {
720     auto yield = cast<vector::YieldOp>(
721         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
722     // Only pick up forOp if it is the last op in the region.
723     Operation *lastNode = yield->getPrevNode();
724     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
725     if (!forOp)
726       return failure();
727     SmallVector<Value> newOperands;
728     SmallVector<unsigned> resultIdx;
729     // Collect all the outputs coming from the forOp.
730     for (OpOperand &yieldOperand : yield->getOpOperands()) {
731       if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
732         continue;
733       auto forResult = yieldOperand.get().cast<OpResult>();
734       newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber()));
735       yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
736       resultIdx.push_back(yieldOperand.getOperandNumber());
737     }
738     OpBuilder::InsertionGuard g(rewriter);
739     rewriter.setInsertionPointAfter(warpOp);
740     // Create a new for op outside the region with a WarpExecuteOnLane0Op region
741     // inside.
742     auto newForOp = rewriter.create<scf::ForOp>(
743         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
744         forOp.getStep(), newOperands);
745     rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
746     auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
747         warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(),
748         warpOp.getWarpSize(), newForOp.getRegionIterArgs(),
749         forOp.getResultTypes());
750 
751     SmallVector<Value> argMapping;
752     argMapping.push_back(newForOp.getInductionVar());
753     for (Value args : innerWarp.getBody()->getArguments()) {
754       argMapping.push_back(args);
755     }
756     SmallVector<Value> yieldOperands;
757     for (Value operand : forOp.getBody()->getTerminator()->getOperands())
758       yieldOperands.push_back(operand);
759     rewriter.eraseOp(forOp.getBody()->getTerminator());
760     rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
761     rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
762     rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
763     rewriter.setInsertionPointAfter(innerWarp);
764     if (!innerWarp.getResults().empty())
765       rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
766     rewriter.eraseOp(forOp);
767     // Replace the warpOp result coming from the original ForOp.
768     for (const auto &res : llvm::enumerate(resultIdx)) {
769       warpOp.getResult(res.value())
770           .replaceAllUsesWith(newForOp.getResult(res.index()));
771       newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value()));
772     }
773     return success();
774   }
775 };
776 
777 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
778 /// The vector is reduced in parallel. Currently limited to vector size matching
779 /// the warpOp size. E.g.:
780 /// ```
781 /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
782 ///   %0 = "some_def"() : () -> (vector<32xf32>)
783 ///   %1 = vector.reduction "add", %0 : vector<32xf32> into f32
784 ///   vector_ext.yield %1 : f32
785 /// }
786 /// ```
787 /// is lowered to:
788 /// ```
789 /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
790 ///   %1 = "some_def"() : () -> (vector<32xf32>)
791 ///   vector_ext.yield %1 : vector<32xf32>
792 /// }
793 /// %a = vector.extract %0[0] : vector<1xf32>
794 /// %r = ("warp.reduction %a")
795 /// ```
796 struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
797   WarpOpReduction(MLIRContext *context,
798                   DistributedReductionFn distributedReductionFn,
799                   PatternBenefit benefit = 1)
800       : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
801         distributedReductionFn(distributedReductionFn) {}
802 
803   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
804                                 PatternRewriter &rewriter) const override {
805     OpOperand *yieldOperand = getWarpResult(
806         warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); });
807     if (!yieldOperand)
808       return failure();
809 
810     auto reductionOp =
811         cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
812     auto vectorType = reductionOp.getVector().getType().cast<VectorType>();
813     // Only rank 1 vectors supported.
814     if (vectorType.getRank() != 1)
815       return rewriter.notifyMatchFailure(
816           warpOp, "Only rank 1 reductions can be distributed.");
817     // Only warp_size-sized vectors supported.
818     if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
819       return rewriter.notifyMatchFailure(
820           warpOp, "Reduction vector dimension must match was size.");
821     // Only f32 and i32 element types are supported.
822     if (!reductionOp.getType().isF32() &&
823         !reductionOp.getType().isSignlessInteger(32))
824       return rewriter.notifyMatchFailure(
825           warpOp,
826           "Reduction distribution currently only supports 32bits types.");
827 
828     int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
829     // Return vector that will be reduced from the WarpExecuteOnLane0Op.
830     unsigned operandIndex = yieldOperand->getOperandNumber();
831     SmallVector<Value> yieldValues = {reductionOp.getVector()};
832     SmallVector<Type> retTypes = {
833         VectorType::get({numElements}, reductionOp.getType())};
834     SmallVector<size_t> newRetIndices;
835     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
836         rewriter, warpOp, yieldValues, retTypes, newRetIndices);
837     rewriter.setInsertionPointAfter(newWarpOp);
838 
839     Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
840     // First reduce on a single thread.
841     Value perLaneReduction = rewriter.create<vector::ReductionOp>(
842         reductionOp.getLoc(), reductionOp.getKind(), laneValVec);
843     // Then distribute across threads.
844     Value fullReduce =
845         distributedReductionFn(reductionOp.getLoc(), rewriter, perLaneReduction,
846                                reductionOp.getKind(), newWarpOp.getWarpSize());
847     newWarpOp.getResult(operandIndex).replaceAllUsesWith(fullReduce);
848     return success();
849   }
850 
851 private:
852   DistributedReductionFn distributedReductionFn;
853 };
854 
855 } // namespace
856 
857 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
858     RewritePatternSet &patterns,
859     const WarpExecuteOnLane0LoweringOptions &options) {
860   patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options);
861 }
862 
863 void mlir::vector::populateDistributeTransferWriteOpPatterns(
864     RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn) {
865   patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn);
866 }
867 
868 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
869     RewritePatternSet &patterns) {
870   patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
871                WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp>(
872       patterns.getContext());
873 }
874 
875 void mlir::vector::populateDistributeReduction(
876     RewritePatternSet &patterns,
877     DistributedReductionFn distributedReductionFn) {
878   patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn);
879 }
880 
881 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
882   Block *body = warpOp.getBody();
883 
884   // Keep track of the ops we want to hoist.
885   llvm::SmallSetVector<Operation *, 8> opsToMove;
886 
887   // Helper to check if a value is or will be defined outside of the region.
888   auto isDefinedOutsideOfBody = [&](Value value) {
889     auto *definingOp = value.getDefiningOp();
890     return (definingOp && opsToMove.count(definingOp)) ||
891            warpOp.isDefinedOutsideOfRegion(value);
892   };
893 
894   // Do not use walk here, as we do not want to go into nested regions and hoist
895   // operations from there.
896   for (auto &op : body->without_terminator()) {
897     bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
898       return result.getType().isa<VectorType>();
899     });
900     if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
901       opsToMove.insert(&op);
902   }
903 
904   // Move all the ops marked as uniform outside of the region.
905   for (Operation *op : opsToMove)
906     op->moveBefore(warpOp);
907 }
908