1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/MemRef/IR/MemRef.h"
11 #include "mlir/Dialect/SCF/SCF.h"
12 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
13 
14 using namespace mlir;
15 using namespace mlir::vector;
16 
17 static LogicalResult
18 rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
19                       const WarpExecuteOnLane0LoweringOptions &options) {
20   assert(warpOp.getBodyRegion().hasOneBlock() &&
21          "expected WarpOp with single block");
22   Block *warpOpBody = &warpOp.getBodyRegion().front();
23   Location loc = warpOp.getLoc();
24 
25   // Passed all checks. Start rewriting.
26   OpBuilder::InsertionGuard g(rewriter);
27   rewriter.setInsertionPoint(warpOp);
28 
29   // Create scf.if op.
30   Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
31   Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
32                                                  warpOp.getLaneid(), c0);
33   auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
34                                          /*withElseRegion=*/false);
35   rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
36 
37   // Store vectors that are defined outside of warpOp into the scratch pad
38   // buffer.
39   SmallVector<Value> bbArgReplacements;
40   for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
41     Value val = it.value();
42     Value bbArg = warpOpBody->getArgument(it.index());
43 
44     rewriter.setInsertionPoint(ifOp);
45     Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp,
46                                             bbArg.getType());
47 
48     // Store arg vector into buffer.
49     rewriter.setInsertionPoint(ifOp);
50     auto vectorType = val.getType().cast<VectorType>();
51     int64_t storeSize = vectorType.getShape()[0];
52     Value storeOffset = rewriter.create<arith::MulIOp>(
53         loc, warpOp.getLaneid(),
54         rewriter.create<arith::ConstantIndexOp>(loc, storeSize));
55     rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset);
56 
57     // Load bbArg vector from buffer.
58     rewriter.setInsertionPointToStart(ifOp.thenBlock());
59     auto bbArgType = bbArg.getType().cast<VectorType>();
60     Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0);
61     bbArgReplacements.push_back(loadOp);
62   }
63 
64   // Insert sync after all the stores and before all the loads.
65   if (!warpOp.getArgs().empty()) {
66     rewriter.setInsertionPoint(ifOp);
67     options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp);
68   }
69 
70   // Move body of warpOp to ifOp.
71   rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
72 
73   // Rewrite terminator and compute replacements of WarpOp results.
74   SmallVector<Value> replacements;
75   auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
76   Location yieldLoc = yieldOp.getLoc();
77   for (const auto &it : llvm::enumerate(yieldOp.operands())) {
78     Value val = it.value();
79     Type resultType = warpOp->getResultTypes()[it.index()];
80     rewriter.setInsertionPoint(ifOp);
81     Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp,
82                                             val.getType());
83 
84     // Store yielded value into buffer.
85     rewriter.setInsertionPoint(yieldOp);
86     if (val.getType().isa<VectorType>())
87       rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0);
88     else
89       rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0);
90 
91     // Load value from buffer (after warpOp).
92     rewriter.setInsertionPointAfter(ifOp);
93     if (resultType == val.getType()) {
94       // Result type and yielded value type are the same. This is a broadcast.
95       // E.g.:
96       // %r = vector_ext.warp_execute_on_lane_0(...) -> (f32) {
97       //   vector_ext.yield %cst : f32
98       // }
99       // Both types are f32. The constant %cst is broadcasted to all lanes.
100       // This is described in more detail in the documentation of the op.
101       Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0);
102       replacements.push_back(loadOp);
103     } else {
104       auto loadedVectorType = resultType.cast<VectorType>();
105       int64_t loadSize = loadedVectorType.getShape()[0];
106 
107       // loadOffset = laneid * loadSize
108       Value loadOffset = rewriter.create<arith::MulIOp>(
109           loc, warpOp.getLaneid(),
110           rewriter.create<arith::ConstantIndexOp>(loc, loadSize));
111       Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType,
112                                                      buffer, loadOffset);
113       replacements.push_back(loadOp);
114     }
115   }
116 
117   // Insert sync after all the stores and before all the loads.
118   if (!yieldOp.operands().empty()) {
119     rewriter.setInsertionPointAfter(ifOp);
120     options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp);
121   }
122 
123   // Delete terminator and add empty scf.yield.
124   rewriter.eraseOp(yieldOp);
125   rewriter.setInsertionPointToEnd(ifOp.thenBlock());
126   rewriter.create<scf::YieldOp>(yieldLoc);
127 
128   // Compute replacements for WarpOp results.
129   rewriter.replaceOp(warpOp, replacements);
130 
131   return success();
132 }
133 
134 namespace {
135 
136 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
137   WarpOpToScfForPattern(MLIRContext *context,
138                         const WarpExecuteOnLane0LoweringOptions &options,
139                         PatternBenefit benefit = 1)
140       : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
141         options(options) {}
142 
143   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
144                                 PatternRewriter &rewriter) const override {
145     return rewriteWarpOpToScfFor(rewriter, warpOp, options);
146   }
147 
148 private:
149   const WarpExecuteOnLane0LoweringOptions &options;
150 };
151 
152 } // namespace
153 
154 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
155     RewritePatternSet &patterns,
156     const WarpExecuteOnLane0LoweringOptions &options) {
157   patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options);
158 }
159