1 //===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===//
2 //
3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This file implements target-independent rewrites of MultiDimReductionOp.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/ImplicitLocOpBuilder.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 #define DEBUG_TYPE "vector-multi-reduction"
20 
21 using namespace mlir;
22 
23 /// This file implements the following transformations as composable atomic
24 /// patterns.
25 
26 /// Converts vector.multi_reduction into inner-most/outer-most reduction form
27 /// by using vector.transpose
28 class InnerOuterDimReductionConversion
29     : public OpRewritePattern<vector::MultiDimReductionOp> {
30 public:
31   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
32 
33   explicit InnerOuterDimReductionConversion(
34       MLIRContext *context, vector::VectorMultiReductionLowering options)
35       : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
36         useInnerDimsForReduction(
37             options == vector::VectorMultiReductionLowering::InnerReduction) {}
38 
39   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
40                                 PatternRewriter &rewriter) const override {
41     auto src = multiReductionOp.source();
42     auto loc = multiReductionOp.getLoc();
43     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
44 
45     // Separate reduction and parallel dims
46     auto reductionDimsRange =
47         multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>();
48     auto reductionDims = llvm::to_vector<4>(llvm::map_range(
49         reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); }));
50     llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
51                                                   reductionDims.end());
52     int64_t reductionSize = reductionDims.size();
53     SmallVector<int64_t, 4> parallelDims;
54     for (int64_t i = 0; i < srcRank; ++i)
55       if (!reductionDimsSet.contains(i))
56         parallelDims.push_back(i);
57 
58     // Add transpose only if inner-most/outer-most dimensions are not parallel
59     // and there are parallel dims.
60     if (parallelDims.empty())
61       return failure();
62     if (useInnerDimsForReduction &&
63         (parallelDims ==
64          llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
65       return failure();
66 
67     if (!useInnerDimsForReduction &&
68         (parallelDims !=
69          llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
70       return failure();
71 
72     SmallVector<int64_t, 4> indices;
73     if (useInnerDimsForReduction) {
74       indices.append(parallelDims.begin(), parallelDims.end());
75       indices.append(reductionDims.begin(), reductionDims.end());
76     } else {
77       indices.append(reductionDims.begin(), reductionDims.end());
78       indices.append(parallelDims.begin(), parallelDims.end());
79     }
80     auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
81     SmallVector<bool> reductionMask(srcRank, false);
82     for (int i = 0; i < reductionSize; ++i) {
83       if (useInnerDimsForReduction)
84         reductionMask[srcRank - i - 1] = true;
85       else
86         reductionMask[i] = true;
87     }
88     rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
89         multiReductionOp, transposeOp.result(), reductionMask,
90         multiReductionOp.kind());
91     return success();
92   }
93 
94 private:
95   const bool useInnerDimsForReduction;
96 };
97 
98 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
99 /// dimensions are either inner most or outer most.
100 class ReduceMultiDimReductionRank
101     : public OpRewritePattern<vector::MultiDimReductionOp> {
102 public:
103   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
104 
105   explicit ReduceMultiDimReductionRank(
106       MLIRContext *context, vector::VectorMultiReductionLowering options)
107       : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
108         useInnerDimsForReduction(
109             options == vector::VectorMultiReductionLowering::InnerReduction) {}
110 
111   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
112                                 PatternRewriter &rewriter) const override {
113     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
114     auto srcShape = multiReductionOp.getSourceVectorType().getShape();
115     auto loc = multiReductionOp.getLoc();
116 
117     // If rank less than 2, nothing to do.
118     if (srcRank < 2)
119       return failure();
120 
121     // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
122     SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
123     if (srcRank == 2 && reductionMask.front() != reductionMask.back())
124       return failure();
125 
126     // 1. Separate reduction and parallel dims.
127     SmallVector<int64_t, 4> parallelDims, parallelShapes;
128     SmallVector<int64_t, 4> reductionDims, reductionShapes;
129     for (const auto &it : llvm::enumerate(reductionMask)) {
130       int64_t i = it.index();
131       bool isReduction = it.value();
132       if (isReduction) {
133         reductionDims.push_back(i);
134         reductionShapes.push_back(srcShape[i]);
135       } else {
136         parallelDims.push_back(i);
137         parallelShapes.push_back(srcShape[i]);
138       }
139     }
140 
141     // 2. Compute flattened parallel and reduction sizes.
142     int flattenedParallelDim = 0;
143     int flattenedReductionDim = 0;
144     if (!parallelShapes.empty()) {
145       flattenedParallelDim = 1;
146       for (auto d : parallelShapes)
147         flattenedParallelDim *= d;
148     }
149     if (!reductionShapes.empty()) {
150       flattenedReductionDim = 1;
151       for (auto d : reductionShapes)
152         flattenedReductionDim *= d;
153     }
154     // We must at least have some parallel or some reduction.
155     assert((flattenedParallelDim || flattenedReductionDim) &&
156            "expected at least one parallel or reduction dim");
157 
158     // 3. Fail if reduction/parallel dims are not contiguous.
159     // Check parallelDims are exactly [0 .. size).
160     int64_t counter = 0;
161     if (useInnerDimsForReduction &&
162         llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
163       return failure();
164     // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
165     counter = reductionDims.size();
166     if (!useInnerDimsForReduction &&
167         llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
168       return failure();
169 
170     // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
171     // a single parallel (resp. reduction) dim.
172     SmallVector<bool, 2> mask;
173     SmallVector<int64_t, 2> vectorShape;
174     if (flattenedParallelDim) {
175       mask.push_back(false);
176       vectorShape.push_back(flattenedParallelDim);
177     }
178     if (flattenedReductionDim) {
179       mask.push_back(true);
180       vectorShape.push_back(flattenedReductionDim);
181     }
182     if (!useInnerDimsForReduction && vectorShape.size() == 2) {
183       std::swap(mask.front(), mask.back());
184       std::swap(vectorShape.front(), vectorShape.back());
185     }
186     auto castedType = VectorType::get(
187         vectorShape, multiReductionOp.getSourceVectorType().getElementType());
188     Value cast = rewriter.create<vector::ShapeCastOp>(
189         loc, castedType, multiReductionOp.source());
190 
191     // 5. Creates the flattened form of vector.multi_reduction with inner/outer
192     // most dim as reduction.
193     auto newOp = rewriter.create<vector::MultiDimReductionOp>(
194         loc, cast, mask, multiReductionOp.kind());
195 
196     // 6. If there are no parallel shapes, the result is a scalar.
197     // TODO: support 0-d vectors when available.
198     if (parallelShapes.empty()) {
199       rewriter.replaceOp(multiReductionOp, newOp.dest());
200       return success();
201     }
202 
203     // 7. Creates shape cast for the output n-D -> 2-D
204     VectorType outputCastedType = VectorType::get(
205         parallelShapes,
206         multiReductionOp.getSourceVectorType().getElementType());
207     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
208         multiReductionOp, outputCastedType, newOp.dest());
209     return success();
210   }
211 
212 private:
213   const bool useInnerDimsForReduction;
214 };
215 
216 /// Unrolls vector.multi_reduction with outermost reductions
217 /// and combines results
218 struct TwoDimMultiReductionToElementWise
219     : public OpRewritePattern<vector::MultiDimReductionOp> {
220   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
221 
222   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
223                                 PatternRewriter &rewriter) const override {
224     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
225     // Rank-2 ["parallel", "reduce"] or bail.
226     if (srcRank != 2)
227       return failure();
228 
229     if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
230       return failure();
231 
232     auto loc = multiReductionOp.getLoc();
233     ArrayRef<int64_t> srcShape =
234         multiReductionOp.getSourceVectorType().getShape();
235 
236     Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
237     if (!elementType.isIntOrIndexOrFloat())
238       return failure();
239 
240     Value result =
241         rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), 0)
242             .getResult();
243     for (int64_t i = 1; i < srcShape[0]; i++) {
244       auto operand =
245           rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), i);
246       result = makeArithReduction(rewriter, loc, multiReductionOp.kind(),
247                                   operand, result);
248     }
249 
250     rewriter.replaceOp(multiReductionOp, result);
251     return success();
252   }
253 };
254 
255 /// Converts 2d vector.multi_reduction with inner most reduction dimension into
256 /// a sequence of vector.reduction ops.
257 struct TwoDimMultiReductionToReduction
258     : public OpRewritePattern<vector::MultiDimReductionOp> {
259   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
260 
261   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
262                                 PatternRewriter &rewriter) const override {
263     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
264     if (srcRank != 2)
265       return failure();
266 
267     if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
268       return failure();
269 
270     auto loc = multiReductionOp.getLoc();
271     Value result = rewriter.create<arith::ConstantOp>(
272         loc, multiReductionOp.getDestType(),
273         rewriter.getZeroAttr(multiReductionOp.getDestType()));
274     int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
275 
276     for (int i = 0; i < outerDim; ++i) {
277       auto v = rewriter.create<vector::ExtractOp>(
278           loc, multiReductionOp.source(), ArrayRef<int64_t>{i});
279       auto reducedValue =
280           rewriter.create<vector::ReductionOp>(loc, multiReductionOp.kind(), v);
281       result = rewriter.create<vector::InsertElementOp>(
282           loc, reducedValue, result,
283           rewriter.create<arith::ConstantIndexOp>(loc, i));
284     }
285     rewriter.replaceOp(multiReductionOp, result);
286     return success();
287   }
288 };
289 
290 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
291 /// form with both a single parallel and reduction dimension.
292 /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
293 /// The case with a single parallel dimension is a noop and folds away
294 /// separately.
295 struct OneDimMultiReductionToTwoDim
296     : public OpRewritePattern<vector::MultiDimReductionOp> {
297   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
298 
299   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
300                                 PatternRewriter &rewriter) const override {
301     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
302     // Rank-1 or bail.
303     if (srcRank != 1)
304       return failure();
305 
306     auto loc = multiReductionOp.getLoc();
307     auto srcVectorType = multiReductionOp.getSourceVectorType();
308     auto srcShape = srcVectorType.getShape();
309     auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
310                                       srcVectorType.getElementType());
311     assert(!multiReductionOp.getDestType().isa<VectorType>() &&
312            "multi_reduction with a single dimension expects a scalar result");
313 
314     // If the unique dim is reduced and we insert a parallel in front, we need a
315     // {false, true} mask.
316     SmallVector<bool, 2> mask{false, true};
317 
318     /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
319     Value cast = rewriter.create<vector::ShapeCastOp>(
320         loc, castedType, multiReductionOp.source());
321     Value reduced = rewriter.create<vector::MultiDimReductionOp>(
322         loc, cast, mask, multiReductionOp.kind());
323     rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
324                                                    ArrayRef<int64_t>{0});
325     return success();
326   }
327 };
328 
329 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
330     RewritePatternSet &patterns, VectorMultiReductionLowering options) {
331   patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
332       patterns.getContext(), options);
333   patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext());
334   if (options == VectorMultiReductionLowering ::InnerReduction)
335     patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
336   else
337     patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext());
338 }
339