199ef9eebSMatthias Springer //===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===//
299ef9eebSMatthias Springer //
399ef9eebSMatthias Springer /// Part of the LLVM Project, under the Apache License v2.0 with LLVM
499ef9eebSMatthias Springer /// Exceptions. See https://llvm.org/LICENSE.txt for license information.
599ef9eebSMatthias Springer /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699ef9eebSMatthias Springer //
799ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
899ef9eebSMatthias Springer //
999ef9eebSMatthias Springer /// This file implements target-independent rewrites of MultiDimReductionOp.
1099ef9eebSMatthias Springer //
1199ef9eebSMatthias Springer //===----------------------------------------------------------------------===//
1299ef9eebSMatthias Springer
13eda6f907SRiver Riddle #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1499ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
1599ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
1699ef9eebSMatthias Springer #include "mlir/IR/Builders.h"
1799ef9eebSMatthias Springer #include "mlir/IR/ImplicitLocOpBuilder.h"
1899ef9eebSMatthias Springer #include "mlir/IR/TypeUtilities.h"
1999ef9eebSMatthias Springer
2099ef9eebSMatthias Springer #define DEBUG_TYPE "vector-multi-reduction"
2199ef9eebSMatthias Springer
2299ef9eebSMatthias Springer using namespace mlir;
2399ef9eebSMatthias Springer
2499ef9eebSMatthias Springer /// This file implements the following transformations as composable atomic
2599ef9eebSMatthias Springer /// patterns.
2699ef9eebSMatthias Springer
2799ef9eebSMatthias Springer /// Converts vector.multi_reduction into inner-most/outer-most reduction form
2899ef9eebSMatthias Springer /// by using vector.transpose
2999ef9eebSMatthias Springer class InnerOuterDimReductionConversion
3099ef9eebSMatthias Springer : public OpRewritePattern<vector::MultiDimReductionOp> {
3199ef9eebSMatthias Springer public:
3299ef9eebSMatthias Springer using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
3399ef9eebSMatthias Springer
InnerOuterDimReductionConversion(MLIRContext * context,vector::VectorMultiReductionLowering options)3499ef9eebSMatthias Springer explicit InnerOuterDimReductionConversion(
3599ef9eebSMatthias Springer MLIRContext *context, vector::VectorMultiReductionLowering options)
3699ef9eebSMatthias Springer : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
3799ef9eebSMatthias Springer useInnerDimsForReduction(
3899ef9eebSMatthias Springer options == vector::VectorMultiReductionLowering::InnerReduction) {}
3999ef9eebSMatthias Springer
matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,PatternRewriter & rewriter) const4099ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
4199ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
427c38fd60SJacques Pienaar auto src = multiReductionOp.getSource();
4399ef9eebSMatthias Springer auto loc = multiReductionOp.getLoc();
4499ef9eebSMatthias Springer auto srcRank = multiReductionOp.getSourceVectorType().getRank();
4599ef9eebSMatthias Springer
4699ef9eebSMatthias Springer // Separate reduction and parallel dims
4799ef9eebSMatthias Springer auto reductionDimsRange =
487c38fd60SJacques Pienaar multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
4999ef9eebSMatthias Springer auto reductionDims = llvm::to_vector<4>(llvm::map_range(
5099ef9eebSMatthias Springer reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); }));
5199ef9eebSMatthias Springer llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
5299ef9eebSMatthias Springer reductionDims.end());
5399ef9eebSMatthias Springer int64_t reductionSize = reductionDims.size();
5499ef9eebSMatthias Springer SmallVector<int64_t, 4> parallelDims;
5599ef9eebSMatthias Springer for (int64_t i = 0; i < srcRank; ++i)
5699ef9eebSMatthias Springer if (!reductionDimsSet.contains(i))
5799ef9eebSMatthias Springer parallelDims.push_back(i);
5899ef9eebSMatthias Springer
5999ef9eebSMatthias Springer // Add transpose only if inner-most/outer-most dimensions are not parallel
6099ef9eebSMatthias Springer // and there are parallel dims.
6199ef9eebSMatthias Springer if (parallelDims.empty())
6299ef9eebSMatthias Springer return failure();
6399ef9eebSMatthias Springer if (useInnerDimsForReduction &&
6499ef9eebSMatthias Springer (parallelDims ==
6599ef9eebSMatthias Springer llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
6699ef9eebSMatthias Springer return failure();
6799ef9eebSMatthias Springer
6899ef9eebSMatthias Springer if (!useInnerDimsForReduction &&
6999ef9eebSMatthias Springer (parallelDims !=
7099ef9eebSMatthias Springer llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
7199ef9eebSMatthias Springer return failure();
7299ef9eebSMatthias Springer
7399ef9eebSMatthias Springer SmallVector<int64_t, 4> indices;
7499ef9eebSMatthias Springer if (useInnerDimsForReduction) {
7599ef9eebSMatthias Springer indices.append(parallelDims.begin(), parallelDims.end());
7699ef9eebSMatthias Springer indices.append(reductionDims.begin(), reductionDims.end());
7799ef9eebSMatthias Springer } else {
7899ef9eebSMatthias Springer indices.append(reductionDims.begin(), reductionDims.end());
7999ef9eebSMatthias Springer indices.append(parallelDims.begin(), parallelDims.end());
8099ef9eebSMatthias Springer }
8199ef9eebSMatthias Springer auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
8299ef9eebSMatthias Springer SmallVector<bool> reductionMask(srcRank, false);
8399ef9eebSMatthias Springer for (int i = 0; i < reductionSize; ++i) {
8499ef9eebSMatthias Springer if (useInnerDimsForReduction)
8599ef9eebSMatthias Springer reductionMask[srcRank - i - 1] = true;
8699ef9eebSMatthias Springer else
8799ef9eebSMatthias Springer reductionMask[i] = true;
8899ef9eebSMatthias Springer }
8999ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
90*051b36baSThomas Raoux multiReductionOp, transposeOp.getResult(), multiReductionOp.getAcc(),
91*051b36baSThomas Raoux reductionMask, multiReductionOp.getKind());
9299ef9eebSMatthias Springer return success();
9399ef9eebSMatthias Springer }
9499ef9eebSMatthias Springer
9599ef9eebSMatthias Springer private:
9699ef9eebSMatthias Springer const bool useInnerDimsForReduction;
9799ef9eebSMatthias Springer };
9899ef9eebSMatthias Springer
9999ef9eebSMatthias Springer /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
10099ef9eebSMatthias Springer /// dimensions are either inner most or outer most.
10199ef9eebSMatthias Springer class ReduceMultiDimReductionRank
10299ef9eebSMatthias Springer : public OpRewritePattern<vector::MultiDimReductionOp> {
10399ef9eebSMatthias Springer public:
10499ef9eebSMatthias Springer using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
10599ef9eebSMatthias Springer
ReduceMultiDimReductionRank(MLIRContext * context,vector::VectorMultiReductionLowering options)10699ef9eebSMatthias Springer explicit ReduceMultiDimReductionRank(
10799ef9eebSMatthias Springer MLIRContext *context, vector::VectorMultiReductionLowering options)
10899ef9eebSMatthias Springer : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
10999ef9eebSMatthias Springer useInnerDimsForReduction(
11099ef9eebSMatthias Springer options == vector::VectorMultiReductionLowering::InnerReduction) {}
11199ef9eebSMatthias Springer
matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,PatternRewriter & rewriter) const11299ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
11399ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
11499ef9eebSMatthias Springer auto srcRank = multiReductionOp.getSourceVectorType().getRank();
11599ef9eebSMatthias Springer auto srcShape = multiReductionOp.getSourceVectorType().getShape();
11699ef9eebSMatthias Springer auto loc = multiReductionOp.getLoc();
11799ef9eebSMatthias Springer
11899ef9eebSMatthias Springer // If rank less than 2, nothing to do.
11999ef9eebSMatthias Springer if (srcRank < 2)
12099ef9eebSMatthias Springer return failure();
12199ef9eebSMatthias Springer
12299ef9eebSMatthias Springer // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
12399ef9eebSMatthias Springer SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
12499ef9eebSMatthias Springer if (srcRank == 2 && reductionMask.front() != reductionMask.back())
12599ef9eebSMatthias Springer return failure();
12699ef9eebSMatthias Springer
12799ef9eebSMatthias Springer // 1. Separate reduction and parallel dims.
12899ef9eebSMatthias Springer SmallVector<int64_t, 4> parallelDims, parallelShapes;
12999ef9eebSMatthias Springer SmallVector<int64_t, 4> reductionDims, reductionShapes;
13099ef9eebSMatthias Springer for (const auto &it : llvm::enumerate(reductionMask)) {
13199ef9eebSMatthias Springer int64_t i = it.index();
13299ef9eebSMatthias Springer bool isReduction = it.value();
13399ef9eebSMatthias Springer if (isReduction) {
13499ef9eebSMatthias Springer reductionDims.push_back(i);
13599ef9eebSMatthias Springer reductionShapes.push_back(srcShape[i]);
13699ef9eebSMatthias Springer } else {
13799ef9eebSMatthias Springer parallelDims.push_back(i);
13899ef9eebSMatthias Springer parallelShapes.push_back(srcShape[i]);
13999ef9eebSMatthias Springer }
14099ef9eebSMatthias Springer }
14199ef9eebSMatthias Springer
14299ef9eebSMatthias Springer // 2. Compute flattened parallel and reduction sizes.
14399ef9eebSMatthias Springer int flattenedParallelDim = 0;
14499ef9eebSMatthias Springer int flattenedReductionDim = 0;
14599ef9eebSMatthias Springer if (!parallelShapes.empty()) {
14699ef9eebSMatthias Springer flattenedParallelDim = 1;
14799ef9eebSMatthias Springer for (auto d : parallelShapes)
14899ef9eebSMatthias Springer flattenedParallelDim *= d;
14999ef9eebSMatthias Springer }
15099ef9eebSMatthias Springer if (!reductionShapes.empty()) {
15199ef9eebSMatthias Springer flattenedReductionDim = 1;
15299ef9eebSMatthias Springer for (auto d : reductionShapes)
15399ef9eebSMatthias Springer flattenedReductionDim *= d;
15499ef9eebSMatthias Springer }
15599ef9eebSMatthias Springer // We must at least have some parallel or some reduction.
15699ef9eebSMatthias Springer assert((flattenedParallelDim || flattenedReductionDim) &&
15799ef9eebSMatthias Springer "expected at least one parallel or reduction dim");
15899ef9eebSMatthias Springer
15999ef9eebSMatthias Springer // 3. Fail if reduction/parallel dims are not contiguous.
16099ef9eebSMatthias Springer // Check parallelDims are exactly [0 .. size).
16199ef9eebSMatthias Springer int64_t counter = 0;
16299ef9eebSMatthias Springer if (useInnerDimsForReduction &&
16399ef9eebSMatthias Springer llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
16499ef9eebSMatthias Springer return failure();
16599ef9eebSMatthias Springer // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
16699ef9eebSMatthias Springer counter = reductionDims.size();
16799ef9eebSMatthias Springer if (!useInnerDimsForReduction &&
16899ef9eebSMatthias Springer llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
16999ef9eebSMatthias Springer return failure();
17099ef9eebSMatthias Springer
17199ef9eebSMatthias Springer // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
17299ef9eebSMatthias Springer // a single parallel (resp. reduction) dim.
17399ef9eebSMatthias Springer SmallVector<bool, 2> mask;
17499ef9eebSMatthias Springer SmallVector<int64_t, 2> vectorShape;
17599ef9eebSMatthias Springer if (flattenedParallelDim) {
17699ef9eebSMatthias Springer mask.push_back(false);
17799ef9eebSMatthias Springer vectorShape.push_back(flattenedParallelDim);
17899ef9eebSMatthias Springer }
17999ef9eebSMatthias Springer if (flattenedReductionDim) {
18099ef9eebSMatthias Springer mask.push_back(true);
18199ef9eebSMatthias Springer vectorShape.push_back(flattenedReductionDim);
18299ef9eebSMatthias Springer }
18399ef9eebSMatthias Springer if (!useInnerDimsForReduction && vectorShape.size() == 2) {
18499ef9eebSMatthias Springer std::swap(mask.front(), mask.back());
18599ef9eebSMatthias Springer std::swap(vectorShape.front(), vectorShape.back());
18699ef9eebSMatthias Springer }
18799ef9eebSMatthias Springer auto castedType = VectorType::get(
18899ef9eebSMatthias Springer vectorShape, multiReductionOp.getSourceVectorType().getElementType());
18999ef9eebSMatthias Springer Value cast = rewriter.create<vector::ShapeCastOp>(
1907c38fd60SJacques Pienaar loc, castedType, multiReductionOp.getSource());
191*051b36baSThomas Raoux Value acc = multiReductionOp.getAcc();
192*051b36baSThomas Raoux if (flattenedParallelDim) {
193*051b36baSThomas Raoux auto accType = VectorType::get(
194*051b36baSThomas Raoux {flattenedParallelDim},
195*051b36baSThomas Raoux multiReductionOp.getSourceVectorType().getElementType());
196*051b36baSThomas Raoux acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
197*051b36baSThomas Raoux }
19899ef9eebSMatthias Springer // 5. Creates the flattened form of vector.multi_reduction with inner/outer
19999ef9eebSMatthias Springer // most dim as reduction.
20099ef9eebSMatthias Springer auto newOp = rewriter.create<vector::MultiDimReductionOp>(
201*051b36baSThomas Raoux loc, cast, acc, mask, multiReductionOp.getKind());
20299ef9eebSMatthias Springer
20399ef9eebSMatthias Springer // 6. If there are no parallel shapes, the result is a scalar.
20499ef9eebSMatthias Springer // TODO: support 0-d vectors when available.
20599ef9eebSMatthias Springer if (parallelShapes.empty()) {
2067c38fd60SJacques Pienaar rewriter.replaceOp(multiReductionOp, newOp.getDest());
20799ef9eebSMatthias Springer return success();
20899ef9eebSMatthias Springer }
20999ef9eebSMatthias Springer
21099ef9eebSMatthias Springer // 7. Creates shape cast for the output n-D -> 2-D
21199ef9eebSMatthias Springer VectorType outputCastedType = VectorType::get(
21299ef9eebSMatthias Springer parallelShapes,
21399ef9eebSMatthias Springer multiReductionOp.getSourceVectorType().getElementType());
21499ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2157c38fd60SJacques Pienaar multiReductionOp, outputCastedType, newOp.getDest());
21699ef9eebSMatthias Springer return success();
21799ef9eebSMatthias Springer }
21899ef9eebSMatthias Springer
21999ef9eebSMatthias Springer private:
22099ef9eebSMatthias Springer const bool useInnerDimsForReduction;
22199ef9eebSMatthias Springer };
22299ef9eebSMatthias Springer
22399ef9eebSMatthias Springer /// Unrolls vector.multi_reduction with outermost reductions
22499ef9eebSMatthias Springer /// and combines results
22599ef9eebSMatthias Springer struct TwoDimMultiReductionToElementWise
22699ef9eebSMatthias Springer : public OpRewritePattern<vector::MultiDimReductionOp> {
22799ef9eebSMatthias Springer using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
22899ef9eebSMatthias Springer
matchAndRewriteTwoDimMultiReductionToElementWise22999ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
23099ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
23199ef9eebSMatthias Springer auto srcRank = multiReductionOp.getSourceVectorType().getRank();
23299ef9eebSMatthias Springer // Rank-2 ["parallel", "reduce"] or bail.
23399ef9eebSMatthias Springer if (srcRank != 2)
23499ef9eebSMatthias Springer return failure();
23599ef9eebSMatthias Springer
23699ef9eebSMatthias Springer if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
23799ef9eebSMatthias Springer return failure();
23899ef9eebSMatthias Springer
23999ef9eebSMatthias Springer auto loc = multiReductionOp.getLoc();
24099ef9eebSMatthias Springer ArrayRef<int64_t> srcShape =
24199ef9eebSMatthias Springer multiReductionOp.getSourceVectorType().getShape();
24299ef9eebSMatthias Springer
24399ef9eebSMatthias Springer Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
24499ef9eebSMatthias Springer if (!elementType.isIntOrIndexOrFloat())
24599ef9eebSMatthias Springer return failure();
24699ef9eebSMatthias Springer
247*051b36baSThomas Raoux Value result = multiReductionOp.getAcc();
248*051b36baSThomas Raoux for (int64_t i = 0; i < srcShape[0]; i++) {
2497c38fd60SJacques Pienaar auto operand = rewriter.create<vector::ExtractOp>(
2507c38fd60SJacques Pienaar loc, multiReductionOp.getSource(), i);
2517c38fd60SJacques Pienaar result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
2529b5a3d14SMatthias Springer operand, result);
25399ef9eebSMatthias Springer }
25499ef9eebSMatthias Springer
25599ef9eebSMatthias Springer rewriter.replaceOp(multiReductionOp, result);
25699ef9eebSMatthias Springer return success();
25799ef9eebSMatthias Springer }
25899ef9eebSMatthias Springer };
25999ef9eebSMatthias Springer
26099ef9eebSMatthias Springer /// Converts 2d vector.multi_reduction with inner most reduction dimension into
26199ef9eebSMatthias Springer /// a sequence of vector.reduction ops.
26299ef9eebSMatthias Springer struct TwoDimMultiReductionToReduction
26399ef9eebSMatthias Springer : public OpRewritePattern<vector::MultiDimReductionOp> {
26499ef9eebSMatthias Springer using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
26599ef9eebSMatthias Springer
matchAndRewriteTwoDimMultiReductionToReduction26699ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
26799ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
26899ef9eebSMatthias Springer auto srcRank = multiReductionOp.getSourceVectorType().getRank();
26999ef9eebSMatthias Springer if (srcRank != 2)
27099ef9eebSMatthias Springer return failure();
27199ef9eebSMatthias Springer
27299ef9eebSMatthias Springer if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
27399ef9eebSMatthias Springer return failure();
27499ef9eebSMatthias Springer
27599ef9eebSMatthias Springer auto loc = multiReductionOp.getLoc();
2768e123ca6SRiver Riddle Value result = rewriter.create<arith::ConstantOp>(
27799ef9eebSMatthias Springer loc, multiReductionOp.getDestType(),
27899ef9eebSMatthias Springer rewriter.getZeroAttr(multiReductionOp.getDestType()));
27999ef9eebSMatthias Springer int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
28099ef9eebSMatthias Springer
28199ef9eebSMatthias Springer for (int i = 0; i < outerDim; ++i) {
28299ef9eebSMatthias Springer auto v = rewriter.create<vector::ExtractOp>(
2837c38fd60SJacques Pienaar loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
284*051b36baSThomas Raoux auto acc = rewriter.create<vector::ExtractOp>(
285*051b36baSThomas Raoux loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
2867c38fd60SJacques Pienaar auto reducedValue = rewriter.create<vector::ReductionOp>(
287*051b36baSThomas Raoux loc, multiReductionOp.getKind(), v, acc);
28899ef9eebSMatthias Springer result = rewriter.create<vector::InsertElementOp>(
28999ef9eebSMatthias Springer loc, reducedValue, result,
29099ef9eebSMatthias Springer rewriter.create<arith::ConstantIndexOp>(loc, i));
29199ef9eebSMatthias Springer }
29299ef9eebSMatthias Springer rewriter.replaceOp(multiReductionOp, result);
29399ef9eebSMatthias Springer return success();
29499ef9eebSMatthias Springer }
29599ef9eebSMatthias Springer };
29699ef9eebSMatthias Springer
29799ef9eebSMatthias Springer /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
29899ef9eebSMatthias Springer /// form with both a single parallel and reduction dimension.
29999ef9eebSMatthias Springer /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
30099ef9eebSMatthias Springer /// The case with a single parallel dimension is a noop and folds away
30199ef9eebSMatthias Springer /// separately.
30299ef9eebSMatthias Springer struct OneDimMultiReductionToTwoDim
30399ef9eebSMatthias Springer : public OpRewritePattern<vector::MultiDimReductionOp> {
30499ef9eebSMatthias Springer using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
30599ef9eebSMatthias Springer
matchAndRewriteOneDimMultiReductionToTwoDim30699ef9eebSMatthias Springer LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
30799ef9eebSMatthias Springer PatternRewriter &rewriter) const override {
30899ef9eebSMatthias Springer auto srcRank = multiReductionOp.getSourceVectorType().getRank();
30999ef9eebSMatthias Springer // Rank-1 or bail.
31099ef9eebSMatthias Springer if (srcRank != 1)
31199ef9eebSMatthias Springer return failure();
31299ef9eebSMatthias Springer
31399ef9eebSMatthias Springer auto loc = multiReductionOp.getLoc();
31499ef9eebSMatthias Springer auto srcVectorType = multiReductionOp.getSourceVectorType();
31599ef9eebSMatthias Springer auto srcShape = srcVectorType.getShape();
31699ef9eebSMatthias Springer auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
31799ef9eebSMatthias Springer srcVectorType.getElementType());
318*051b36baSThomas Raoux auto accType =
319*051b36baSThomas Raoux VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
32099ef9eebSMatthias Springer assert(!multiReductionOp.getDestType().isa<VectorType>() &&
32199ef9eebSMatthias Springer "multi_reduction with a single dimension expects a scalar result");
32299ef9eebSMatthias Springer
32399ef9eebSMatthias Springer // If the unique dim is reduced and we insert a parallel in front, we need a
32499ef9eebSMatthias Springer // {false, true} mask.
32599ef9eebSMatthias Springer SmallVector<bool, 2> mask{false, true};
32699ef9eebSMatthias Springer
32799ef9eebSMatthias Springer /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
32899ef9eebSMatthias Springer Value cast = rewriter.create<vector::ShapeCastOp>(
3297c38fd60SJacques Pienaar loc, castedType, multiReductionOp.getSource());
330*051b36baSThomas Raoux Value castAcc = rewriter.create<vector::BroadcastOp>(
331*051b36baSThomas Raoux loc, accType, multiReductionOp.getAcc());
33299ef9eebSMatthias Springer Value reduced = rewriter.create<vector::MultiDimReductionOp>(
333*051b36baSThomas Raoux loc, cast, castAcc, mask, multiReductionOp.getKind());
33499ef9eebSMatthias Springer rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
33599ef9eebSMatthias Springer ArrayRef<int64_t>{0});
33699ef9eebSMatthias Springer return success();
33799ef9eebSMatthias Springer }
33899ef9eebSMatthias Springer };
33999ef9eebSMatthias Springer
populateVectorMultiReductionLoweringPatterns(RewritePatternSet & patterns,VectorMultiReductionLowering options)34099ef9eebSMatthias Springer void mlir::vector::populateVectorMultiReductionLoweringPatterns(
34199ef9eebSMatthias Springer RewritePatternSet &patterns, VectorMultiReductionLowering options) {
34299ef9eebSMatthias Springer patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
34399ef9eebSMatthias Springer patterns.getContext(), options);
34499ef9eebSMatthias Springer patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext());
34599ef9eebSMatthias Springer if (options == VectorMultiReductionLowering ::InnerReduction)
34699ef9eebSMatthias Springer patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
34799ef9eebSMatthias Springer else
34899ef9eebSMatthias Springer patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext());
34999ef9eebSMatthias Springer }
350