1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/SCF/SCF.h"
13 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
14 #include "mlir/Transforms/SideEffectUtils.h"
15 
16 using namespace mlir;
17 using namespace mlir::vector;
18 
19 static LogicalResult
20 rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
21                       const WarpExecuteOnLane0LoweringOptions &options) {
22   assert(warpOp.getBodyRegion().hasOneBlock() &&
23          "expected WarpOp with single block");
24   Block *warpOpBody = &warpOp.getBodyRegion().front();
25   Location loc = warpOp.getLoc();
26 
27   // Passed all checks. Start rewriting.
28   OpBuilder::InsertionGuard g(rewriter);
29   rewriter.setInsertionPoint(warpOp);
30 
31   // Create scf.if op.
32   Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
33   Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
34                                                  warpOp.getLaneid(), c0);
35   auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
36                                          /*withElseRegion=*/false);
37   rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
38 
39   // Store vectors that are defined outside of warpOp into the scratch pad
40   // buffer.
41   SmallVector<Value> bbArgReplacements;
42   for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
43     Value val = it.value();
44     Value bbArg = warpOpBody->getArgument(it.index());
45 
46     rewriter.setInsertionPoint(ifOp);
47     Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp,
48                                             bbArg.getType());
49 
50     // Store arg vector into buffer.
51     rewriter.setInsertionPoint(ifOp);
52     auto vectorType = val.getType().cast<VectorType>();
53     int64_t storeSize = vectorType.getShape()[0];
54     Value storeOffset = rewriter.create<arith::MulIOp>(
55         loc, warpOp.getLaneid(),
56         rewriter.create<arith::ConstantIndexOp>(loc, storeSize));
57     rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset);
58 
59     // Load bbArg vector from buffer.
60     rewriter.setInsertionPointToStart(ifOp.thenBlock());
61     auto bbArgType = bbArg.getType().cast<VectorType>();
62     Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0);
63     bbArgReplacements.push_back(loadOp);
64   }
65 
66   // Insert sync after all the stores and before all the loads.
67   if (!warpOp.getArgs().empty()) {
68     rewriter.setInsertionPoint(ifOp);
69     options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp);
70   }
71 
72   // Move body of warpOp to ifOp.
73   rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
74 
75   // Rewrite terminator and compute replacements of WarpOp results.
76   SmallVector<Value> replacements;
77   auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
78   Location yieldLoc = yieldOp.getLoc();
79   for (const auto &it : llvm::enumerate(yieldOp.operands())) {
80     Value val = it.value();
81     Type resultType = warpOp->getResultTypes()[it.index()];
82     rewriter.setInsertionPoint(ifOp);
83     Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp,
84                                             val.getType());
85 
86     // Store yielded value into buffer.
87     rewriter.setInsertionPoint(yieldOp);
88     if (val.getType().isa<VectorType>())
89       rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0);
90     else
91       rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0);
92 
93     // Load value from buffer (after warpOp).
94     rewriter.setInsertionPointAfter(ifOp);
95     if (resultType == val.getType()) {
96       // Result type and yielded value type are the same. This is a broadcast.
97       // E.g.:
98       // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
99       //   vector.yield %cst : f32
100       // }
101       // Both types are f32. The constant %cst is broadcasted to all lanes.
102       // This is described in more detail in the documentation of the op.
103       Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0);
104       replacements.push_back(loadOp);
105     } else {
106       auto loadedVectorType = resultType.cast<VectorType>();
107       int64_t loadSize = loadedVectorType.getShape()[0];
108 
109       // loadOffset = laneid * loadSize
110       Value loadOffset = rewriter.create<arith::MulIOp>(
111           loc, warpOp.getLaneid(),
112           rewriter.create<arith::ConstantIndexOp>(loc, loadSize));
113       Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType,
114                                                      buffer, loadOffset);
115       replacements.push_back(loadOp);
116     }
117   }
118 
119   // Insert sync after all the stores and before all the loads.
120   if (!yieldOp.operands().empty()) {
121     rewriter.setInsertionPointAfter(ifOp);
122     options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp);
123   }
124 
125   // Delete terminator and add empty scf.yield.
126   rewriter.eraseOp(yieldOp);
127   rewriter.setInsertionPointToEnd(ifOp.thenBlock());
128   rewriter.create<scf::YieldOp>(yieldLoc);
129 
130   // Compute replacements for WarpOp results.
131   rewriter.replaceOp(warpOp, replacements);
132 
133   return success();
134 }
135 
136 /// Helper to create a new WarpExecuteOnLane0Op with different signature.
137 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
138     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
139     ValueRange newYieldedValues, TypeRange newReturnTypes) {
140   // Create a new op before the existing one, with the extra operands.
141   OpBuilder::InsertionGuard g(rewriter);
142   rewriter.setInsertionPoint(warpOp);
143   auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
144       warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
145       warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
146 
147   Region &opBody = warpOp.getBodyRegion();
148   Region &newOpBody = newWarpOp.getBodyRegion();
149   rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
150   auto yield =
151       cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
152 
153   rewriter.updateRootInPlace(
154       yield, [&]() { yield.operandsMutable().assign(newYieldedValues); });
155   return newWarpOp;
156 }
157 
158 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
159 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
160     RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
161     ValueRange newYieldedValues, TypeRange newReturnTypes) {
162   SmallVector<Type> types(warpOp.getResultTypes().begin(),
163                           warpOp.getResultTypes().end());
164   types.append(newReturnTypes.begin(), newReturnTypes.end());
165   auto yield = cast<vector::YieldOp>(
166       warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
167   SmallVector<Value> yieldValues(yield.getOperands().begin(),
168                                  yield.getOperands().end());
169   yieldValues.append(newYieldedValues.begin(), newYieldedValues.end());
170   WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
171       rewriter, warpOp, yieldValues, types);
172   rewriter.replaceOp(warpOp,
173                      newWarpOp.getResults().take_front(warpOp.getNumResults()));
174   return newWarpOp;
175 }
176 
177 /// Helper to know if an op can be hoisted out of the region.
178 static bool canBeHoisted(Operation *op,
179                          function_ref<bool(Value)> definedOutside) {
180   return llvm::all_of(op->getOperands(), definedOutside) &&
181          isSideEffectFree(op) && op->getNumRegions() == 0;
182 }
183 
184 namespace {
185 
186 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
187   WarpOpToScfForPattern(MLIRContext *context,
188                         const WarpExecuteOnLane0LoweringOptions &options,
189                         PatternBenefit benefit = 1)
190       : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
191         options(options) {}
192 
193   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
194                                 PatternRewriter &rewriter) const override {
195     return rewriteWarpOpToScfFor(rewriter, warpOp, options);
196   }
197 
198 private:
199   const WarpExecuteOnLane0LoweringOptions &options;
200 };
201 
202 /// Distribute transfer_write ops based on the affine map returned by
203 /// `distributionMapFn`.
204 /// Example:
205 /// ```
206 /// %0 = vector.warp_execute_on_lane_0(%id){
207 ///   ...
208 ///   vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
209 ///   vector.yield
210 /// }
211 /// ```
212 /// To
213 /// ```
214 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
215 ///   ...
216 ///   vector.yield %v : vector<32xf32>
217 /// }
218 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
219 struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
220   WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
221                       PatternBenefit b = 1)
222       : OpRewritePattern<vector::TransferWriteOp>(ctx, b),
223         distributionMapFn(fn) {}
224 
225   /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
226   /// are multiples of the distribution ratio are supported at the moment.
227   LogicalResult tryDistributeOp(RewriterBase &rewriter,
228                                 vector::TransferWriteOp writeOp,
229                                 WarpExecuteOnLane0Op warpOp) const {
230     AffineMap map = distributionMapFn(writeOp);
231     SmallVector<int64_t> targetShape(writeOp.getVectorType().getShape().begin(),
232                                      writeOp.getVectorType().getShape().end());
233     assert(map.getNumResults() == 1 &&
234            "multi-dim distribution not implemented yet");
235     for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
236       unsigned position = map.getDimPosition(i);
237       if (targetShape[position] % warpOp.getWarpSize() != 0)
238         return failure();
239       targetShape[position] = targetShape[position] / warpOp.getWarpSize();
240     }
241     VectorType targetType =
242         VectorType::get(targetShape, writeOp.getVectorType().getElementType());
243 
244     SmallVector<Value> yieldValues = {writeOp.getVector()};
245     SmallVector<Type> retTypes = {targetType};
246     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
247         rewriter, warpOp, yieldValues, retTypes);
248     rewriter.setInsertionPointAfter(newWarpOp);
249 
250     // Move op outside of region: Insert clone at the insertion point and delete
251     // the old op.
252     auto newWriteOp =
253         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
254     rewriter.eraseOp(writeOp);
255 
256     rewriter.setInsertionPoint(newWriteOp);
257     AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
258     Location loc = newWriteOp.getLoc();
259     SmallVector<Value> indices(newWriteOp.getIndices().begin(),
260                                newWriteOp.getIndices().end());
261     for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
262       AffineExpr d0, d1;
263       bindDims(newWarpOp.getContext(), d0, d1);
264       auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
265       if (!indexExpr)
266         continue;
267       unsigned indexPos = indexExpr.getPosition();
268       unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
269       auto scale =
270           getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext());
271       indices[indexPos] =
272           makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
273                                   {indices[indexPos], newWarpOp.getLaneid()});
274     }
275     newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back());
276     newWriteOp.getIndicesMutable().assign(indices);
277 
278     return success();
279   }
280 
281   /// Extract TransferWriteOps of vector<1x> into a separate warp op.
282   LogicalResult tryExtractOp(RewriterBase &rewriter,
283                              vector::TransferWriteOp writeOp,
284                              WarpExecuteOnLane0Op warpOp) const {
285     Location loc = writeOp.getLoc();
286     VectorType vecType = writeOp.getVectorType();
287 
288     // Only vector<1x> is supported at the moment.
289     if (vecType.getShape().size() != 1 || vecType.getShape()[0] != 1)
290       return failure();
291 
292     // Do not process warp ops that contain only TransferWriteOps.
293     if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
294           return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
295         }))
296       return failure();
297 
298     SmallVector<Value> yieldValues = {writeOp.getVector()};
299     SmallVector<Type> retTypes = {vecType};
300     WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
301         rewriter, warpOp, yieldValues, retTypes);
302     rewriter.setInsertionPointAfter(newWarpOp);
303 
304     // Create a second warp op that contains only writeOp.
305     auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
306         loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
307     Block &body = secondWarpOp.getBodyRegion().front();
308     rewriter.setInsertionPointToStart(&body);
309     auto newWriteOp =
310         cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
311     newWriteOp.getVectorMutable().assign(
312         newWarpOp.getResult(newWarpOp.getNumResults() - 1));
313     rewriter.eraseOp(writeOp);
314     rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
315     return success();
316   }
317 
318   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
319                                 PatternRewriter &rewriter) const override {
320     // Ops with mask not supported yet.
321     if (writeOp.getMask())
322       return failure();
323 
324     auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
325     if (!warpOp)
326       return failure();
327 
328     // There must be no op with a side effect after writeOp.
329     Operation *nextOp = writeOp.getOperation();
330     while ((nextOp = nextOp->getNextNode()))
331       if (!isSideEffectFree(nextOp))
332         return failure();
333 
334     if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
335           return writeOp.getVector() == value ||
336                  warpOp.isDefinedOutsideOfRegion(value);
337         }))
338       return failure();
339 
340     if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
341       return success();
342 
343     if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
344       return success();
345 
346     return failure();
347   }
348 
349 private:
350   DistributionMapFn distributionMapFn;
351 };
352 
353 } // namespace
354 
355 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
356     RewritePatternSet &patterns,
357     const WarpExecuteOnLane0LoweringOptions &options) {
358   patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options);
359 }
360 
361 void mlir::vector::populateDistributeTransferWriteOpPatterns(
362     RewritePatternSet &patterns, DistributionMapFn distributionMapFn) {
363   patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn);
364 }
365 
366 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
367   Block *body = warpOp.getBody();
368 
369   // Keep track of the ops we want to hoist.
370   llvm::SmallSetVector<Operation *, 8> opsToMove;
371 
372   // Helper to check if a value is or will be defined outside of the region.
373   auto isDefinedOutsideOfBody = [&](Value value) {
374     auto *definingOp = value.getDefiningOp();
375     return (definingOp && opsToMove.count(definingOp)) ||
376            warpOp.isDefinedOutsideOfRegion(value);
377   };
378 
379   // Do not use walk here, as we do not want to go into nested regions and hoist
380   // operations from there.
381   for (auto &op : body->without_terminator()) {
382     bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
383       return result.getType().isa<VectorType>();
384     });
385     if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
386       opsToMove.insert(&op);
387   }
388 
389   // Move all the ops marked as uniform outside of the region.
390   for (Operation *op : opsToMove)
391     op->moveBefore(warpOp);
392 }
393