1 //===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===//
2 //
3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This file implements target-independent rewrites of MultiDimReductionOp.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/ImplicitLocOpBuilder.h"
17 #include "mlir/IR/TypeUtilities.h"
18 
19 #define DEBUG_TYPE "vector-multi-reduction"
20 
21 using namespace mlir;
22 
23 /// This file implements the following transformations as composable atomic
24 /// patterns.
25 
26 /// Converts vector.multi_reduction into inner-most/outer-most reduction form
27 /// by using vector.transpose
28 class InnerOuterDimReductionConversion
29     : public OpRewritePattern<vector::MultiDimReductionOp> {
30 public:
31   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
32 
33   explicit InnerOuterDimReductionConversion(
34       MLIRContext *context, vector::VectorMultiReductionLowering options)
35       : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
36         useInnerDimsForReduction(
37             options == vector::VectorMultiReductionLowering::InnerReduction) {}
38 
39   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
40                                 PatternRewriter &rewriter) const override {
41     auto src = multiReductionOp.source();
42     auto loc = multiReductionOp.getLoc();
43     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
44 
45     // Separate reduction and parallel dims
46     auto reductionDimsRange =
47         multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>();
48     auto reductionDims = llvm::to_vector<4>(llvm::map_range(
49         reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); }));
50     llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
51                                                   reductionDims.end());
52     int64_t reductionSize = reductionDims.size();
53     SmallVector<int64_t, 4> parallelDims;
54     for (int64_t i = 0; i < srcRank; ++i)
55       if (!reductionDimsSet.contains(i))
56         parallelDims.push_back(i);
57 
58     // Add transpose only if inner-most/outer-most dimensions are not parallel
59     // and there are parallel dims.
60     if (parallelDims.empty())
61       return failure();
62     if (useInnerDimsForReduction &&
63         (parallelDims ==
64          llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
65       return failure();
66 
67     if (!useInnerDimsForReduction &&
68         (parallelDims !=
69          llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
70       return failure();
71 
72     SmallVector<int64_t, 4> indices;
73     if (useInnerDimsForReduction) {
74       indices.append(parallelDims.begin(), parallelDims.end());
75       indices.append(reductionDims.begin(), reductionDims.end());
76     } else {
77       indices.append(reductionDims.begin(), reductionDims.end());
78       indices.append(parallelDims.begin(), parallelDims.end());
79     }
80     auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
81     SmallVector<bool> reductionMask(srcRank, false);
82     for (int i = 0; i < reductionSize; ++i) {
83       if (useInnerDimsForReduction)
84         reductionMask[srcRank - i - 1] = true;
85       else
86         reductionMask[i] = true;
87     }
88     rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
89         multiReductionOp, transposeOp.result(), reductionMask,
90         multiReductionOp.kind());
91     return success();
92   }
93 
94 private:
95   const bool useInnerDimsForReduction;
96 };
97 
98 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
99 /// dimensions are either inner most or outer most.
100 class ReduceMultiDimReductionRank
101     : public OpRewritePattern<vector::MultiDimReductionOp> {
102 public:
103   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
104 
105   explicit ReduceMultiDimReductionRank(
106       MLIRContext *context, vector::VectorMultiReductionLowering options)
107       : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
108         useInnerDimsForReduction(
109             options == vector::VectorMultiReductionLowering::InnerReduction) {}
110 
111   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
112                                 PatternRewriter &rewriter) const override {
113     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
114     auto srcShape = multiReductionOp.getSourceVectorType().getShape();
115     auto loc = multiReductionOp.getLoc();
116 
117     // If rank less than 2, nothing to do.
118     if (srcRank < 2)
119       return failure();
120 
121     // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
122     SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
123     if (srcRank == 2 && reductionMask.front() != reductionMask.back())
124       return failure();
125 
126     // 1. Separate reduction and parallel dims.
127     SmallVector<int64_t, 4> parallelDims, parallelShapes;
128     SmallVector<int64_t, 4> reductionDims, reductionShapes;
129     for (const auto &it : llvm::enumerate(reductionMask)) {
130       int64_t i = it.index();
131       bool isReduction = it.value();
132       if (isReduction) {
133         reductionDims.push_back(i);
134         reductionShapes.push_back(srcShape[i]);
135       } else {
136         parallelDims.push_back(i);
137         parallelShapes.push_back(srcShape[i]);
138       }
139     }
140 
141     // 2. Compute flattened parallel and reduction sizes.
142     int flattenedParallelDim = 0;
143     int flattenedReductionDim = 0;
144     if (!parallelShapes.empty()) {
145       flattenedParallelDim = 1;
146       for (auto d : parallelShapes)
147         flattenedParallelDim *= d;
148     }
149     if (!reductionShapes.empty()) {
150       flattenedReductionDim = 1;
151       for (auto d : reductionShapes)
152         flattenedReductionDim *= d;
153     }
154     // We must at least have some parallel or some reduction.
155     assert((flattenedParallelDim || flattenedReductionDim) &&
156            "expected at least one parallel or reduction dim");
157 
158     // 3. Fail if reduction/parallel dims are not contiguous.
159     // Check parallelDims are exactly [0 .. size).
160     int64_t counter = 0;
161     if (useInnerDimsForReduction &&
162         llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
163       return failure();
164     // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
165     counter = reductionDims.size();
166     if (!useInnerDimsForReduction &&
167         llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
168       return failure();
169 
170     // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
171     // a single parallel (resp. reduction) dim.
172     SmallVector<bool, 2> mask;
173     SmallVector<int64_t, 2> vectorShape;
174     if (flattenedParallelDim) {
175       mask.push_back(false);
176       vectorShape.push_back(flattenedParallelDim);
177     }
178     if (flattenedReductionDim) {
179       mask.push_back(true);
180       vectorShape.push_back(flattenedReductionDim);
181     }
182     if (!useInnerDimsForReduction && vectorShape.size() == 2) {
183       std::swap(mask.front(), mask.back());
184       std::swap(vectorShape.front(), vectorShape.back());
185     }
186     auto castedType = VectorType::get(
187         vectorShape, multiReductionOp.getSourceVectorType().getElementType());
188     Value cast = rewriter.create<vector::ShapeCastOp>(
189         loc, castedType, multiReductionOp.source());
190 
191     // 5. Creates the flattened form of vector.multi_reduction with inner/outer
192     // most dim as reduction.
193     auto newOp = rewriter.create<vector::MultiDimReductionOp>(
194         loc, cast, mask, multiReductionOp.kind());
195 
196     // 6. If there are no parallel shapes, the result is a scalar.
197     // TODO: support 0-d vectors when available.
198     if (parallelShapes.empty()) {
199       rewriter.replaceOp(multiReductionOp, newOp.dest());
200       return success();
201     }
202 
203     // 7. Creates shape cast for the output n-D -> 2-D
204     VectorType outputCastedType = VectorType::get(
205         parallelShapes,
206         multiReductionOp.getSourceVectorType().getElementType());
207     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
208         multiReductionOp, outputCastedType, newOp.dest());
209     return success();
210   }
211 
212 private:
213   const bool useInnerDimsForReduction;
214 };
215 
216 /// Unrolls vector.multi_reduction with outermost reductions
217 /// and combines results
218 struct TwoDimMultiReductionToElementWise
219     : public OpRewritePattern<vector::MultiDimReductionOp> {
220   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
221 
222   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
223                                 PatternRewriter &rewriter) const override {
224     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
225     // Rank-2 ["parallel", "reduce"] or bail.
226     if (srcRank != 2)
227       return failure();
228 
229     if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
230       return failure();
231 
232     auto loc = multiReductionOp.getLoc();
233     ArrayRef<int64_t> srcShape =
234         multiReductionOp.getSourceVectorType().getShape();
235 
236     Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
237     if (!elementType.isIntOrIndexOrFloat())
238       return failure();
239 
240     Value result =
241         rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), 0)
242             .getResult();
243     for (int64_t i = 1; i < srcShape[0]; i++) {
244       auto operand =
245           rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), i);
246       switch (multiReductionOp.kind()) {
247       case vector::CombiningKind::ADD:
248         if (elementType.isIntOrIndex())
249           result = rewriter.create<arith::AddIOp>(loc, operand, result);
250         else
251           result = rewriter.create<arith::AddFOp>(loc, operand, result);
252         break;
253       case vector::CombiningKind::MUL:
254         if (elementType.isIntOrIndex())
255           result = rewriter.create<arith::MulIOp>(loc, operand, result);
256         else
257           result = rewriter.create<arith::MulFOp>(loc, operand, result);
258         break;
259       case vector::CombiningKind::MINUI:
260         result = rewriter.create<arith::MinUIOp>(loc, operand, result);
261         break;
262       case vector::CombiningKind::MINSI:
263         result = rewriter.create<arith::MinSIOp>(loc, operand, result);
264         break;
265       case vector::CombiningKind::MINF:
266         result = rewriter.create<arith::MinFOp>(loc, operand, result);
267         break;
268       case vector::CombiningKind::MAXUI:
269         result = rewriter.create<arith::MaxUIOp>(loc, operand, result);
270         break;
271       case vector::CombiningKind::MAXSI:
272         result = rewriter.create<arith::MaxSIOp>(loc, operand, result);
273         break;
274       case vector::CombiningKind::MAXF:
275         result = rewriter.create<arith::MaxFOp>(loc, operand, result);
276         break;
277       case vector::CombiningKind::AND:
278         result = rewriter.create<arith::AndIOp>(loc, operand, result);
279         break;
280       case vector::CombiningKind::OR:
281         result = rewriter.create<arith::OrIOp>(loc, operand, result);
282         break;
283       case vector::CombiningKind::XOR:
284         result = rewriter.create<arith::XOrIOp>(loc, operand, result);
285         break;
286       }
287     }
288 
289     rewriter.replaceOp(multiReductionOp, result);
290     return success();
291   }
292 };
293 
294 /// Converts 2d vector.multi_reduction with inner most reduction dimension into
295 /// a sequence of vector.reduction ops.
296 struct TwoDimMultiReductionToReduction
297     : public OpRewritePattern<vector::MultiDimReductionOp> {
298   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
299 
300   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
301                                 PatternRewriter &rewriter) const override {
302     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
303     if (srcRank != 2)
304       return failure();
305 
306     if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
307       return failure();
308 
309     auto loc = multiReductionOp.getLoc();
310     Value result = rewriter.create<ConstantOp>(
311         loc, multiReductionOp.getDestType(),
312         rewriter.getZeroAttr(multiReductionOp.getDestType()));
313     int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
314 
315     // TODO: Add vector::CombiningKind attribute instead of string to
316     // vector.reduction.
317     auto getKindStr = [](vector::CombiningKind kind) {
318       switch (kind) {
319       case vector::CombiningKind::ADD:
320         return "add";
321       case vector::CombiningKind::MUL:
322         return "mul";
323       case vector::CombiningKind::MINUI:
324         return "minui";
325       case vector::CombiningKind::MINSI:
326         return "minsi";
327       case vector::CombiningKind::MINF:
328         return "minf";
329       case vector::CombiningKind::MAXUI:
330         return "maxui";
331       case vector::CombiningKind::MAXSI:
332         return "maxsi";
333       case vector::CombiningKind::MAXF:
334         return "maxf";
335       case vector::CombiningKind::AND:
336         return "and";
337       case vector::CombiningKind::OR:
338         return "or";
339       case vector::CombiningKind::XOR:
340         return "xor";
341       }
342       llvm_unreachable("unknown combining kind");
343     };
344 
345     for (int i = 0; i < outerDim; ++i) {
346       auto v = rewriter.create<vector::ExtractOp>(
347           loc, multiReductionOp.source(), ArrayRef<int64_t>{i});
348       auto reducedValue = rewriter.create<vector::ReductionOp>(
349           loc, getElementTypeOrSelf(multiReductionOp.getDestType()),
350           rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v,
351           ValueRange{});
352       result = rewriter.create<vector::InsertElementOp>(
353           loc, reducedValue, result,
354           rewriter.create<arith::ConstantIndexOp>(loc, i));
355     }
356     rewriter.replaceOp(multiReductionOp, result);
357     return success();
358   }
359 };
360 
361 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
362 /// form with both a single parallel and reduction dimension.
363 /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
364 /// The case with a single parallel dimension is a noop and folds away
365 /// separately.
366 struct OneDimMultiReductionToTwoDim
367     : public OpRewritePattern<vector::MultiDimReductionOp> {
368   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
369 
370   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
371                                 PatternRewriter &rewriter) const override {
372     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
373     // Rank-1 or bail.
374     if (srcRank != 1)
375       return failure();
376 
377     auto loc = multiReductionOp.getLoc();
378     auto srcVectorType = multiReductionOp.getSourceVectorType();
379     auto srcShape = srcVectorType.getShape();
380     auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
381                                       srcVectorType.getElementType());
382     assert(!multiReductionOp.getDestType().isa<VectorType>() &&
383            "multi_reduction with a single dimension expects a scalar result");
384 
385     // If the unique dim is reduced and we insert a parallel in front, we need a
386     // {false, true} mask.
387     SmallVector<bool, 2> mask{false, true};
388 
389     /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
390     Value cast = rewriter.create<vector::ShapeCastOp>(
391         loc, castedType, multiReductionOp.source());
392     Value reduced = rewriter.create<vector::MultiDimReductionOp>(
393         loc, cast, mask, multiReductionOp.kind());
394     rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
395                                                    ArrayRef<int64_t>{0});
396     return success();
397   }
398 };
399 
400 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
401     RewritePatternSet &patterns, VectorMultiReductionLowering options) {
402   patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
403       patterns.getContext(), options);
404   patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext());
405   if (options == VectorMultiReductionLowering ::InnerReduction)
406     patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
407   else
408     patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext());
409 }
410