1 //===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===//
2 //
3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This file implements target-independent rewrites of MultiDimReductionOp.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
15 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 #include "mlir/IR/TypeUtilities.h"
19
20 #define DEBUG_TYPE "vector-multi-reduction"
21
22 using namespace mlir;
23
24 /// This file implements the following transformations as composable atomic
25 /// patterns.
26
27 /// Converts vector.multi_reduction into inner-most/outer-most reduction form
28 /// by using vector.transpose
29 class InnerOuterDimReductionConversion
30 : public OpRewritePattern<vector::MultiDimReductionOp> {
31 public:
32 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
33
InnerOuterDimReductionConversion(MLIRContext * context,vector::VectorMultiReductionLowering options)34 explicit InnerOuterDimReductionConversion(
35 MLIRContext *context, vector::VectorMultiReductionLowering options)
36 : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
37 useInnerDimsForReduction(
38 options == vector::VectorMultiReductionLowering::InnerReduction) {}
39
matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,PatternRewriter & rewriter) const40 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
41 PatternRewriter &rewriter) const override {
42 auto src = multiReductionOp.getSource();
43 auto loc = multiReductionOp.getLoc();
44 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
45
46 // Separate reduction and parallel dims
47 auto reductionDimsRange =
48 multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
49 auto reductionDims = llvm::to_vector<4>(llvm::map_range(
50 reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); }));
51 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
52 reductionDims.end());
53 int64_t reductionSize = reductionDims.size();
54 SmallVector<int64_t, 4> parallelDims;
55 for (int64_t i = 0; i < srcRank; ++i)
56 if (!reductionDimsSet.contains(i))
57 parallelDims.push_back(i);
58
59 // Add transpose only if inner-most/outer-most dimensions are not parallel
60 // and there are parallel dims.
61 if (parallelDims.empty())
62 return failure();
63 if (useInnerDimsForReduction &&
64 (parallelDims ==
65 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
66 return failure();
67
68 if (!useInnerDimsForReduction &&
69 (parallelDims !=
70 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
71 return failure();
72
73 SmallVector<int64_t, 4> indices;
74 if (useInnerDimsForReduction) {
75 indices.append(parallelDims.begin(), parallelDims.end());
76 indices.append(reductionDims.begin(), reductionDims.end());
77 } else {
78 indices.append(reductionDims.begin(), reductionDims.end());
79 indices.append(parallelDims.begin(), parallelDims.end());
80 }
81 auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
82 SmallVector<bool> reductionMask(srcRank, false);
83 for (int i = 0; i < reductionSize; ++i) {
84 if (useInnerDimsForReduction)
85 reductionMask[srcRank - i - 1] = true;
86 else
87 reductionMask[i] = true;
88 }
89 rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
90 multiReductionOp, transposeOp.getResult(), multiReductionOp.getAcc(),
91 reductionMask, multiReductionOp.getKind());
92 return success();
93 }
94
95 private:
96 const bool useInnerDimsForReduction;
97 };
98
99 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
100 /// dimensions are either inner most or outer most.
101 class ReduceMultiDimReductionRank
102 : public OpRewritePattern<vector::MultiDimReductionOp> {
103 public:
104 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
105
ReduceMultiDimReductionRank(MLIRContext * context,vector::VectorMultiReductionLowering options)106 explicit ReduceMultiDimReductionRank(
107 MLIRContext *context, vector::VectorMultiReductionLowering options)
108 : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
109 useInnerDimsForReduction(
110 options == vector::VectorMultiReductionLowering::InnerReduction) {}
111
matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,PatternRewriter & rewriter) const112 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
113 PatternRewriter &rewriter) const override {
114 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
115 auto srcShape = multiReductionOp.getSourceVectorType().getShape();
116 auto loc = multiReductionOp.getLoc();
117
118 // If rank less than 2, nothing to do.
119 if (srcRank < 2)
120 return failure();
121
122 // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
123 SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
124 if (srcRank == 2 && reductionMask.front() != reductionMask.back())
125 return failure();
126
127 // 1. Separate reduction and parallel dims.
128 SmallVector<int64_t, 4> parallelDims, parallelShapes;
129 SmallVector<int64_t, 4> reductionDims, reductionShapes;
130 for (const auto &it : llvm::enumerate(reductionMask)) {
131 int64_t i = it.index();
132 bool isReduction = it.value();
133 if (isReduction) {
134 reductionDims.push_back(i);
135 reductionShapes.push_back(srcShape[i]);
136 } else {
137 parallelDims.push_back(i);
138 parallelShapes.push_back(srcShape[i]);
139 }
140 }
141
142 // 2. Compute flattened parallel and reduction sizes.
143 int flattenedParallelDim = 0;
144 int flattenedReductionDim = 0;
145 if (!parallelShapes.empty()) {
146 flattenedParallelDim = 1;
147 for (auto d : parallelShapes)
148 flattenedParallelDim *= d;
149 }
150 if (!reductionShapes.empty()) {
151 flattenedReductionDim = 1;
152 for (auto d : reductionShapes)
153 flattenedReductionDim *= d;
154 }
155 // We must at least have some parallel or some reduction.
156 assert((flattenedParallelDim || flattenedReductionDim) &&
157 "expected at least one parallel or reduction dim");
158
159 // 3. Fail if reduction/parallel dims are not contiguous.
160 // Check parallelDims are exactly [0 .. size).
161 int64_t counter = 0;
162 if (useInnerDimsForReduction &&
163 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
164 return failure();
165 // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
166 counter = reductionDims.size();
167 if (!useInnerDimsForReduction &&
168 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
169 return failure();
170
171 // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
172 // a single parallel (resp. reduction) dim.
173 SmallVector<bool, 2> mask;
174 SmallVector<int64_t, 2> vectorShape;
175 if (flattenedParallelDim) {
176 mask.push_back(false);
177 vectorShape.push_back(flattenedParallelDim);
178 }
179 if (flattenedReductionDim) {
180 mask.push_back(true);
181 vectorShape.push_back(flattenedReductionDim);
182 }
183 if (!useInnerDimsForReduction && vectorShape.size() == 2) {
184 std::swap(mask.front(), mask.back());
185 std::swap(vectorShape.front(), vectorShape.back());
186 }
187 auto castedType = VectorType::get(
188 vectorShape, multiReductionOp.getSourceVectorType().getElementType());
189 Value cast = rewriter.create<vector::ShapeCastOp>(
190 loc, castedType, multiReductionOp.getSource());
191 Value acc = multiReductionOp.getAcc();
192 if (flattenedParallelDim) {
193 auto accType = VectorType::get(
194 {flattenedParallelDim},
195 multiReductionOp.getSourceVectorType().getElementType());
196 acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
197 }
198 // 5. Creates the flattened form of vector.multi_reduction with inner/outer
199 // most dim as reduction.
200 auto newOp = rewriter.create<vector::MultiDimReductionOp>(
201 loc, cast, acc, mask, multiReductionOp.getKind());
202
203 // 6. If there are no parallel shapes, the result is a scalar.
204 // TODO: support 0-d vectors when available.
205 if (parallelShapes.empty()) {
206 rewriter.replaceOp(multiReductionOp, newOp.getDest());
207 return success();
208 }
209
210 // 7. Creates shape cast for the output n-D -> 2-D
211 VectorType outputCastedType = VectorType::get(
212 parallelShapes,
213 multiReductionOp.getSourceVectorType().getElementType());
214 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
215 multiReductionOp, outputCastedType, newOp.getDest());
216 return success();
217 }
218
219 private:
220 const bool useInnerDimsForReduction;
221 };
222
223 /// Unrolls vector.multi_reduction with outermost reductions
224 /// and combines results
225 struct TwoDimMultiReductionToElementWise
226 : public OpRewritePattern<vector::MultiDimReductionOp> {
227 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
228
matchAndRewriteTwoDimMultiReductionToElementWise229 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
230 PatternRewriter &rewriter) const override {
231 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
232 // Rank-2 ["parallel", "reduce"] or bail.
233 if (srcRank != 2)
234 return failure();
235
236 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
237 return failure();
238
239 auto loc = multiReductionOp.getLoc();
240 ArrayRef<int64_t> srcShape =
241 multiReductionOp.getSourceVectorType().getShape();
242
243 Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
244 if (!elementType.isIntOrIndexOrFloat())
245 return failure();
246
247 Value result = multiReductionOp.getAcc();
248 for (int64_t i = 0; i < srcShape[0]; i++) {
249 auto operand = rewriter.create<vector::ExtractOp>(
250 loc, multiReductionOp.getSource(), i);
251 result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
252 operand, result);
253 }
254
255 rewriter.replaceOp(multiReductionOp, result);
256 return success();
257 }
258 };
259
260 /// Converts 2d vector.multi_reduction with inner most reduction dimension into
261 /// a sequence of vector.reduction ops.
262 struct TwoDimMultiReductionToReduction
263 : public OpRewritePattern<vector::MultiDimReductionOp> {
264 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
265
matchAndRewriteTwoDimMultiReductionToReduction266 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
267 PatternRewriter &rewriter) const override {
268 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
269 if (srcRank != 2)
270 return failure();
271
272 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
273 return failure();
274
275 auto loc = multiReductionOp.getLoc();
276 Value result = rewriter.create<arith::ConstantOp>(
277 loc, multiReductionOp.getDestType(),
278 rewriter.getZeroAttr(multiReductionOp.getDestType()));
279 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
280
281 for (int i = 0; i < outerDim; ++i) {
282 auto v = rewriter.create<vector::ExtractOp>(
283 loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
284 auto acc = rewriter.create<vector::ExtractOp>(
285 loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
286 auto reducedValue = rewriter.create<vector::ReductionOp>(
287 loc, multiReductionOp.getKind(), v, acc);
288 result = rewriter.create<vector::InsertElementOp>(
289 loc, reducedValue, result,
290 rewriter.create<arith::ConstantIndexOp>(loc, i));
291 }
292 rewriter.replaceOp(multiReductionOp, result);
293 return success();
294 }
295 };
296
297 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
298 /// form with both a single parallel and reduction dimension.
299 /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
300 /// The case with a single parallel dimension is a noop and folds away
301 /// separately.
302 struct OneDimMultiReductionToTwoDim
303 : public OpRewritePattern<vector::MultiDimReductionOp> {
304 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
305
matchAndRewriteOneDimMultiReductionToTwoDim306 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
307 PatternRewriter &rewriter) const override {
308 auto srcRank = multiReductionOp.getSourceVectorType().getRank();
309 // Rank-1 or bail.
310 if (srcRank != 1)
311 return failure();
312
313 auto loc = multiReductionOp.getLoc();
314 auto srcVectorType = multiReductionOp.getSourceVectorType();
315 auto srcShape = srcVectorType.getShape();
316 auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
317 srcVectorType.getElementType());
318 auto accType =
319 VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
320 assert(!multiReductionOp.getDestType().isa<VectorType>() &&
321 "multi_reduction with a single dimension expects a scalar result");
322
323 // If the unique dim is reduced and we insert a parallel in front, we need a
324 // {false, true} mask.
325 SmallVector<bool, 2> mask{false, true};
326
327 /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
328 Value cast = rewriter.create<vector::ShapeCastOp>(
329 loc, castedType, multiReductionOp.getSource());
330 Value castAcc = rewriter.create<vector::BroadcastOp>(
331 loc, accType, multiReductionOp.getAcc());
332 Value reduced = rewriter.create<vector::MultiDimReductionOp>(
333 loc, cast, castAcc, mask, multiReductionOp.getKind());
334 rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
335 ArrayRef<int64_t>{0});
336 return success();
337 }
338 };
339
populateVectorMultiReductionLoweringPatterns(RewritePatternSet & patterns,VectorMultiReductionLowering options)340 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
341 RewritePatternSet &patterns, VectorMultiReductionLowering options) {
342 patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
343 patterns.getContext(), options);
344 patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext());
345 if (options == VectorMultiReductionLowering ::InnerReduction)
346 patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
347 else
348 patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext());
349 }
350