1 //===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===//
2 //
3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This file implements target-independent rewrites of MultiDimReductionOp.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
15 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/IR/TypeUtilities.h"
19 
20 #define DEBUG_TYPE "vector-multi-reduction"
21 
22 using namespace mlir;
23 
24 /// This file implements the following transformations as composable atomic
25 /// patterns.
26 
27 /// Converts vector.multi_reduction into inner-most/outer-most reduction form
28 /// by using vector.transpose
29 class InnerOuterDimReductionConversion
30     : public OpRewritePattern<vector::MultiDimReductionOp> {
31 public:
32   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
33 
34   explicit InnerOuterDimReductionConversion(
35       MLIRContext *context, vector::VectorMultiReductionLowering options)
36       : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
37         useInnerDimsForReduction(
38             options == vector::VectorMultiReductionLowering::InnerReduction) {}
39 
40   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
41                                 PatternRewriter &rewriter) const override {
42     auto src = multiReductionOp.getSource();
43     auto loc = multiReductionOp.getLoc();
44     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
45 
46     // Separate reduction and parallel dims
47     auto reductionDimsRange =
48         multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
49     auto reductionDims = llvm::to_vector<4>(llvm::map_range(
50         reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); }));
51     llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
52                                                   reductionDims.end());
53     int64_t reductionSize = reductionDims.size();
54     SmallVector<int64_t, 4> parallelDims;
55     for (int64_t i = 0; i < srcRank; ++i)
56       if (!reductionDimsSet.contains(i))
57         parallelDims.push_back(i);
58 
59     // Add transpose only if inner-most/outer-most dimensions are not parallel
60     // and there are parallel dims.
61     if (parallelDims.empty())
62       return failure();
63     if (useInnerDimsForReduction &&
64         (parallelDims ==
65          llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
66       return failure();
67 
68     if (!useInnerDimsForReduction &&
69         (parallelDims !=
70          llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
71       return failure();
72 
73     SmallVector<int64_t, 4> indices;
74     if (useInnerDimsForReduction) {
75       indices.append(parallelDims.begin(), parallelDims.end());
76       indices.append(reductionDims.begin(), reductionDims.end());
77     } else {
78       indices.append(reductionDims.begin(), reductionDims.end());
79       indices.append(parallelDims.begin(), parallelDims.end());
80     }
81     auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
82     SmallVector<bool> reductionMask(srcRank, false);
83     for (int i = 0; i < reductionSize; ++i) {
84       if (useInnerDimsForReduction)
85         reductionMask[srcRank - i - 1] = true;
86       else
87         reductionMask[i] = true;
88     }
89     rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
90         multiReductionOp, transposeOp.getResult(), reductionMask,
91         multiReductionOp.getKind());
92     return success();
93   }
94 
95 private:
96   const bool useInnerDimsForReduction;
97 };
98 
99 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
100 /// dimensions are either inner most or outer most.
101 class ReduceMultiDimReductionRank
102     : public OpRewritePattern<vector::MultiDimReductionOp> {
103 public:
104   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
105 
106   explicit ReduceMultiDimReductionRank(
107       MLIRContext *context, vector::VectorMultiReductionLowering options)
108       : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
109         useInnerDimsForReduction(
110             options == vector::VectorMultiReductionLowering::InnerReduction) {}
111 
112   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
113                                 PatternRewriter &rewriter) const override {
114     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
115     auto srcShape = multiReductionOp.getSourceVectorType().getShape();
116     auto loc = multiReductionOp.getLoc();
117 
118     // If rank less than 2, nothing to do.
119     if (srcRank < 2)
120       return failure();
121 
122     // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
123     SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
124     if (srcRank == 2 && reductionMask.front() != reductionMask.back())
125       return failure();
126 
127     // 1. Separate reduction and parallel dims.
128     SmallVector<int64_t, 4> parallelDims, parallelShapes;
129     SmallVector<int64_t, 4> reductionDims, reductionShapes;
130     for (const auto &it : llvm::enumerate(reductionMask)) {
131       int64_t i = it.index();
132       bool isReduction = it.value();
133       if (isReduction) {
134         reductionDims.push_back(i);
135         reductionShapes.push_back(srcShape[i]);
136       } else {
137         parallelDims.push_back(i);
138         parallelShapes.push_back(srcShape[i]);
139       }
140     }
141 
142     // 2. Compute flattened parallel and reduction sizes.
143     int flattenedParallelDim = 0;
144     int flattenedReductionDim = 0;
145     if (!parallelShapes.empty()) {
146       flattenedParallelDim = 1;
147       for (auto d : parallelShapes)
148         flattenedParallelDim *= d;
149     }
150     if (!reductionShapes.empty()) {
151       flattenedReductionDim = 1;
152       for (auto d : reductionShapes)
153         flattenedReductionDim *= d;
154     }
155     // We must at least have some parallel or some reduction.
156     assert((flattenedParallelDim || flattenedReductionDim) &&
157            "expected at least one parallel or reduction dim");
158 
159     // 3. Fail if reduction/parallel dims are not contiguous.
160     // Check parallelDims are exactly [0 .. size).
161     int64_t counter = 0;
162     if (useInnerDimsForReduction &&
163         llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
164       return failure();
165     // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
166     counter = reductionDims.size();
167     if (!useInnerDimsForReduction &&
168         llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
169       return failure();
170 
171     // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
172     // a single parallel (resp. reduction) dim.
173     SmallVector<bool, 2> mask;
174     SmallVector<int64_t, 2> vectorShape;
175     if (flattenedParallelDim) {
176       mask.push_back(false);
177       vectorShape.push_back(flattenedParallelDim);
178     }
179     if (flattenedReductionDim) {
180       mask.push_back(true);
181       vectorShape.push_back(flattenedReductionDim);
182     }
183     if (!useInnerDimsForReduction && vectorShape.size() == 2) {
184       std::swap(mask.front(), mask.back());
185       std::swap(vectorShape.front(), vectorShape.back());
186     }
187     auto castedType = VectorType::get(
188         vectorShape, multiReductionOp.getSourceVectorType().getElementType());
189     Value cast = rewriter.create<vector::ShapeCastOp>(
190         loc, castedType, multiReductionOp.getSource());
191 
192     // 5. Creates the flattened form of vector.multi_reduction with inner/outer
193     // most dim as reduction.
194     auto newOp = rewriter.create<vector::MultiDimReductionOp>(
195         loc, cast, mask, multiReductionOp.getKind());
196 
197     // 6. If there are no parallel shapes, the result is a scalar.
198     // TODO: support 0-d vectors when available.
199     if (parallelShapes.empty()) {
200       rewriter.replaceOp(multiReductionOp, newOp.getDest());
201       return success();
202     }
203 
204     // 7. Creates shape cast for the output n-D -> 2-D
205     VectorType outputCastedType = VectorType::get(
206         parallelShapes,
207         multiReductionOp.getSourceVectorType().getElementType());
208     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
209         multiReductionOp, outputCastedType, newOp.getDest());
210     return success();
211   }
212 
213 private:
214   const bool useInnerDimsForReduction;
215 };
216 
217 /// Unrolls vector.multi_reduction with outermost reductions
218 /// and combines results
219 struct TwoDimMultiReductionToElementWise
220     : public OpRewritePattern<vector::MultiDimReductionOp> {
221   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
222 
223   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
224                                 PatternRewriter &rewriter) const override {
225     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
226     // Rank-2 ["parallel", "reduce"] or bail.
227     if (srcRank != 2)
228       return failure();
229 
230     if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
231       return failure();
232 
233     auto loc = multiReductionOp.getLoc();
234     ArrayRef<int64_t> srcShape =
235         multiReductionOp.getSourceVectorType().getShape();
236 
237     Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
238     if (!elementType.isIntOrIndexOrFloat())
239       return failure();
240 
241     Value result =
242         rewriter.create<vector::ExtractOp>(loc, multiReductionOp.getSource(), 0)
243             .getResult();
244     for (int64_t i = 1; i < srcShape[0]; i++) {
245       auto operand = rewriter.create<vector::ExtractOp>(
246           loc, multiReductionOp.getSource(), i);
247       result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
248                                   operand, result);
249     }
250 
251     rewriter.replaceOp(multiReductionOp, result);
252     return success();
253   }
254 };
255 
256 /// Converts 2d vector.multi_reduction with inner most reduction dimension into
257 /// a sequence of vector.reduction ops.
258 struct TwoDimMultiReductionToReduction
259     : public OpRewritePattern<vector::MultiDimReductionOp> {
260   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
261 
262   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
263                                 PatternRewriter &rewriter) const override {
264     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
265     if (srcRank != 2)
266       return failure();
267 
268     if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
269       return failure();
270 
271     auto loc = multiReductionOp.getLoc();
272     Value result = rewriter.create<arith::ConstantOp>(
273         loc, multiReductionOp.getDestType(),
274         rewriter.getZeroAttr(multiReductionOp.getDestType()));
275     int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
276 
277     for (int i = 0; i < outerDim; ++i) {
278       auto v = rewriter.create<vector::ExtractOp>(
279           loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
280       auto reducedValue = rewriter.create<vector::ReductionOp>(
281           loc, multiReductionOp.getKind(), v);
282       result = rewriter.create<vector::InsertElementOp>(
283           loc, reducedValue, result,
284           rewriter.create<arith::ConstantIndexOp>(loc, i));
285     }
286     rewriter.replaceOp(multiReductionOp, result);
287     return success();
288   }
289 };
290 
291 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
292 /// form with both a single parallel and reduction dimension.
293 /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
294 /// The case with a single parallel dimension is a noop and folds away
295 /// separately.
296 struct OneDimMultiReductionToTwoDim
297     : public OpRewritePattern<vector::MultiDimReductionOp> {
298   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
299 
300   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
301                                 PatternRewriter &rewriter) const override {
302     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
303     // Rank-1 or bail.
304     if (srcRank != 1)
305       return failure();
306 
307     auto loc = multiReductionOp.getLoc();
308     auto srcVectorType = multiReductionOp.getSourceVectorType();
309     auto srcShape = srcVectorType.getShape();
310     auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
311                                       srcVectorType.getElementType());
312     assert(!multiReductionOp.getDestType().isa<VectorType>() &&
313            "multi_reduction with a single dimension expects a scalar result");
314 
315     // If the unique dim is reduced and we insert a parallel in front, we need a
316     // {false, true} mask.
317     SmallVector<bool, 2> mask{false, true};
318 
319     /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
320     Value cast = rewriter.create<vector::ShapeCastOp>(
321         loc, castedType, multiReductionOp.getSource());
322     Value reduced = rewriter.create<vector::MultiDimReductionOp>(
323         loc, cast, mask, multiReductionOp.getKind());
324     rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
325                                                    ArrayRef<int64_t>{0});
326     return success();
327   }
328 };
329 
330 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
331     RewritePatternSet &patterns, VectorMultiReductionLowering options) {
332   patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
333       patterns.getContext(), options);
334   patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext());
335   if (options == VectorMultiReductionLowering ::InnerReduction)
336     patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
337   else
338     patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext());
339 }
340