1 //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements linalg transformation to break a reduction dimension
10 // between a parallel and a reduction dimension.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
16 #include "mlir/Analysis/SliceAnalysis.h"
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/IR/PatternMatch.h"
23 
24 using namespace mlir;
25 using namespace mlir::linalg;
26 
27 /// Return the identity numeric value associated to the give op.
28 static Optional<Attribute> getIdentity(Operation *op) {
29   // Builder only used as helper for attribute creation.
30   OpBuilder b(op->getContext());
31   Type resultType = op->getResult(0).getType();
32   if (auto floatType = resultType.dyn_cast<FloatType>()) {
33     const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
34     if (isa<arith::AddFOp>(op))
35       return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
36     if (isa<arith::MulFOp>(op))
37       return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
38     if (isa<arith::MaxFOp>(op))
39       return b.getFloatAttr(resultType,
40                             llvm::APFloat::getLargest(semantic, true));
41     if (isa<arith::MinFOp>(op))
42       return b.getFloatAttr(resultType,
43                             llvm::APFloat::getLargest(semantic, true));
44     return llvm::None;
45   }
46   if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
47     return b.getIntegerAttr(resultType, 0);
48   if (isa<arith::AndIOp>(op))
49     return b.getIntegerAttr(resultType, -1);
50   if (isa<arith::MaxSIOp>(op))
51     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
52   if (isa<arith::MinSIOp>(op))
53     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
54   if (isa<arith::MulIOp>(op))
55     return b.getIntegerAttr(resultType, 1);
56   return llvm::None;
57 }
58 
59 FailureOr<LinalgOp> mlir::linalg::splitReduction(
60     PatternRewriter &b, LinalgOp op,
61     const ControlSplitReductionFn &controlSplitReductionFn,
62     const LinalgTransformationFilter &filter) {
63   if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
64       op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
65       !op.hasOnlyProjectedPermutations())
66     return b.notifyMatchFailure(op, "precondition not met");
67   std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
68   int64_t ratio = control.first;
69   unsigned insertDimIndex = control.second;
70   if (ratio <= 1)
71     return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
72   SmallVector<unsigned> dims;
73   op.getReductionDims(dims);
74   assert(dims.size() == 1);
75   unsigned reductionDim = dims[0];
76   SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
77   int64_t reductionDimSize = loopRanges[reductionDim];
78   if (reductionDimSize == ShapedType::kDynamicSize ||
79       reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges.size())
80     return b.notifyMatchFailure(
81         op, "Reduction dimension not divisible by split ratio");
82   SmallVector<Operation *, 4> combinerOps;
83   if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
84       combinerOps.size() != 1)
85     return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
86   Operation *reductionOp = combinerOps[0];
87   Optional<Attribute> identity = getIdentity(reductionOp);
88   if (!identity)
89     return b.notifyMatchFailure(op, "Unknown identity value for the redution");
90 
91   Location loc = op->getLoc();
92   SmallVector<Value> newInputs;
93   SmallVector<AffineMap> newMaps;
94   // Calculate the new shapes and indexing maps of the input operands.
95   for (OpOperand *operand : op.getInputOperands()) {
96     AffineMap map = op.getTiedIndexingMap(operand);
97     SmallVector<int64_t> newShape;
98     SmallVector<AffineExpr> exprs;
99     SmallVector<ReassociationIndices> reassociation;
100     unsigned index = 0;
101     for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
102       unsigned dim = map.getDimPosition(idx);
103       if (reductionDim == dim) {
104         newShape.push_back(ratio);
105         newShape.push_back(op.getShape(operand)[idx] / ratio);
106         reassociation.push_back({index++, index++});
107         exprs.push_back(b.getAffineDimExpr(insertDimIndex));
108         exprs.push_back(
109             b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
110         continue;
111       }
112       newShape.push_back(op.getShape(operand)[idx]);
113       exprs.push_back(b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
114       reassociation.push_back({index++});
115     }
116     newMaps.push_back(
117         AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
118     // If the shape is unchanged the input doesn't change.
119     if (newShape == op.getShape(operand)) {
120       newInputs.push_back(operand->get());
121       continue;
122     }
123     Type newType = RankedTensorType::get(
124         newShape,
125         operand->get().getType().cast<RankedTensorType>().getElementType());
126     Value newInput = b.create<tensor::ExpandShapeOp>(
127         loc, newType, operand->get(), reassociation);
128     newInputs.push_back(newInput);
129   }
130   // Calculate the new output map and shape, we insert the new dimension based
131   // on the index returned by `controlSplitReductionFn`.
132   SmallVector<int64_t> newOutputShape;
133   AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0));
134   ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0));
135   SmallVector<AffineExpr> outputExpr;
136   for (unsigned idx :
137        llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) {
138     if (idx == insertDimIndex) {
139       newOutputShape.push_back(ratio);
140       outputExpr.push_back(b.getAffineDimExpr(insertDimIndex));
141       continue;
142     }
143     unsigned oldDim = idx < insertDimIndex ? idx : idx - 1;
144     newOutputShape.push_back(oldShape[oldDim]);
145     unsigned dim = oldOutputMap.getDimPosition(oldDim);
146     outputExpr.push_back(
147         b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
148   }
149   Value initTensor = b.create<linalg::InitTensorOp>(
150       loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
151   Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
152   Value identityTensor =
153       b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor)
154           .getResult(0);
155 
156   newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
157                                    op.getContext()));
158   SmallVector<StringRef> newIteratorTypes;
159   for (auto &it : llvm::enumerate(op.iterator_types())) {
160     if (insertDimIndex == it.index())
161       newIteratorTypes.push_back(getParallelIteratorTypeName());
162     newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue());
163   }
164   // Create the new op matching the original op with an extra parallel
165   // dimension.
166   GenericOp genericOp = b.create<GenericOp>(
167       loc, TypeRange({initTensor.getType()}), newInputs,
168       ValueRange({identityTensor}), newMaps, newIteratorTypes);
169   b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
170                        genericOp.region().begin());
171 
172   // Then create a new reduction that only reduce the newly added dimension from
173   // the previous op.
174   unsigned intermRank = newOutputShape.size();
175   AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
176   SmallVector<Value> outputOperands = op.getOutputOperands();
177   SmallVector<StringRef> reductionIteratorTypes;
178   SmallVector<AffineExpr> exprs;
179   for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
180     if (insertDimIndex == i) {
181       reductionIteratorTypes.push_back(getReductionIteratorTypeName());
182     } else {
183       exprs.push_back(b.getAffineDimExpr(i));
184       reductionIteratorTypes.push_back(getParallelIteratorTypeName());
185     }
186   }
187   AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
188   SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
189 
190   auto reduction = b.create<GenericOp>(
191       loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
192       outputOperands, reductionMaps, reductionIteratorTypes,
193       [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
194         Operation *clonedReductionOp = b.clone(*reductionOp);
195         clonedReductionOp->setOperand(0, inputs[0]);
196         clonedReductionOp->setOperand(1, inputs[1]);
197         b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
198       });
199   b.replaceOp(op, reduction.getResults());
200   filter.replaceLinalgTransformationFilter(b, genericOp);
201   filter.replaceLinalgTransformationFilter(b, reduction);
202   return cast<LinalgOp>(genericOp.getOperation());
203 }
204 
205 namespace {
206 
207 struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
208   /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
209   LinalgSplitReduction(MLIRContext *context,
210                        ControlSplitReductionFn controlSplitReductionFn,
211                        LinalgTransformationFilter f, PatternBenefit benefit = 1)
212       : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
213         controlSplitReductionFn(std::move(controlSplitReductionFn)),
214         filter(std::move(f)) {}
215 
216   LogicalResult matchAndRewrite(LinalgOp op,
217                                 PatternRewriter &rewriter) const override {
218     return splitReduction(rewriter, op, controlSplitReductionFn, filter);
219   }
220 
221 private:
222   ControlSplitReductionFn controlSplitReductionFn;
223   LinalgTransformationFilter filter;
224 };
225 
226 } // namespace
227 
228 void linalg::populateSplitReductionPattern(
229     RewritePatternSet &patterns,
230     const ControlSplitReductionFn &controlSplitReductionFn,
231     const LinalgTransformationFilter &f) {
232   patterns.add<LinalgSplitReduction>(patterns.getContext(),
233                                      controlSplitReductionFn, f);
234 }
235