133d2a780SThomas Raoux //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
233d2a780SThomas Raoux //
333d2a780SThomas Raoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
433d2a780SThomas Raoux // See https://llvm.org/LICENSE.txt for license information.
533d2a780SThomas Raoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
633d2a780SThomas Raoux //
733d2a780SThomas Raoux //===----------------------------------------------------------------------===//
833d2a780SThomas Raoux //
933d2a780SThomas Raoux // This file implements linalg transformation to break a reduction dimension
1033d2a780SThomas Raoux // between a parallel and a reduction dimension.
1133d2a780SThomas Raoux //
1233d2a780SThomas Raoux //===----------------------------------------------------------------------===//
1333d2a780SThomas Raoux 
14e188ad8bSMehdi Amini #include <utility>
15e188ad8bSMehdi Amini 
1633d2a780SThomas Raoux #include "mlir/Analysis/SliceAnalysis.h"
1733d2a780SThomas Raoux #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18*178f9bd6SNicolas Vasilache #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1933d2a780SThomas Raoux #include "mlir/Dialect/Linalg/IR/Linalg.h"
2033d2a780SThomas Raoux #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
2133d2a780SThomas Raoux #include "mlir/Dialect/Linalg/Utils/Utils.h"
2233d2a780SThomas Raoux #include "mlir/Dialect/Tensor/IR/Tensor.h"
23d5716395SNicolas Vasilache #include "mlir/Dialect/Tensor/Utils/Utils.h"
2433d2a780SThomas Raoux #include "mlir/IR/PatternMatch.h"
2533d2a780SThomas Raoux 
2633d2a780SThomas Raoux using namespace mlir;
2733d2a780SThomas Raoux using namespace mlir::linalg;
2833d2a780SThomas Raoux 
2933d2a780SThomas Raoux /// Return the identity numeric value associated to the give op.
getNeutralElement(Operation * op)30d5716395SNicolas Vasilache static Attribute getNeutralElement(Operation *op) {
3133d2a780SThomas Raoux   // Builder only used as helper for attribute creation.
3233d2a780SThomas Raoux   OpBuilder b(op->getContext());
3333d2a780SThomas Raoux   Type resultType = op->getResult(0).getType();
3433d2a780SThomas Raoux   if (auto floatType = resultType.dyn_cast<FloatType>()) {
3533d2a780SThomas Raoux     const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
3633d2a780SThomas Raoux     if (isa<arith::AddFOp>(op))
3733d2a780SThomas Raoux       return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
3833d2a780SThomas Raoux     if (isa<arith::MulFOp>(op))
3933d2a780SThomas Raoux       return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
4033d2a780SThomas Raoux     if (isa<arith::MaxFOp>(op))
4133d2a780SThomas Raoux       return b.getFloatAttr(resultType,
4233d2a780SThomas Raoux                             llvm::APFloat::getLargest(semantic, true));
4333d2a780SThomas Raoux     if (isa<arith::MinFOp>(op))
4433d2a780SThomas Raoux       return b.getFloatAttr(resultType,
4533d2a780SThomas Raoux                             llvm::APFloat::getLargest(semantic, true));
46d5716395SNicolas Vasilache     return Attribute();
4733d2a780SThomas Raoux   }
4833d2a780SThomas Raoux   if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
4933d2a780SThomas Raoux     return b.getIntegerAttr(resultType, 0);
5033d2a780SThomas Raoux   if (isa<arith::AndIOp>(op))
5133d2a780SThomas Raoux     return b.getIntegerAttr(resultType, -1);
5233d2a780SThomas Raoux   if (isa<arith::MaxSIOp>(op))
5333d2a780SThomas Raoux     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
5433d2a780SThomas Raoux   if (isa<arith::MinSIOp>(op))
5533d2a780SThomas Raoux     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
5633d2a780SThomas Raoux   if (isa<arith::MulIOp>(op))
5733d2a780SThomas Raoux     return b.getIntegerAttr(resultType, 1);
58d5716395SNicolas Vasilache   return Attribute();
5933d2a780SThomas Raoux }
6033d2a780SThomas Raoux 
splitReduction(PatternRewriter & b,LinalgOp op,const ControlSplitReductionFn & controlSplitReductionFn,const LinalgTransformationFilter & filter,bool useAlloc)61e188ad8bSMehdi Amini FailureOr<LinalgOp> mlir::linalg::splitReduction(
62e188ad8bSMehdi Amini     PatternRewriter &b, LinalgOp op,
63e188ad8bSMehdi Amini     const ControlSplitReductionFn &controlSplitReductionFn,
64*178f9bd6SNicolas Vasilache     const LinalgTransformationFilter &filter, bool useAlloc) {
6533d2a780SThomas Raoux   if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
6633d2a780SThomas Raoux       op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
6733d2a780SThomas Raoux       !op.hasOnlyProjectedPermutations())
6833d2a780SThomas Raoux     return b.notifyMatchFailure(op, "precondition not met");
69f439b319SNicolas Vasilache 
70f439b319SNicolas Vasilache   FailureOr<SplitReductionResult> res =
71*178f9bd6SNicolas Vasilache       splitReduction(b, op, controlSplitReductionFn, useAlloc);
72f439b319SNicolas Vasilache   if (failed(res))
73f439b319SNicolas Vasilache     return failure();
74f439b319SNicolas Vasilache 
75f439b319SNicolas Vasilache   filter.replaceLinalgTransformationFilter(b, res->splitLinalgOp);
76f439b319SNicolas Vasilache   filter.replaceLinalgTransformationFilter(b, res->resultCombiningLinalgOp);
77f439b319SNicolas Vasilache 
78f439b319SNicolas Vasilache   return res->splitLinalgOp;
79f439b319SNicolas Vasilache }
80f439b319SNicolas Vasilache 
splitReduction(PatternRewriter & b,LinalgOp op,const ControlSplitReductionFn & controlSplitReductionFn,bool useAlloc)81f439b319SNicolas Vasilache FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
82f439b319SNicolas Vasilache     PatternRewriter &b, LinalgOp op,
83*178f9bd6SNicolas Vasilache     const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
84f439b319SNicolas Vasilache   OpBuilder::InsertionGuard guard(b);
85f439b319SNicolas Vasilache   b.setInsertionPoint(op);
86f439b319SNicolas Vasilache 
8733d2a780SThomas Raoux   std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
8833d2a780SThomas Raoux   int64_t ratio = control.first;
89d5716395SNicolas Vasilache   unsigned insertSplitDimension = control.second;
9033d2a780SThomas Raoux   if (ratio <= 1)
9133d2a780SThomas Raoux     return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
92f439b319SNicolas Vasilache 
9333d2a780SThomas Raoux   SmallVector<unsigned> dims;
9433d2a780SThomas Raoux   op.getReductionDims(dims);
9533d2a780SThomas Raoux   assert(dims.size() == 1);
9633d2a780SThomas Raoux   unsigned reductionDim = dims[0];
97919e459fSHanhan Wang   SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
98919e459fSHanhan Wang   int64_t reductionDimSize = loopRanges[reductionDim];
9933d2a780SThomas Raoux   if (reductionDimSize == ShapedType::kDynamicSize ||
100d5716395SNicolas Vasilache       reductionDimSize % ratio != 0 ||
101d5716395SNicolas Vasilache       insertSplitDimension >= loopRanges.size())
10233d2a780SThomas Raoux     return b.notifyMatchFailure(
10333d2a780SThomas Raoux         op, "Reduction dimension not divisible by split ratio");
104f439b319SNicolas Vasilache 
10533d2a780SThomas Raoux   SmallVector<Operation *, 4> combinerOps;
10633d2a780SThomas Raoux   if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
10733d2a780SThomas Raoux       combinerOps.size() != 1)
10833d2a780SThomas Raoux     return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
109f439b319SNicolas Vasilache 
11033d2a780SThomas Raoux   Operation *reductionOp = combinerOps[0];
111d5716395SNicolas Vasilache   Attribute identity = getNeutralElement(reductionOp);
11233d2a780SThomas Raoux   if (!identity)
113f439b319SNicolas Vasilache     return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
11433d2a780SThomas Raoux 
11533d2a780SThomas Raoux   Location loc = op->getLoc();
11633d2a780SThomas Raoux   SmallVector<Value> newInputs;
11733d2a780SThomas Raoux   SmallVector<AffineMap> newMaps;
11833d2a780SThomas Raoux   // Calculate the new shapes and indexing maps of the input operands.
11933d2a780SThomas Raoux   for (OpOperand *operand : op.getInputOperands()) {
12033d2a780SThomas Raoux     AffineMap map = op.getTiedIndexingMap(operand);
12133d2a780SThomas Raoux     SmallVector<int64_t> newShape;
12233d2a780SThomas Raoux     SmallVector<AffineExpr> exprs;
12333d2a780SThomas Raoux     SmallVector<ReassociationIndices> reassociation;
12433d2a780SThomas Raoux     unsigned index = 0;
12533d2a780SThomas Raoux     for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
12633d2a780SThomas Raoux       unsigned dim = map.getDimPosition(idx);
12733d2a780SThomas Raoux       if (reductionDim == dim) {
12833d2a780SThomas Raoux         newShape.push_back(ratio);
12933d2a780SThomas Raoux         newShape.push_back(op.getShape(operand)[idx] / ratio);
13033d2a780SThomas Raoux         reassociation.push_back({index++, index++});
131d5716395SNicolas Vasilache         exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
13233d2a780SThomas Raoux         exprs.push_back(
133d5716395SNicolas Vasilache             b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
13433d2a780SThomas Raoux         continue;
13533d2a780SThomas Raoux       }
13633d2a780SThomas Raoux       newShape.push_back(op.getShape(operand)[idx]);
137d5716395SNicolas Vasilache       exprs.push_back(
138d5716395SNicolas Vasilache           b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
13933d2a780SThomas Raoux       reassociation.push_back({index++});
14033d2a780SThomas Raoux     }
14133d2a780SThomas Raoux     newMaps.push_back(
14233d2a780SThomas Raoux         AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
14333d2a780SThomas Raoux     // If the shape is unchanged the input doesn't change.
14433d2a780SThomas Raoux     if (newShape == op.getShape(operand)) {
14533d2a780SThomas Raoux       newInputs.push_back(operand->get());
14633d2a780SThomas Raoux       continue;
14733d2a780SThomas Raoux     }
14833d2a780SThomas Raoux     Type newType = RankedTensorType::get(
14933d2a780SThomas Raoux         newShape,
15033d2a780SThomas Raoux         operand->get().getType().cast<RankedTensorType>().getElementType());
15133d2a780SThomas Raoux     Value newInput = b.create<tensor::ExpandShapeOp>(
15233d2a780SThomas Raoux         loc, newType, operand->get(), reassociation);
15333d2a780SThomas Raoux     newInputs.push_back(newInput);
15433d2a780SThomas Raoux   }
155f439b319SNicolas Vasilache 
15633d2a780SThomas Raoux   // Calculate the new output map and shape, we insert the new dimension based
15733d2a780SThomas Raoux   // on the index returned by `controlSplitReductionFn`.
15833d2a780SThomas Raoux   SmallVector<int64_t> newOutputShape;
15933d2a780SThomas Raoux   AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0));
16033d2a780SThomas Raoux   ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0));
16133d2a780SThomas Raoux   SmallVector<AffineExpr> outputExpr;
16233d2a780SThomas Raoux   for (unsigned idx :
16333d2a780SThomas Raoux        llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) {
164d5716395SNicolas Vasilache     if (idx == insertSplitDimension) {
16533d2a780SThomas Raoux       newOutputShape.push_back(ratio);
166d5716395SNicolas Vasilache       outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
16733d2a780SThomas Raoux       continue;
16833d2a780SThomas Raoux     }
169d5716395SNicolas Vasilache     unsigned oldDim = idx < insertSplitDimension ? idx : idx - 1;
17033d2a780SThomas Raoux     newOutputShape.push_back(oldShape[oldDim]);
17133d2a780SThomas Raoux     unsigned dim = oldOutputMap.getDimPosition(oldDim);
17233d2a780SThomas Raoux     outputExpr.push_back(
173d5716395SNicolas Vasilache         b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
17433d2a780SThomas Raoux   }
175*178f9bd6SNicolas Vasilache   Value initOrAllocTensor;
176*178f9bd6SNicolas Vasilache   if (useAlloc) {
177*178f9bd6SNicolas Vasilache     initOrAllocTensor = b.create<bufferization::AllocTensorOp>(
178*178f9bd6SNicolas Vasilache         loc,
179*178f9bd6SNicolas Vasilache         RankedTensorType::get(newOutputShape,
180*178f9bd6SNicolas Vasilache                               op.getRegionOutputArgs()[0].getType()),
181*178f9bd6SNicolas Vasilache         ValueRange{});
182*178f9bd6SNicolas Vasilache   } else {
183*178f9bd6SNicolas Vasilache     initOrAllocTensor = b.create<linalg::InitTensorOp>(
18433d2a780SThomas Raoux         loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
185*178f9bd6SNicolas Vasilache   }
186d5716395SNicolas Vasilache   Value constantOp = b.create<arith::ConstantOp>(loc, identity);
18733d2a780SThomas Raoux   Value identityTensor =
188*178f9bd6SNicolas Vasilache       b.create<linalg::FillOp>(op->getLoc(), constantOp, initOrAllocTensor)
18933d2a780SThomas Raoux           .getResult(0);
19033d2a780SThomas Raoux 
19133d2a780SThomas Raoux   newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
19233d2a780SThomas Raoux                                    op.getContext()));
19333d2a780SThomas Raoux   SmallVector<StringRef> newIteratorTypes;
19433d2a780SThomas Raoux   for (auto &it : llvm::enumerate(op.iterator_types())) {
195d5716395SNicolas Vasilache     if (insertSplitDimension == it.index())
19633d2a780SThomas Raoux       newIteratorTypes.push_back(getParallelIteratorTypeName());
19733d2a780SThomas Raoux     newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue());
19833d2a780SThomas Raoux   }
19933d2a780SThomas Raoux   // Create the new op matching the original op with an extra parallel
20033d2a780SThomas Raoux   // dimension.
20133d2a780SThomas Raoux   GenericOp genericOp = b.create<GenericOp>(
202*178f9bd6SNicolas Vasilache       loc, TypeRange({initOrAllocTensor.getType()}), newInputs,
20333d2a780SThomas Raoux       ValueRange({identityTensor}), newMaps, newIteratorTypes);
20433d2a780SThomas Raoux   b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
20533d2a780SThomas Raoux                        genericOp.region().begin());
20633d2a780SThomas Raoux 
207f439b319SNicolas Vasilache   // Then create a new reduction that only reduce the newly added dimension
208f439b319SNicolas Vasilache   // from the previous op.
20933d2a780SThomas Raoux   unsigned intermRank = newOutputShape.size();
21033d2a780SThomas Raoux   AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
21133d2a780SThomas Raoux   SmallVector<Value> outputOperands = op.getOutputOperands();
21233d2a780SThomas Raoux   SmallVector<StringRef> reductionIteratorTypes;
21333d2a780SThomas Raoux   SmallVector<AffineExpr> exprs;
21433d2a780SThomas Raoux   for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
215d5716395SNicolas Vasilache     if (insertSplitDimension == i) {
21633d2a780SThomas Raoux       reductionIteratorTypes.push_back(getReductionIteratorTypeName());
21733d2a780SThomas Raoux     } else {
21833d2a780SThomas Raoux       exprs.push_back(b.getAffineDimExpr(i));
21933d2a780SThomas Raoux       reductionIteratorTypes.push_back(getParallelIteratorTypeName());
22033d2a780SThomas Raoux     }
22133d2a780SThomas Raoux   }
22233d2a780SThomas Raoux   AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
22333d2a780SThomas Raoux   SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
22433d2a780SThomas Raoux 
22533d2a780SThomas Raoux   auto reduction = b.create<GenericOp>(
22633d2a780SThomas Raoux       loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
22733d2a780SThomas Raoux       outputOperands, reductionMaps, reductionIteratorTypes,
22833d2a780SThomas Raoux       [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
22933d2a780SThomas Raoux         Operation *clonedReductionOp = b.clone(*reductionOp);
23033d2a780SThomas Raoux         clonedReductionOp->setOperand(0, inputs[0]);
23133d2a780SThomas Raoux         clonedReductionOp->setOperand(1, inputs[1]);
23233d2a780SThomas Raoux         b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
23333d2a780SThomas Raoux       });
23433d2a780SThomas Raoux   b.replaceOp(op, reduction.getResults());
235f439b319SNicolas Vasilache 
236*178f9bd6SNicolas Vasilache   return SplitReductionResult{
237*178f9bd6SNicolas Vasilache       initOrAllocTensor.getDefiningOp(), identityTensor.getDefiningOp<FillOp>(),
238*178f9bd6SNicolas Vasilache       cast<LinalgOp>(genericOp.getOperation()), reduction};
23933d2a780SThomas Raoux }
24033d2a780SThomas Raoux 
241d5716395SNicolas Vasilache /// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
242d5716395SNicolas Vasilache /// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into
243d5716395SNicolas Vasilache /// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better
244d5716395SNicolas Vasilache /// done as a transform to enable better vectorization.
scaleReductionDim(LinalgOp op,OpOperand & opOperand,unsigned reductionDimPos,int64_t reductionRatio)245d5716395SNicolas Vasilache static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand,
246d5716395SNicolas Vasilache                                    unsigned reductionDimPos,
247d5716395SNicolas Vasilache                                    int64_t reductionRatio) {
248d5716395SNicolas Vasilache   auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
249d5716395SNicolas Vasilache   auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext());
250d5716395SNicolas Vasilache   AffineMap map = op.getTiedIndexingMap(&opOperand);
251d5716395SNicolas Vasilache   AffineMap idMap =
252d5716395SNicolas Vasilache       AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
253d5716395SNicolas Vasilache   AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
254d5716395SNicolas Vasilache   AffineMap composeMap = shiftedIdMap.replace(
255d5716395SNicolas Vasilache       reductionDim, reductionDim * reductionRatio + reductionDimP1,
256d5716395SNicolas Vasilache       shiftedIdMap.getNumDims(), /*numSymbols=*/0);
257d5716395SNicolas Vasilache   return map.compose(composeMap);
258d5716395SNicolas Vasilache }
259d5716395SNicolas Vasilache 
insertParallelDim(LinalgOp op,OpOperand & opOperand,unsigned reductionDimPos,int64_t size)260d5716395SNicolas Vasilache static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
261d5716395SNicolas Vasilache                                    unsigned reductionDimPos, int64_t size) {
262d5716395SNicolas Vasilache   auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
263d5716395SNicolas Vasilache   AffineMap map = op.getTiedIndexingMap(&opOperand);
264d5716395SNicolas Vasilache   AffineMap idMap =
265d5716395SNicolas Vasilache       AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
266d5716395SNicolas Vasilache   AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
267d5716395SNicolas Vasilache   return map.compose(shiftedIdMap).insertResult(reductionDim, reductionDimPos);
268d5716395SNicolas Vasilache }
269d5716395SNicolas Vasilache 
270d5716395SNicolas Vasilache /// Core rewrite implementation.
splitReductionByScaling(PatternRewriter & b,LinalgOp op,const ControlSplitReductionFn & controlSplitReductionFn,bool useAlloc)271d5716395SNicolas Vasilache FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
272d5716395SNicolas Vasilache     PatternRewriter &b, LinalgOp op,
273*178f9bd6SNicolas Vasilache     const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
274d5716395SNicolas Vasilache   OpBuilder::InsertionGuard guard(b);
275d5716395SNicolas Vasilache   b.setInsertionPoint(op);
276d5716395SNicolas Vasilache 
277d5716395SNicolas Vasilache   // Matcher part, enforce preconditions.
278d5716395SNicolas Vasilache   std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
279d5716395SNicolas Vasilache   int64_t splitFactor = control.first;
280d5716395SNicolas Vasilache   unsigned insertSplitDimension = control.second;
281d5716395SNicolas Vasilache   if (splitFactor <= 1)
282d5716395SNicolas Vasilache     return b.notifyMatchFailure(op, "split factor needs to be greater than 1");
283d5716395SNicolas Vasilache 
284d5716395SNicolas Vasilache   SmallVector<unsigned> dims;
285d5716395SNicolas Vasilache   op.getReductionDims(dims);
286d5716395SNicolas Vasilache   if (dims.empty())
287d5716395SNicolas Vasilache     return b.notifyMatchFailure(op, "needs at least 1 reduction dimension");
288d5716395SNicolas Vasilache 
289d5716395SNicolas Vasilache   unsigned reductionDimPos = dims[0];
290d5716395SNicolas Vasilache   SmallVector<int64_t> loopRanges = op.getStaticLoopRanges();
291d5716395SNicolas Vasilache   int64_t reductionDimSize = loopRanges[reductionDimPos];
292d5716395SNicolas Vasilache   if (reductionDimSize == ShapedType::kDynamicSize ||
293d5716395SNicolas Vasilache       reductionDimSize % splitFactor != 0 ||
294d5716395SNicolas Vasilache       insertSplitDimension >= loopRanges.size())
295d5716395SNicolas Vasilache     return b.notifyMatchFailure(
296d5716395SNicolas Vasilache         op, "first reduction dimension not divisible by split factor");
297d5716395SNicolas Vasilache 
298d5716395SNicolas Vasilache   SmallVector<Operation *> combinerOps;
299d5716395SNicolas Vasilache   if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
300d5716395SNicolas Vasilache     return b.notifyMatchFailure(op, "cannot match a reduction pattern");
301d5716395SNicolas Vasilache 
302d5716395SNicolas Vasilache   SmallVector<Attribute> neutralElements = llvm::to_vector<4>(
303d5716395SNicolas Vasilache       llvm::map_range(combinerOps, [&](Operation *reductionOp) {
304d5716395SNicolas Vasilache         return getNeutralElement(reductionOp);
305d5716395SNicolas Vasilache       }));
306d5716395SNicolas Vasilache   if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; }))
307d5716395SNicolas Vasilache     return b.notifyMatchFailure(op, "unknown reduction neutral");
308d5716395SNicolas Vasilache 
309d5716395SNicolas Vasilache   // TODO: relax this when multi-reduction support is available.
310*178f9bd6SNicolas Vasilache   if (op.getNumOutputs() != static_cast<int64_t>(neutralElements.size()))
311d5716395SNicolas Vasilache     return b.notifyMatchFailure(op, "expect one reduction per output");
312d5716395SNicolas Vasilache 
313d5716395SNicolas Vasilache   // Rewrite part.
314d5716395SNicolas Vasilache   // Step 1. Build the intermediate outputs filled with the proper
315d5716395SNicolas Vasilache   // neutralElements. Such outputs are of the same shape with an extra dimension
316d5716395SNicolas Vasilache   // inserted at `insertSplitDimension`.
317d5716395SNicolas Vasilache   //
318d5716395SNicolas Vasilache   // Consider a minimal example where `k` is reduced:
319d5716395SNicolas Vasilache   //     O(i, j) += I(i, j, k)
320d5716395SNicolas Vasilache   // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0.
321d5716395SNicolas Vasilache   // The compute is rewritten as:
322d5716395SNicolas Vasilache   //   a. O_i(kk, i, j) += I(i, j, 16 * k + kk)
323d5716395SNicolas Vasilache   //   b. O(i, j) += O_i(kk, i, j)
324d5716395SNicolas Vasilache   // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5.
325d5716395SNicolas Vasilache   Location loc = op->getLoc();
326d5716395SNicolas Vasilache   MLIRContext *context = op.getContext();
327d5716395SNicolas Vasilache   // For now assume outputs are 1-1 with reduction neutralElements.
328d5716395SNicolas Vasilache   // TODO: generalize when multi-reduction support is available.
329d5716395SNicolas Vasilache   SmallVector<Value> newOutputs;
330d5716395SNicolas Vasilache   newOutputs.reserve(op.getNumOutputs());
331*178f9bd6SNicolas Vasilache   SmallVector<Operation *> initOrAllocTensorOps;
332d5716395SNicolas Vasilache   SmallVector<linalg::FillOp> fillOps;
333d5716395SNicolas Vasilache   fillOps.reserve(op.getNumOutputs());
334d5716395SNicolas Vasilache   for (auto it : llvm::zip(op.outputs(), neutralElements)) {
335d5716395SNicolas Vasilache     Value rankedTensor = std::get<0>(it);
336d5716395SNicolas Vasilache     auto t = rankedTensor.getType().cast<RankedTensorType>();
337d5716395SNicolas Vasilache     RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
338d5716395SNicolas Vasilache         reductionDimSize / splitFactor, insertSplitDimension);
339d5716395SNicolas Vasilache     SmallVector<Value> dims =
340d5716395SNicolas Vasilache         tensor::createDynamicDimValues(b, loc, rankedTensor);
341*178f9bd6SNicolas Vasilache     Value initOrAllocTensor;
342*178f9bd6SNicolas Vasilache     if (useAlloc) {
343*178f9bd6SNicolas Vasilache       initOrAllocTensor =
344*178f9bd6SNicolas Vasilache           b.create<bufferization::AllocTensorOp>(loc, newT, dims);
345*178f9bd6SNicolas Vasilache     } else {
346*178f9bd6SNicolas Vasilache       initOrAllocTensor = b.create<linalg::InitTensorOp>(
347d5716395SNicolas Vasilache           loc, dims, newT.getShape(), t.getElementType());
348*178f9bd6SNicolas Vasilache     }
349d5716395SNicolas Vasilache     Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
350d5716395SNicolas Vasilache     fillOps.push_back(
351*178f9bd6SNicolas Vasilache         b.create<linalg::FillOp>(op->getLoc(), constantOp, initOrAllocTensor));
352d5716395SNicolas Vasilache     newOutputs.push_back(fillOps.back().getResult(0));
353*178f9bd6SNicolas Vasilache     initOrAllocTensorOps.push_back(initOrAllocTensor.getDefiningOp());
354d5716395SNicolas Vasilache   }
355d5716395SNicolas Vasilache 
356d5716395SNicolas Vasilache   // Step 2. Reindex / expand indexing maps.
357d5716395SNicolas Vasilache   // Reindex existing input indexings: k -> k * splitFactor + k'.
358d5716395SNicolas Vasilache   SmallVector<AffineMap> newMaps;
359d5716395SNicolas Vasilache   newMaps.reserve(op.getNumInputsAndOutputs() + 1);
360d5716395SNicolas Vasilache   for (OpOperand *o : op.getInputOperands())
361d5716395SNicolas Vasilache     newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
362d5716395SNicolas Vasilache   // Provision a new indexing for the shape-only tensor.
363d5716395SNicolas Vasilache   auto nDims = op.getNumLoops() + 1;
364d5716395SNicolas Vasilache   auto redDim = getAffineDimExpr(reductionDimPos, context);
365d5716395SNicolas Vasilache   auto redDimP1 = getAffineDimExpr(reductionDimPos + 1, context);
366d5716395SNicolas Vasilache   newMaps.push_back(AffineMap::get(nDims, 0, {redDim, redDimP1}, context));
367d5716395SNicolas Vasilache   // Expand existing output indexings.
368d5716395SNicolas Vasilache   // TODO: a subset of these may not reduce along reducePos and should be
369d5716395SNicolas Vasilache   // reindexed: k -> k * splitFactor + k', when multi-reduction support is
370d5716395SNicolas Vasilache   // available.
371d5716395SNicolas Vasilache   for (OpOperand *o : op.getOutputOperands())
372d5716395SNicolas Vasilache     newMaps.push_back(insertParallelDim(op, *o, reductionDimPos,
373d5716395SNicolas Vasilache                                         reductionDimSize / splitFactor));
374d5716395SNicolas Vasilache 
375d5716395SNicolas Vasilache   // Step 3. Handle operands.
376d5716395SNicolas Vasilache   // Compute the new input tensors.
377d5716395SNicolas Vasilache   auto newInputs = llvm::to_vector<4>(op.inputs());
378d5716395SNicolas Vasilache   // Add a single shape-only tensor to carry the dimensions without resorting to
379d5716395SNicolas Vasilache   // more complex inversions.
380d5716395SNicolas Vasilache   newInputs.push_back(b.create<linalg::InitTensorOp>(
381d5716395SNicolas Vasilache       loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor},
382d5716395SNicolas Vasilache       b.getIntegerType(1)));
383d5716395SNicolas Vasilache   // Output tensors are already good to go.
384d5716395SNicolas Vasilache 
385d5716395SNicolas Vasilache   // Step 4. Create the new op matching the original op with an extra parallel
386d5716395SNicolas Vasilache   // dimension.
387d5716395SNicolas Vasilache   SmallVector<StringRef> iteratorTypes =
388d5716395SNicolas Vasilache       llvm::to_vector<4>(op.getIteratorTypes().getAsValueRange<StringAttr>());
389d5716395SNicolas Vasilache   iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
390d5716395SNicolas Vasilache                        getParallelIteratorTypeName());
391d5716395SNicolas Vasilache   GenericOp genericOp =
392d5716395SNicolas Vasilache       b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs,
393d5716395SNicolas Vasilache                           newOutputs, newMaps, iteratorTypes);
394d5716395SNicolas Vasilache   b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
395d5716395SNicolas Vasilache                        genericOp.region().begin());
396d5716395SNicolas Vasilache   genericOp.region().front().insertArgument(reductionDimPos,
397d5716395SNicolas Vasilache                                             b.getIntegerType(1), loc);
398d5716395SNicolas Vasilache 
399d5716395SNicolas Vasilache   // Step 5. Create new reduction ops that only reduce the newly added
400d5716395SNicolas Vasilache   // dimensions from the previous op.
401d5716395SNicolas Vasilache   // For now assume outputs are 1-1 with reduction ops.
402d5716395SNicolas Vasilache   // TODO: a subset of these may not reduce in the first place and do not
403d5716395SNicolas Vasilache   // require a new op, when multi-reduction support is available.
404d5716395SNicolas Vasilache   // TODO: all results can be handled in a single GenericOp, when
405d5716395SNicolas Vasilache   // multi-reduction support is available.
406d5716395SNicolas Vasilache   SmallVector<LinalgOp> results;
407d5716395SNicolas Vasilache   for (auto it :
408d5716395SNicolas Vasilache        llvm::zip(genericOp->getResults(), op.outputs(), combinerOps)) {
409d5716395SNicolas Vasilache     Value reindexedOutput = std::get<0>(it);
410d5716395SNicolas Vasilache     Value originalOutput = std::get<1>(it);
411d5716395SNicolas Vasilache     auto originalOutputType = originalOutput.getType().cast<RankedTensorType>();
412d5716395SNicolas Vasilache     Operation *combinerOp = std::get<2>(it);
413d5716395SNicolas Vasilache 
414d5716395SNicolas Vasilache     AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
415d5716395SNicolas Vasilache     SmallVector<AffineMap> indexingMaps = {
416d5716395SNicolas Vasilache         map, map.dropResult(insertSplitDimension)};
417d5716395SNicolas Vasilache     SmallVector<StringRef> reductionIteratorTypes(
418d5716395SNicolas Vasilache         originalOutputType.getRank() + 1, getParallelIteratorTypeName());
419d5716395SNicolas Vasilache     reductionIteratorTypes[insertSplitDimension] =
420d5716395SNicolas Vasilache         getReductionIteratorTypeName();
421d5716395SNicolas Vasilache 
422d5716395SNicolas Vasilache     // clang-format off
423d5716395SNicolas Vasilache     auto reductionOp = b.create<GenericOp>(
424d5716395SNicolas Vasilache         loc,
425d5716395SNicolas Vasilache         originalOutputType,
426d5716395SNicolas Vasilache         reindexedOutput,
427d5716395SNicolas Vasilache         originalOutput,
428d5716395SNicolas Vasilache         indexingMaps,
429d5716395SNicolas Vasilache         reductionIteratorTypes,
430d5716395SNicolas Vasilache         [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) {
431d5716395SNicolas Vasilache           Operation *clonedReductionOp = b.clone(*combinerOp);
432d5716395SNicolas Vasilache           clonedReductionOp->setOperand(0, bbArgs[0]);
433d5716395SNicolas Vasilache           clonedReductionOp->setOperand(1, bbArgs[1]);
434d5716395SNicolas Vasilache           b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
435d5716395SNicolas Vasilache         });
436d5716395SNicolas Vasilache     // clang-format on
437d5716395SNicolas Vasilache 
438d5716395SNicolas Vasilache     results.push_back(reductionOp);
439d5716395SNicolas Vasilache   }
440d5716395SNicolas Vasilache 
441d5716395SNicolas Vasilache   // TODO: extend when multi-reduction support is available.
442d5716395SNicolas Vasilache   assert(fillOps.size() == results.size() && results.size() == 1);
443d5716395SNicolas Vasilache   b.replaceOp(op, results.front()->getResults());
444*178f9bd6SNicolas Vasilache   return SplitReductionResult{initOrAllocTensorOps.front(), fillOps.front(),
445d5716395SNicolas Vasilache                               cast<LinalgOp>(genericOp.getOperation()),
446d5716395SNicolas Vasilache                               results.front()};
447d5716395SNicolas Vasilache }
448d5716395SNicolas Vasilache 
44933d2a780SThomas Raoux namespace {
45033d2a780SThomas Raoux 
45133d2a780SThomas Raoux struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
45233d2a780SThomas Raoux   /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
LinalgSplitReduction__anon86cbc2bb0511::LinalgSplitReduction45333d2a780SThomas Raoux   LinalgSplitReduction(MLIRContext *context,
45433d2a780SThomas Raoux                        ControlSplitReductionFn controlSplitReductionFn,
455*178f9bd6SNicolas Vasilache                        LinalgTransformationFilter f, bool useAlloc = false,
456*178f9bd6SNicolas Vasilache                        PatternBenefit benefit = 1)
45733d2a780SThomas Raoux       : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
458e188ad8bSMehdi Amini         controlSplitReductionFn(std::move(controlSplitReductionFn)),
459*178f9bd6SNicolas Vasilache         useAlloc(useAlloc), filter(std::move(f)) {}
46033d2a780SThomas Raoux 
matchAndRewrite__anon86cbc2bb0511::LinalgSplitReduction46133d2a780SThomas Raoux   LogicalResult matchAndRewrite(LinalgOp op,
46233d2a780SThomas Raoux                                 PatternRewriter &rewriter) const override {
463*178f9bd6SNicolas Vasilache     return splitReduction(rewriter, op, controlSplitReductionFn, filter,
464*178f9bd6SNicolas Vasilache                           useAlloc);
46533d2a780SThomas Raoux   }
46633d2a780SThomas Raoux 
46733d2a780SThomas Raoux private:
46833d2a780SThomas Raoux   ControlSplitReductionFn controlSplitReductionFn;
469*178f9bd6SNicolas Vasilache   bool useAlloc;
47033d2a780SThomas Raoux   LinalgTransformationFilter filter;
47133d2a780SThomas Raoux };
47233d2a780SThomas Raoux 
47333d2a780SThomas Raoux } // namespace
47433d2a780SThomas Raoux 
populateSplitReductionPattern(RewritePatternSet & patterns,const ControlSplitReductionFn & controlSplitReductionFn,const LinalgTransformationFilter & f,bool useAlloc)47533d2a780SThomas Raoux void linalg::populateSplitReductionPattern(
47633d2a780SThomas Raoux     RewritePatternSet &patterns,
477e188ad8bSMehdi Amini     const ControlSplitReductionFn &controlSplitReductionFn,
478*178f9bd6SNicolas Vasilache     const LinalgTransformationFilter &f, bool useAlloc) {
47933d2a780SThomas Raoux   patterns.add<LinalgSplitReduction>(patterns.getContext(),
480*178f9bd6SNicolas Vasilache                                      controlSplitReductionFn, f, useAlloc);
48133d2a780SThomas Raoux }
482