133d2a780SThomas Raoux //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
233d2a780SThomas Raoux //
333d2a780SThomas Raoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
433d2a780SThomas Raoux // See https://llvm.org/LICENSE.txt for license information.
533d2a780SThomas Raoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
633d2a780SThomas Raoux //
733d2a780SThomas Raoux //===----------------------------------------------------------------------===//
833d2a780SThomas Raoux //
933d2a780SThomas Raoux // This file implements linalg transformation to break a reduction dimension
1033d2a780SThomas Raoux // between a parallel and a reduction dimension.
1133d2a780SThomas Raoux //
1233d2a780SThomas Raoux //===----------------------------------------------------------------------===//
1333d2a780SThomas Raoux
14e188ad8bSMehdi Amini #include <utility>
15e188ad8bSMehdi Amini
1633d2a780SThomas Raoux #include "mlir/Analysis/SliceAnalysis.h"
1733d2a780SThomas Raoux #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18*178f9bd6SNicolas Vasilache #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1933d2a780SThomas Raoux #include "mlir/Dialect/Linalg/IR/Linalg.h"
2033d2a780SThomas Raoux #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
2133d2a780SThomas Raoux #include "mlir/Dialect/Linalg/Utils/Utils.h"
2233d2a780SThomas Raoux #include "mlir/Dialect/Tensor/IR/Tensor.h"
23d5716395SNicolas Vasilache #include "mlir/Dialect/Tensor/Utils/Utils.h"
2433d2a780SThomas Raoux #include "mlir/IR/PatternMatch.h"
2533d2a780SThomas Raoux
2633d2a780SThomas Raoux using namespace mlir;
2733d2a780SThomas Raoux using namespace mlir::linalg;
2833d2a780SThomas Raoux
2933d2a780SThomas Raoux /// Return the identity numeric value associated to the give op.
getNeutralElement(Operation * op)30d5716395SNicolas Vasilache static Attribute getNeutralElement(Operation *op) {
3133d2a780SThomas Raoux // Builder only used as helper for attribute creation.
3233d2a780SThomas Raoux OpBuilder b(op->getContext());
3333d2a780SThomas Raoux Type resultType = op->getResult(0).getType();
3433d2a780SThomas Raoux if (auto floatType = resultType.dyn_cast<FloatType>()) {
3533d2a780SThomas Raoux const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
3633d2a780SThomas Raoux if (isa<arith::AddFOp>(op))
3733d2a780SThomas Raoux return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
3833d2a780SThomas Raoux if (isa<arith::MulFOp>(op))
3933d2a780SThomas Raoux return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
4033d2a780SThomas Raoux if (isa<arith::MaxFOp>(op))
4133d2a780SThomas Raoux return b.getFloatAttr(resultType,
4233d2a780SThomas Raoux llvm::APFloat::getLargest(semantic, true));
4333d2a780SThomas Raoux if (isa<arith::MinFOp>(op))
4433d2a780SThomas Raoux return b.getFloatAttr(resultType,
4533d2a780SThomas Raoux llvm::APFloat::getLargest(semantic, true));
46d5716395SNicolas Vasilache return Attribute();
4733d2a780SThomas Raoux }
4833d2a780SThomas Raoux if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
4933d2a780SThomas Raoux return b.getIntegerAttr(resultType, 0);
5033d2a780SThomas Raoux if (isa<arith::AndIOp>(op))
5133d2a780SThomas Raoux return b.getIntegerAttr(resultType, -1);
5233d2a780SThomas Raoux if (isa<arith::MaxSIOp>(op))
5333d2a780SThomas Raoux return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
5433d2a780SThomas Raoux if (isa<arith::MinSIOp>(op))
5533d2a780SThomas Raoux return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
5633d2a780SThomas Raoux if (isa<arith::MulIOp>(op))
5733d2a780SThomas Raoux return b.getIntegerAttr(resultType, 1);
58d5716395SNicolas Vasilache return Attribute();
5933d2a780SThomas Raoux }
6033d2a780SThomas Raoux
splitReduction(PatternRewriter & b,LinalgOp op,const ControlSplitReductionFn & controlSplitReductionFn,const LinalgTransformationFilter & filter,bool useAlloc)61e188ad8bSMehdi Amini FailureOr<LinalgOp> mlir::linalg::splitReduction(
62e188ad8bSMehdi Amini PatternRewriter &b, LinalgOp op,
63e188ad8bSMehdi Amini const ControlSplitReductionFn &controlSplitReductionFn,
64*178f9bd6SNicolas Vasilache const LinalgTransformationFilter &filter, bool useAlloc) {
6533d2a780SThomas Raoux if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
6633d2a780SThomas Raoux op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
6733d2a780SThomas Raoux !op.hasOnlyProjectedPermutations())
6833d2a780SThomas Raoux return b.notifyMatchFailure(op, "precondition not met");
69f439b319SNicolas Vasilache
70f439b319SNicolas Vasilache FailureOr<SplitReductionResult> res =
71*178f9bd6SNicolas Vasilache splitReduction(b, op, controlSplitReductionFn, useAlloc);
72f439b319SNicolas Vasilache if (failed(res))
73f439b319SNicolas Vasilache return failure();
74f439b319SNicolas Vasilache
75f439b319SNicolas Vasilache filter.replaceLinalgTransformationFilter(b, res->splitLinalgOp);
76f439b319SNicolas Vasilache filter.replaceLinalgTransformationFilter(b, res->resultCombiningLinalgOp);
77f439b319SNicolas Vasilache
78f439b319SNicolas Vasilache return res->splitLinalgOp;
79f439b319SNicolas Vasilache }
80f439b319SNicolas Vasilache
splitReduction(PatternRewriter & b,LinalgOp op,const ControlSplitReductionFn & controlSplitReductionFn,bool useAlloc)81f439b319SNicolas Vasilache FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
82f439b319SNicolas Vasilache PatternRewriter &b, LinalgOp op,
83*178f9bd6SNicolas Vasilache const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
84f439b319SNicolas Vasilache OpBuilder::InsertionGuard guard(b);
85f439b319SNicolas Vasilache b.setInsertionPoint(op);
86f439b319SNicolas Vasilache
8733d2a780SThomas Raoux std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
8833d2a780SThomas Raoux int64_t ratio = control.first;
89d5716395SNicolas Vasilache unsigned insertSplitDimension = control.second;
9033d2a780SThomas Raoux if (ratio <= 1)
9133d2a780SThomas Raoux return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
92f439b319SNicolas Vasilache
9333d2a780SThomas Raoux SmallVector<unsigned> dims;
9433d2a780SThomas Raoux op.getReductionDims(dims);
9533d2a780SThomas Raoux assert(dims.size() == 1);
9633d2a780SThomas Raoux unsigned reductionDim = dims[0];
97919e459fSHanhan Wang SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
98919e459fSHanhan Wang int64_t reductionDimSize = loopRanges[reductionDim];
9933d2a780SThomas Raoux if (reductionDimSize == ShapedType::kDynamicSize ||
100d5716395SNicolas Vasilache reductionDimSize % ratio != 0 ||
101d5716395SNicolas Vasilache insertSplitDimension >= loopRanges.size())
10233d2a780SThomas Raoux return b.notifyMatchFailure(
10333d2a780SThomas Raoux op, "Reduction dimension not divisible by split ratio");
104f439b319SNicolas Vasilache
10533d2a780SThomas Raoux SmallVector<Operation *, 4> combinerOps;
10633d2a780SThomas Raoux if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
10733d2a780SThomas Raoux combinerOps.size() != 1)
10833d2a780SThomas Raoux return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
109f439b319SNicolas Vasilache
11033d2a780SThomas Raoux Operation *reductionOp = combinerOps[0];
111d5716395SNicolas Vasilache Attribute identity = getNeutralElement(reductionOp);
11233d2a780SThomas Raoux if (!identity)
113f439b319SNicolas Vasilache return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
11433d2a780SThomas Raoux
11533d2a780SThomas Raoux Location loc = op->getLoc();
11633d2a780SThomas Raoux SmallVector<Value> newInputs;
11733d2a780SThomas Raoux SmallVector<AffineMap> newMaps;
11833d2a780SThomas Raoux // Calculate the new shapes and indexing maps of the input operands.
11933d2a780SThomas Raoux for (OpOperand *operand : op.getInputOperands()) {
12033d2a780SThomas Raoux AffineMap map = op.getTiedIndexingMap(operand);
12133d2a780SThomas Raoux SmallVector<int64_t> newShape;
12233d2a780SThomas Raoux SmallVector<AffineExpr> exprs;
12333d2a780SThomas Raoux SmallVector<ReassociationIndices> reassociation;
12433d2a780SThomas Raoux unsigned index = 0;
12533d2a780SThomas Raoux for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
12633d2a780SThomas Raoux unsigned dim = map.getDimPosition(idx);
12733d2a780SThomas Raoux if (reductionDim == dim) {
12833d2a780SThomas Raoux newShape.push_back(ratio);
12933d2a780SThomas Raoux newShape.push_back(op.getShape(operand)[idx] / ratio);
13033d2a780SThomas Raoux reassociation.push_back({index++, index++});
131d5716395SNicolas Vasilache exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
13233d2a780SThomas Raoux exprs.push_back(
133d5716395SNicolas Vasilache b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
13433d2a780SThomas Raoux continue;
13533d2a780SThomas Raoux }
13633d2a780SThomas Raoux newShape.push_back(op.getShape(operand)[idx]);
137d5716395SNicolas Vasilache exprs.push_back(
138d5716395SNicolas Vasilache b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
13933d2a780SThomas Raoux reassociation.push_back({index++});
14033d2a780SThomas Raoux }
14133d2a780SThomas Raoux newMaps.push_back(
14233d2a780SThomas Raoux AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
14333d2a780SThomas Raoux // If the shape is unchanged the input doesn't change.
14433d2a780SThomas Raoux if (newShape == op.getShape(operand)) {
14533d2a780SThomas Raoux newInputs.push_back(operand->get());
14633d2a780SThomas Raoux continue;
14733d2a780SThomas Raoux }
14833d2a780SThomas Raoux Type newType = RankedTensorType::get(
14933d2a780SThomas Raoux newShape,
15033d2a780SThomas Raoux operand->get().getType().cast<RankedTensorType>().getElementType());
15133d2a780SThomas Raoux Value newInput = b.create<tensor::ExpandShapeOp>(
15233d2a780SThomas Raoux loc, newType, operand->get(), reassociation);
15333d2a780SThomas Raoux newInputs.push_back(newInput);
15433d2a780SThomas Raoux }
155f439b319SNicolas Vasilache
15633d2a780SThomas Raoux // Calculate the new output map and shape, we insert the new dimension based
15733d2a780SThomas Raoux // on the index returned by `controlSplitReductionFn`.
15833d2a780SThomas Raoux SmallVector<int64_t> newOutputShape;
15933d2a780SThomas Raoux AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0));
16033d2a780SThomas Raoux ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0));
16133d2a780SThomas Raoux SmallVector<AffineExpr> outputExpr;
16233d2a780SThomas Raoux for (unsigned idx :
16333d2a780SThomas Raoux llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) {
164d5716395SNicolas Vasilache if (idx == insertSplitDimension) {
16533d2a780SThomas Raoux newOutputShape.push_back(ratio);
166d5716395SNicolas Vasilache outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
16733d2a780SThomas Raoux continue;
16833d2a780SThomas Raoux }
169d5716395SNicolas Vasilache unsigned oldDim = idx < insertSplitDimension ? idx : idx - 1;
17033d2a780SThomas Raoux newOutputShape.push_back(oldShape[oldDim]);
17133d2a780SThomas Raoux unsigned dim = oldOutputMap.getDimPosition(oldDim);
17233d2a780SThomas Raoux outputExpr.push_back(
173d5716395SNicolas Vasilache b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
17433d2a780SThomas Raoux }
175*178f9bd6SNicolas Vasilache Value initOrAllocTensor;
176*178f9bd6SNicolas Vasilache if (useAlloc) {
177*178f9bd6SNicolas Vasilache initOrAllocTensor = b.create<bufferization::AllocTensorOp>(
178*178f9bd6SNicolas Vasilache loc,
179*178f9bd6SNicolas Vasilache RankedTensorType::get(newOutputShape,
180*178f9bd6SNicolas Vasilache op.getRegionOutputArgs()[0].getType()),
181*178f9bd6SNicolas Vasilache ValueRange{});
182*178f9bd6SNicolas Vasilache } else {
183*178f9bd6SNicolas Vasilache initOrAllocTensor = b.create<linalg::InitTensorOp>(
18433d2a780SThomas Raoux loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
185*178f9bd6SNicolas Vasilache }
186d5716395SNicolas Vasilache Value constantOp = b.create<arith::ConstantOp>(loc, identity);
18733d2a780SThomas Raoux Value identityTensor =
188*178f9bd6SNicolas Vasilache b.create<linalg::FillOp>(op->getLoc(), constantOp, initOrAllocTensor)
18933d2a780SThomas Raoux .getResult(0);
19033d2a780SThomas Raoux
19133d2a780SThomas Raoux newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
19233d2a780SThomas Raoux op.getContext()));
19333d2a780SThomas Raoux SmallVector<StringRef> newIteratorTypes;
19433d2a780SThomas Raoux for (auto &it : llvm::enumerate(op.iterator_types())) {
195d5716395SNicolas Vasilache if (insertSplitDimension == it.index())
19633d2a780SThomas Raoux newIteratorTypes.push_back(getParallelIteratorTypeName());
19733d2a780SThomas Raoux newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue());
19833d2a780SThomas Raoux }
19933d2a780SThomas Raoux // Create the new op matching the original op with an extra parallel
20033d2a780SThomas Raoux // dimension.
20133d2a780SThomas Raoux GenericOp genericOp = b.create<GenericOp>(
202*178f9bd6SNicolas Vasilache loc, TypeRange({initOrAllocTensor.getType()}), newInputs,
20333d2a780SThomas Raoux ValueRange({identityTensor}), newMaps, newIteratorTypes);
20433d2a780SThomas Raoux b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
20533d2a780SThomas Raoux genericOp.region().begin());
20633d2a780SThomas Raoux
207f439b319SNicolas Vasilache // Then create a new reduction that only reduce the newly added dimension
208f439b319SNicolas Vasilache // from the previous op.
20933d2a780SThomas Raoux unsigned intermRank = newOutputShape.size();
21033d2a780SThomas Raoux AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
21133d2a780SThomas Raoux SmallVector<Value> outputOperands = op.getOutputOperands();
21233d2a780SThomas Raoux SmallVector<StringRef> reductionIteratorTypes;
21333d2a780SThomas Raoux SmallVector<AffineExpr> exprs;
21433d2a780SThomas Raoux for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
215d5716395SNicolas Vasilache if (insertSplitDimension == i) {
21633d2a780SThomas Raoux reductionIteratorTypes.push_back(getReductionIteratorTypeName());
21733d2a780SThomas Raoux } else {
21833d2a780SThomas Raoux exprs.push_back(b.getAffineDimExpr(i));
21933d2a780SThomas Raoux reductionIteratorTypes.push_back(getParallelIteratorTypeName());
22033d2a780SThomas Raoux }
22133d2a780SThomas Raoux }
22233d2a780SThomas Raoux AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
22333d2a780SThomas Raoux SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
22433d2a780SThomas Raoux
22533d2a780SThomas Raoux auto reduction = b.create<GenericOp>(
22633d2a780SThomas Raoux loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
22733d2a780SThomas Raoux outputOperands, reductionMaps, reductionIteratorTypes,
22833d2a780SThomas Raoux [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
22933d2a780SThomas Raoux Operation *clonedReductionOp = b.clone(*reductionOp);
23033d2a780SThomas Raoux clonedReductionOp->setOperand(0, inputs[0]);
23133d2a780SThomas Raoux clonedReductionOp->setOperand(1, inputs[1]);
23233d2a780SThomas Raoux b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
23333d2a780SThomas Raoux });
23433d2a780SThomas Raoux b.replaceOp(op, reduction.getResults());
235f439b319SNicolas Vasilache
236*178f9bd6SNicolas Vasilache return SplitReductionResult{
237*178f9bd6SNicolas Vasilache initOrAllocTensor.getDefiningOp(), identityTensor.getDefiningOp<FillOp>(),
238*178f9bd6SNicolas Vasilache cast<LinalgOp>(genericOp.getOperation()), reduction};
23933d2a780SThomas Raoux }
24033d2a780SThomas Raoux
241d5716395SNicolas Vasilache /// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
242d5716395SNicolas Vasilache /// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into
243d5716395SNicolas Vasilache /// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better
244d5716395SNicolas Vasilache /// done as a transform to enable better vectorization.
scaleReductionDim(LinalgOp op,OpOperand & opOperand,unsigned reductionDimPos,int64_t reductionRatio)245d5716395SNicolas Vasilache static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand,
246d5716395SNicolas Vasilache unsigned reductionDimPos,
247d5716395SNicolas Vasilache int64_t reductionRatio) {
248d5716395SNicolas Vasilache auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
249d5716395SNicolas Vasilache auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext());
250d5716395SNicolas Vasilache AffineMap map = op.getTiedIndexingMap(&opOperand);
251d5716395SNicolas Vasilache AffineMap idMap =
252d5716395SNicolas Vasilache AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
253d5716395SNicolas Vasilache AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
254d5716395SNicolas Vasilache AffineMap composeMap = shiftedIdMap.replace(
255d5716395SNicolas Vasilache reductionDim, reductionDim * reductionRatio + reductionDimP1,
256d5716395SNicolas Vasilache shiftedIdMap.getNumDims(), /*numSymbols=*/0);
257d5716395SNicolas Vasilache return map.compose(composeMap);
258d5716395SNicolas Vasilache }
259d5716395SNicolas Vasilache
insertParallelDim(LinalgOp op,OpOperand & opOperand,unsigned reductionDimPos,int64_t size)260d5716395SNicolas Vasilache static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
261d5716395SNicolas Vasilache unsigned reductionDimPos, int64_t size) {
262d5716395SNicolas Vasilache auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
263d5716395SNicolas Vasilache AffineMap map = op.getTiedIndexingMap(&opOperand);
264d5716395SNicolas Vasilache AffineMap idMap =
265d5716395SNicolas Vasilache AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
266d5716395SNicolas Vasilache AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
267d5716395SNicolas Vasilache return map.compose(shiftedIdMap).insertResult(reductionDim, reductionDimPos);
268d5716395SNicolas Vasilache }
269d5716395SNicolas Vasilache
270d5716395SNicolas Vasilache /// Core rewrite implementation.
splitReductionByScaling(PatternRewriter & b,LinalgOp op,const ControlSplitReductionFn & controlSplitReductionFn,bool useAlloc)271d5716395SNicolas Vasilache FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
272d5716395SNicolas Vasilache PatternRewriter &b, LinalgOp op,
273*178f9bd6SNicolas Vasilache const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
274d5716395SNicolas Vasilache OpBuilder::InsertionGuard guard(b);
275d5716395SNicolas Vasilache b.setInsertionPoint(op);
276d5716395SNicolas Vasilache
277d5716395SNicolas Vasilache // Matcher part, enforce preconditions.
278d5716395SNicolas Vasilache std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
279d5716395SNicolas Vasilache int64_t splitFactor = control.first;
280d5716395SNicolas Vasilache unsigned insertSplitDimension = control.second;
281d5716395SNicolas Vasilache if (splitFactor <= 1)
282d5716395SNicolas Vasilache return b.notifyMatchFailure(op, "split factor needs to be greater than 1");
283d5716395SNicolas Vasilache
284d5716395SNicolas Vasilache SmallVector<unsigned> dims;
285d5716395SNicolas Vasilache op.getReductionDims(dims);
286d5716395SNicolas Vasilache if (dims.empty())
287d5716395SNicolas Vasilache return b.notifyMatchFailure(op, "needs at least 1 reduction dimension");
288d5716395SNicolas Vasilache
289d5716395SNicolas Vasilache unsigned reductionDimPos = dims[0];
290d5716395SNicolas Vasilache SmallVector<int64_t> loopRanges = op.getStaticLoopRanges();
291d5716395SNicolas Vasilache int64_t reductionDimSize = loopRanges[reductionDimPos];
292d5716395SNicolas Vasilache if (reductionDimSize == ShapedType::kDynamicSize ||
293d5716395SNicolas Vasilache reductionDimSize % splitFactor != 0 ||
294d5716395SNicolas Vasilache insertSplitDimension >= loopRanges.size())
295d5716395SNicolas Vasilache return b.notifyMatchFailure(
296d5716395SNicolas Vasilache op, "first reduction dimension not divisible by split factor");
297d5716395SNicolas Vasilache
298d5716395SNicolas Vasilache SmallVector<Operation *> combinerOps;
299d5716395SNicolas Vasilache if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
300d5716395SNicolas Vasilache return b.notifyMatchFailure(op, "cannot match a reduction pattern");
301d5716395SNicolas Vasilache
302d5716395SNicolas Vasilache SmallVector<Attribute> neutralElements = llvm::to_vector<4>(
303d5716395SNicolas Vasilache llvm::map_range(combinerOps, [&](Operation *reductionOp) {
304d5716395SNicolas Vasilache return getNeutralElement(reductionOp);
305d5716395SNicolas Vasilache }));
306d5716395SNicolas Vasilache if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; }))
307d5716395SNicolas Vasilache return b.notifyMatchFailure(op, "unknown reduction neutral");
308d5716395SNicolas Vasilache
309d5716395SNicolas Vasilache // TODO: relax this when multi-reduction support is available.
310*178f9bd6SNicolas Vasilache if (op.getNumOutputs() != static_cast<int64_t>(neutralElements.size()))
311d5716395SNicolas Vasilache return b.notifyMatchFailure(op, "expect one reduction per output");
312d5716395SNicolas Vasilache
313d5716395SNicolas Vasilache // Rewrite part.
314d5716395SNicolas Vasilache // Step 1. Build the intermediate outputs filled with the proper
315d5716395SNicolas Vasilache // neutralElements. Such outputs are of the same shape with an extra dimension
316d5716395SNicolas Vasilache // inserted at `insertSplitDimension`.
317d5716395SNicolas Vasilache //
318d5716395SNicolas Vasilache // Consider a minimal example where `k` is reduced:
319d5716395SNicolas Vasilache // O(i, j) += I(i, j, k)
320d5716395SNicolas Vasilache // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0.
321d5716395SNicolas Vasilache // The compute is rewritten as:
322d5716395SNicolas Vasilache // a. O_i(kk, i, j) += I(i, j, 16 * k + kk)
323d5716395SNicolas Vasilache // b. O(i, j) += O_i(kk, i, j)
324d5716395SNicolas Vasilache // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5.
325d5716395SNicolas Vasilache Location loc = op->getLoc();
326d5716395SNicolas Vasilache MLIRContext *context = op.getContext();
327d5716395SNicolas Vasilache // For now assume outputs are 1-1 with reduction neutralElements.
328d5716395SNicolas Vasilache // TODO: generalize when multi-reduction support is available.
329d5716395SNicolas Vasilache SmallVector<Value> newOutputs;
330d5716395SNicolas Vasilache newOutputs.reserve(op.getNumOutputs());
331*178f9bd6SNicolas Vasilache SmallVector<Operation *> initOrAllocTensorOps;
332d5716395SNicolas Vasilache SmallVector<linalg::FillOp> fillOps;
333d5716395SNicolas Vasilache fillOps.reserve(op.getNumOutputs());
334d5716395SNicolas Vasilache for (auto it : llvm::zip(op.outputs(), neutralElements)) {
335d5716395SNicolas Vasilache Value rankedTensor = std::get<0>(it);
336d5716395SNicolas Vasilache auto t = rankedTensor.getType().cast<RankedTensorType>();
337d5716395SNicolas Vasilache RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
338d5716395SNicolas Vasilache reductionDimSize / splitFactor, insertSplitDimension);
339d5716395SNicolas Vasilache SmallVector<Value> dims =
340d5716395SNicolas Vasilache tensor::createDynamicDimValues(b, loc, rankedTensor);
341*178f9bd6SNicolas Vasilache Value initOrAllocTensor;
342*178f9bd6SNicolas Vasilache if (useAlloc) {
343*178f9bd6SNicolas Vasilache initOrAllocTensor =
344*178f9bd6SNicolas Vasilache b.create<bufferization::AllocTensorOp>(loc, newT, dims);
345*178f9bd6SNicolas Vasilache } else {
346*178f9bd6SNicolas Vasilache initOrAllocTensor = b.create<linalg::InitTensorOp>(
347d5716395SNicolas Vasilache loc, dims, newT.getShape(), t.getElementType());
348*178f9bd6SNicolas Vasilache }
349d5716395SNicolas Vasilache Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
350d5716395SNicolas Vasilache fillOps.push_back(
351*178f9bd6SNicolas Vasilache b.create<linalg::FillOp>(op->getLoc(), constantOp, initOrAllocTensor));
352d5716395SNicolas Vasilache newOutputs.push_back(fillOps.back().getResult(0));
353*178f9bd6SNicolas Vasilache initOrAllocTensorOps.push_back(initOrAllocTensor.getDefiningOp());
354d5716395SNicolas Vasilache }
355d5716395SNicolas Vasilache
356d5716395SNicolas Vasilache // Step 2. Reindex / expand indexing maps.
357d5716395SNicolas Vasilache // Reindex existing input indexings: k -> k * splitFactor + k'.
358d5716395SNicolas Vasilache SmallVector<AffineMap> newMaps;
359d5716395SNicolas Vasilache newMaps.reserve(op.getNumInputsAndOutputs() + 1);
360d5716395SNicolas Vasilache for (OpOperand *o : op.getInputOperands())
361d5716395SNicolas Vasilache newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
362d5716395SNicolas Vasilache // Provision a new indexing for the shape-only tensor.
363d5716395SNicolas Vasilache auto nDims = op.getNumLoops() + 1;
364d5716395SNicolas Vasilache auto redDim = getAffineDimExpr(reductionDimPos, context);
365d5716395SNicolas Vasilache auto redDimP1 = getAffineDimExpr(reductionDimPos + 1, context);
366d5716395SNicolas Vasilache newMaps.push_back(AffineMap::get(nDims, 0, {redDim, redDimP1}, context));
367d5716395SNicolas Vasilache // Expand existing output indexings.
368d5716395SNicolas Vasilache // TODO: a subset of these may not reduce along reducePos and should be
369d5716395SNicolas Vasilache // reindexed: k -> k * splitFactor + k', when multi-reduction support is
370d5716395SNicolas Vasilache // available.
371d5716395SNicolas Vasilache for (OpOperand *o : op.getOutputOperands())
372d5716395SNicolas Vasilache newMaps.push_back(insertParallelDim(op, *o, reductionDimPos,
373d5716395SNicolas Vasilache reductionDimSize / splitFactor));
374d5716395SNicolas Vasilache
375d5716395SNicolas Vasilache // Step 3. Handle operands.
376d5716395SNicolas Vasilache // Compute the new input tensors.
377d5716395SNicolas Vasilache auto newInputs = llvm::to_vector<4>(op.inputs());
378d5716395SNicolas Vasilache // Add a single shape-only tensor to carry the dimensions without resorting to
379d5716395SNicolas Vasilache // more complex inversions.
380d5716395SNicolas Vasilache newInputs.push_back(b.create<linalg::InitTensorOp>(
381d5716395SNicolas Vasilache loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor},
382d5716395SNicolas Vasilache b.getIntegerType(1)));
383d5716395SNicolas Vasilache // Output tensors are already good to go.
384d5716395SNicolas Vasilache
385d5716395SNicolas Vasilache // Step 4. Create the new op matching the original op with an extra parallel
386d5716395SNicolas Vasilache // dimension.
387d5716395SNicolas Vasilache SmallVector<StringRef> iteratorTypes =
388d5716395SNicolas Vasilache llvm::to_vector<4>(op.getIteratorTypes().getAsValueRange<StringAttr>());
389d5716395SNicolas Vasilache iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
390d5716395SNicolas Vasilache getParallelIteratorTypeName());
391d5716395SNicolas Vasilache GenericOp genericOp =
392d5716395SNicolas Vasilache b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs,
393d5716395SNicolas Vasilache newOutputs, newMaps, iteratorTypes);
394d5716395SNicolas Vasilache b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
395d5716395SNicolas Vasilache genericOp.region().begin());
396d5716395SNicolas Vasilache genericOp.region().front().insertArgument(reductionDimPos,
397d5716395SNicolas Vasilache b.getIntegerType(1), loc);
398d5716395SNicolas Vasilache
399d5716395SNicolas Vasilache // Step 5. Create new reduction ops that only reduce the newly added
400d5716395SNicolas Vasilache // dimensions from the previous op.
401d5716395SNicolas Vasilache // For now assume outputs are 1-1 with reduction ops.
402d5716395SNicolas Vasilache // TODO: a subset of these may not reduce in the first place and do not
403d5716395SNicolas Vasilache // require a new op, when multi-reduction support is available.
404d5716395SNicolas Vasilache // TODO: all results can be handled in a single GenericOp, when
405d5716395SNicolas Vasilache // multi-reduction support is available.
406d5716395SNicolas Vasilache SmallVector<LinalgOp> results;
407d5716395SNicolas Vasilache for (auto it :
408d5716395SNicolas Vasilache llvm::zip(genericOp->getResults(), op.outputs(), combinerOps)) {
409d5716395SNicolas Vasilache Value reindexedOutput = std::get<0>(it);
410d5716395SNicolas Vasilache Value originalOutput = std::get<1>(it);
411d5716395SNicolas Vasilache auto originalOutputType = originalOutput.getType().cast<RankedTensorType>();
412d5716395SNicolas Vasilache Operation *combinerOp = std::get<2>(it);
413d5716395SNicolas Vasilache
414d5716395SNicolas Vasilache AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
415d5716395SNicolas Vasilache SmallVector<AffineMap> indexingMaps = {
416d5716395SNicolas Vasilache map, map.dropResult(insertSplitDimension)};
417d5716395SNicolas Vasilache SmallVector<StringRef> reductionIteratorTypes(
418d5716395SNicolas Vasilache originalOutputType.getRank() + 1, getParallelIteratorTypeName());
419d5716395SNicolas Vasilache reductionIteratorTypes[insertSplitDimension] =
420d5716395SNicolas Vasilache getReductionIteratorTypeName();
421d5716395SNicolas Vasilache
422d5716395SNicolas Vasilache // clang-format off
423d5716395SNicolas Vasilache auto reductionOp = b.create<GenericOp>(
424d5716395SNicolas Vasilache loc,
425d5716395SNicolas Vasilache originalOutputType,
426d5716395SNicolas Vasilache reindexedOutput,
427d5716395SNicolas Vasilache originalOutput,
428d5716395SNicolas Vasilache indexingMaps,
429d5716395SNicolas Vasilache reductionIteratorTypes,
430d5716395SNicolas Vasilache [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) {
431d5716395SNicolas Vasilache Operation *clonedReductionOp = b.clone(*combinerOp);
432d5716395SNicolas Vasilache clonedReductionOp->setOperand(0, bbArgs[0]);
433d5716395SNicolas Vasilache clonedReductionOp->setOperand(1, bbArgs[1]);
434d5716395SNicolas Vasilache b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
435d5716395SNicolas Vasilache });
436d5716395SNicolas Vasilache // clang-format on
437d5716395SNicolas Vasilache
438d5716395SNicolas Vasilache results.push_back(reductionOp);
439d5716395SNicolas Vasilache }
440d5716395SNicolas Vasilache
441d5716395SNicolas Vasilache // TODO: extend when multi-reduction support is available.
442d5716395SNicolas Vasilache assert(fillOps.size() == results.size() && results.size() == 1);
443d5716395SNicolas Vasilache b.replaceOp(op, results.front()->getResults());
444*178f9bd6SNicolas Vasilache return SplitReductionResult{initOrAllocTensorOps.front(), fillOps.front(),
445d5716395SNicolas Vasilache cast<LinalgOp>(genericOp.getOperation()),
446d5716395SNicolas Vasilache results.front()};
447d5716395SNicolas Vasilache }
448d5716395SNicolas Vasilache
44933d2a780SThomas Raoux namespace {
45033d2a780SThomas Raoux
45133d2a780SThomas Raoux struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
45233d2a780SThomas Raoux /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
LinalgSplitReduction__anon86cbc2bb0511::LinalgSplitReduction45333d2a780SThomas Raoux LinalgSplitReduction(MLIRContext *context,
45433d2a780SThomas Raoux ControlSplitReductionFn controlSplitReductionFn,
455*178f9bd6SNicolas Vasilache LinalgTransformationFilter f, bool useAlloc = false,
456*178f9bd6SNicolas Vasilache PatternBenefit benefit = 1)
45733d2a780SThomas Raoux : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
458e188ad8bSMehdi Amini controlSplitReductionFn(std::move(controlSplitReductionFn)),
459*178f9bd6SNicolas Vasilache useAlloc(useAlloc), filter(std::move(f)) {}
46033d2a780SThomas Raoux
matchAndRewrite__anon86cbc2bb0511::LinalgSplitReduction46133d2a780SThomas Raoux LogicalResult matchAndRewrite(LinalgOp op,
46233d2a780SThomas Raoux PatternRewriter &rewriter) const override {
463*178f9bd6SNicolas Vasilache return splitReduction(rewriter, op, controlSplitReductionFn, filter,
464*178f9bd6SNicolas Vasilache useAlloc);
46533d2a780SThomas Raoux }
46633d2a780SThomas Raoux
46733d2a780SThomas Raoux private:
46833d2a780SThomas Raoux ControlSplitReductionFn controlSplitReductionFn;
469*178f9bd6SNicolas Vasilache bool useAlloc;
47033d2a780SThomas Raoux LinalgTransformationFilter filter;
47133d2a780SThomas Raoux };
47233d2a780SThomas Raoux
47333d2a780SThomas Raoux } // namespace
47433d2a780SThomas Raoux
populateSplitReductionPattern(RewritePatternSet & patterns,const ControlSplitReductionFn & controlSplitReductionFn,const LinalgTransformationFilter & f,bool useAlloc)47533d2a780SThomas Raoux void linalg::populateSplitReductionPattern(
47633d2a780SThomas Raoux RewritePatternSet &patterns,
477e188ad8bSMehdi Amini const ControlSplitReductionFn &controlSplitReductionFn,
478*178f9bd6SNicolas Vasilache const LinalgTransformationFilter &f, bool useAlloc) {
47933d2a780SThomas Raoux patterns.add<LinalgSplitReduction>(patterns.getContext(),
480*178f9bd6SNicolas Vasilache controlSplitReductionFn, f, useAlloc);
48133d2a780SThomas Raoux }
482