1 //===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===//
2 //
3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This file implements target-independent rewrites of MultiDimReductionOp.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
15 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/IR/TypeUtilities.h"
19 
20 #define DEBUG_TYPE "vector-multi-reduction"
21 
22 using namespace mlir;
23 
24 /// This file implements the following transformations as composable atomic
25 /// patterns.
26 
27 /// Converts vector.multi_reduction into inner-most/outer-most reduction form
28 /// by using vector.transpose
29 class InnerOuterDimReductionConversion
30     : public OpRewritePattern<vector::MultiDimReductionOp> {
31 public:
32   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
33 
InnerOuterDimReductionConversion(MLIRContext * context,vector::VectorMultiReductionLowering options)34   explicit InnerOuterDimReductionConversion(
35       MLIRContext *context, vector::VectorMultiReductionLowering options)
36       : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
37         useInnerDimsForReduction(
38             options == vector::VectorMultiReductionLowering::InnerReduction) {}
39 
matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,PatternRewriter & rewriter) const40   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
41                                 PatternRewriter &rewriter) const override {
42     auto src = multiReductionOp.getSource();
43     auto loc = multiReductionOp.getLoc();
44     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
45 
46     // Separate reduction and parallel dims
47     auto reductionDimsRange =
48         multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
49     auto reductionDims = llvm::to_vector<4>(llvm::map_range(
50         reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); }));
51     llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
52                                                   reductionDims.end());
53     int64_t reductionSize = reductionDims.size();
54     SmallVector<int64_t, 4> parallelDims;
55     for (int64_t i = 0; i < srcRank; ++i)
56       if (!reductionDimsSet.contains(i))
57         parallelDims.push_back(i);
58 
59     // Add transpose only if inner-most/outer-most dimensions are not parallel
60     // and there are parallel dims.
61     if (parallelDims.empty())
62       return failure();
63     if (useInnerDimsForReduction &&
64         (parallelDims ==
65          llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
66       return failure();
67 
68     if (!useInnerDimsForReduction &&
69         (parallelDims !=
70          llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
71       return failure();
72 
73     SmallVector<int64_t, 4> indices;
74     if (useInnerDimsForReduction) {
75       indices.append(parallelDims.begin(), parallelDims.end());
76       indices.append(reductionDims.begin(), reductionDims.end());
77     } else {
78       indices.append(reductionDims.begin(), reductionDims.end());
79       indices.append(parallelDims.begin(), parallelDims.end());
80     }
81     auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
82     SmallVector<bool> reductionMask(srcRank, false);
83     for (int i = 0; i < reductionSize; ++i) {
84       if (useInnerDimsForReduction)
85         reductionMask[srcRank - i - 1] = true;
86       else
87         reductionMask[i] = true;
88     }
89     rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
90         multiReductionOp, transposeOp.getResult(), multiReductionOp.getAcc(),
91         reductionMask, multiReductionOp.getKind());
92     return success();
93   }
94 
95 private:
96   const bool useInnerDimsForReduction;
97 };
98 
99 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
100 /// dimensions are either inner most or outer most.
101 class ReduceMultiDimReductionRank
102     : public OpRewritePattern<vector::MultiDimReductionOp> {
103 public:
104   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
105 
ReduceMultiDimReductionRank(MLIRContext * context,vector::VectorMultiReductionLowering options)106   explicit ReduceMultiDimReductionRank(
107       MLIRContext *context, vector::VectorMultiReductionLowering options)
108       : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
109         useInnerDimsForReduction(
110             options == vector::VectorMultiReductionLowering::InnerReduction) {}
111 
matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,PatternRewriter & rewriter) const112   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
113                                 PatternRewriter &rewriter) const override {
114     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
115     auto srcShape = multiReductionOp.getSourceVectorType().getShape();
116     auto loc = multiReductionOp.getLoc();
117 
118     // If rank less than 2, nothing to do.
119     if (srcRank < 2)
120       return failure();
121 
122     // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
123     SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
124     if (srcRank == 2 && reductionMask.front() != reductionMask.back())
125       return failure();
126 
127     // 1. Separate reduction and parallel dims.
128     SmallVector<int64_t, 4> parallelDims, parallelShapes;
129     SmallVector<int64_t, 4> reductionDims, reductionShapes;
130     for (const auto &it : llvm::enumerate(reductionMask)) {
131       int64_t i = it.index();
132       bool isReduction = it.value();
133       if (isReduction) {
134         reductionDims.push_back(i);
135         reductionShapes.push_back(srcShape[i]);
136       } else {
137         parallelDims.push_back(i);
138         parallelShapes.push_back(srcShape[i]);
139       }
140     }
141 
142     // 2. Compute flattened parallel and reduction sizes.
143     int flattenedParallelDim = 0;
144     int flattenedReductionDim = 0;
145     if (!parallelShapes.empty()) {
146       flattenedParallelDim = 1;
147       for (auto d : parallelShapes)
148         flattenedParallelDim *= d;
149     }
150     if (!reductionShapes.empty()) {
151       flattenedReductionDim = 1;
152       for (auto d : reductionShapes)
153         flattenedReductionDim *= d;
154     }
155     // We must at least have some parallel or some reduction.
156     assert((flattenedParallelDim || flattenedReductionDim) &&
157            "expected at least one parallel or reduction dim");
158 
159     // 3. Fail if reduction/parallel dims are not contiguous.
160     // Check parallelDims are exactly [0 .. size).
161     int64_t counter = 0;
162     if (useInnerDimsForReduction &&
163         llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
164       return failure();
165     // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
166     counter = reductionDims.size();
167     if (!useInnerDimsForReduction &&
168         llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
169       return failure();
170 
171     // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
172     // a single parallel (resp. reduction) dim.
173     SmallVector<bool, 2> mask;
174     SmallVector<int64_t, 2> vectorShape;
175     if (flattenedParallelDim) {
176       mask.push_back(false);
177       vectorShape.push_back(flattenedParallelDim);
178     }
179     if (flattenedReductionDim) {
180       mask.push_back(true);
181       vectorShape.push_back(flattenedReductionDim);
182     }
183     if (!useInnerDimsForReduction && vectorShape.size() == 2) {
184       std::swap(mask.front(), mask.back());
185       std::swap(vectorShape.front(), vectorShape.back());
186     }
187     auto castedType = VectorType::get(
188         vectorShape, multiReductionOp.getSourceVectorType().getElementType());
189     Value cast = rewriter.create<vector::ShapeCastOp>(
190         loc, castedType, multiReductionOp.getSource());
191     Value acc = multiReductionOp.getAcc();
192     if (flattenedParallelDim) {
193       auto accType = VectorType::get(
194           {flattenedParallelDim},
195           multiReductionOp.getSourceVectorType().getElementType());
196       acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
197     }
198     // 5. Creates the flattened form of vector.multi_reduction with inner/outer
199     // most dim as reduction.
200     auto newOp = rewriter.create<vector::MultiDimReductionOp>(
201         loc, cast, acc, mask, multiReductionOp.getKind());
202 
203     // 6. If there are no parallel shapes, the result is a scalar.
204     // TODO: support 0-d vectors when available.
205     if (parallelShapes.empty()) {
206       rewriter.replaceOp(multiReductionOp, newOp.getDest());
207       return success();
208     }
209 
210     // 7. Creates shape cast for the output n-D -> 2-D
211     VectorType outputCastedType = VectorType::get(
212         parallelShapes,
213         multiReductionOp.getSourceVectorType().getElementType());
214     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
215         multiReductionOp, outputCastedType, newOp.getDest());
216     return success();
217   }
218 
219 private:
220   const bool useInnerDimsForReduction;
221 };
222 
223 /// Unrolls vector.multi_reduction with outermost reductions
224 /// and combines results
225 struct TwoDimMultiReductionToElementWise
226     : public OpRewritePattern<vector::MultiDimReductionOp> {
227   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
228 
matchAndRewriteTwoDimMultiReductionToElementWise229   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
230                                 PatternRewriter &rewriter) const override {
231     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
232     // Rank-2 ["parallel", "reduce"] or bail.
233     if (srcRank != 2)
234       return failure();
235 
236     if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
237       return failure();
238 
239     auto loc = multiReductionOp.getLoc();
240     ArrayRef<int64_t> srcShape =
241         multiReductionOp.getSourceVectorType().getShape();
242 
243     Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
244     if (!elementType.isIntOrIndexOrFloat())
245       return failure();
246 
247     Value result = multiReductionOp.getAcc();
248     for (int64_t i = 0; i < srcShape[0]; i++) {
249       auto operand = rewriter.create<vector::ExtractOp>(
250           loc, multiReductionOp.getSource(), i);
251       result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
252                                   operand, result);
253     }
254 
255     rewriter.replaceOp(multiReductionOp, result);
256     return success();
257   }
258 };
259 
260 /// Converts 2d vector.multi_reduction with inner most reduction dimension into
261 /// a sequence of vector.reduction ops.
262 struct TwoDimMultiReductionToReduction
263     : public OpRewritePattern<vector::MultiDimReductionOp> {
264   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
265 
matchAndRewriteTwoDimMultiReductionToReduction266   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
267                                 PatternRewriter &rewriter) const override {
268     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
269     if (srcRank != 2)
270       return failure();
271 
272     if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
273       return failure();
274 
275     auto loc = multiReductionOp.getLoc();
276     Value result = rewriter.create<arith::ConstantOp>(
277         loc, multiReductionOp.getDestType(),
278         rewriter.getZeroAttr(multiReductionOp.getDestType()));
279     int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
280 
281     for (int i = 0; i < outerDim; ++i) {
282       auto v = rewriter.create<vector::ExtractOp>(
283           loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
284       auto acc = rewriter.create<vector::ExtractOp>(
285           loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
286       auto reducedValue = rewriter.create<vector::ReductionOp>(
287           loc, multiReductionOp.getKind(), v, acc);
288       result = rewriter.create<vector::InsertElementOp>(
289           loc, reducedValue, result,
290           rewriter.create<arith::ConstantIndexOp>(loc, i));
291     }
292     rewriter.replaceOp(multiReductionOp, result);
293     return success();
294   }
295 };
296 
297 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
298 /// form with both a single parallel and reduction dimension.
299 /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
300 /// The case with a single parallel dimension is a noop and folds away
301 /// separately.
302 struct OneDimMultiReductionToTwoDim
303     : public OpRewritePattern<vector::MultiDimReductionOp> {
304   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
305 
matchAndRewriteOneDimMultiReductionToTwoDim306   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
307                                 PatternRewriter &rewriter) const override {
308     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
309     // Rank-1 or bail.
310     if (srcRank != 1)
311       return failure();
312 
313     auto loc = multiReductionOp.getLoc();
314     auto srcVectorType = multiReductionOp.getSourceVectorType();
315     auto srcShape = srcVectorType.getShape();
316     auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
317                                       srcVectorType.getElementType());
318     auto accType =
319         VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
320     assert(!multiReductionOp.getDestType().isa<VectorType>() &&
321            "multi_reduction with a single dimension expects a scalar result");
322 
323     // If the unique dim is reduced and we insert a parallel in front, we need a
324     // {false, true} mask.
325     SmallVector<bool, 2> mask{false, true};
326 
327     /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
328     Value cast = rewriter.create<vector::ShapeCastOp>(
329         loc, castedType, multiReductionOp.getSource());
330     Value castAcc = rewriter.create<vector::BroadcastOp>(
331         loc, accType, multiReductionOp.getAcc());
332     Value reduced = rewriter.create<vector::MultiDimReductionOp>(
333         loc, cast, castAcc, mask, multiReductionOp.getKind());
334     rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
335                                                    ArrayRef<int64_t>{0});
336     return success();
337   }
338 };
339 
populateVectorMultiReductionLoweringPatterns(RewritePatternSet & patterns,VectorMultiReductionLowering options)340 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
341     RewritePatternSet &patterns, VectorMultiReductionLowering options) {
342   patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
343       patterns.getContext(), options);
344   patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext());
345   if (options == VectorMultiReductionLowering ::InnerReduction)
346     patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
347   else
348     patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext());
349 }
350