1 //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements linalg transformation to break a reduction dimension
10 // between a parallel and a reduction dimension.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
16 #include "mlir/Analysis/SliceAnalysis.h"
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/IR/PatternMatch.h"
23 
24 using namespace mlir;
25 using namespace mlir::linalg;
26 
27 /// Return the identity numeric value associated to the give op.
28 static Optional<Attribute> getIdentity(Operation *op) {
29   // Builder only used as helper for attribute creation.
30   OpBuilder b(op->getContext());
31   Type resultType = op->getResult(0).getType();
32   if (auto floatType = resultType.dyn_cast<FloatType>()) {
33     const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
34     if (isa<arith::AddFOp>(op))
35       return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
36     if (isa<arith::MulFOp>(op))
37       return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
38     if (isa<arith::MaxFOp>(op))
39       return b.getFloatAttr(resultType,
40                             llvm::APFloat::getLargest(semantic, true));
41     if (isa<arith::MinFOp>(op))
42       return b.getFloatAttr(resultType,
43                             llvm::APFloat::getLargest(semantic, true));
44     return llvm::None;
45   }
46   if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
47     return b.getIntegerAttr(resultType, 0);
48   if (isa<arith::AndIOp>(op))
49     return b.getIntegerAttr(resultType, -1);
50   if (isa<arith::MaxSIOp>(op))
51     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
52   if (isa<arith::MinSIOp>(op))
53     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
54   if (isa<arith::MulIOp>(op))
55     return b.getIntegerAttr(resultType, 1);
56   return llvm::None;
57 }
58 
59 FailureOr<LinalgOp> mlir::linalg::splitReduction(
60     PatternRewriter &b, LinalgOp op,
61     const ControlSplitReductionFn &controlSplitReductionFn,
62     const LinalgTransformationFilter &filter) {
63   if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
64       op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
65       !op.hasOnlyProjectedPermutations())
66     return b.notifyMatchFailure(op, "precondition not met");
67   std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
68   int64_t ratio = control.first;
69   unsigned insertDimIndex = control.second;
70   if (ratio <= 1)
71     return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
72   SmallVector<unsigned> dims;
73   op.getReductionDims(dims);
74   assert(dims.size() == 1);
75   unsigned reductionDim = dims[0];
76   Optional<SmallVector<int64_t, 4>> loopRanges = op.getStaticLoopRanges();
77   if (!loopRanges)
78     return b.notifyMatchFailure(op, "Cannot analyze loops");
79   int64_t reductionDimSize = (*loopRanges)[reductionDim];
80   if (reductionDimSize == ShapedType::kDynamicSize ||
81       reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges->size())
82     return b.notifyMatchFailure(
83         op, "Reduction dimension not divisible by split ratio");
84   SmallVector<Operation *, 4> combinerOps;
85   if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
86       combinerOps.size() != 1)
87     return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
88   Operation *reductionOp = combinerOps[0];
89   Optional<Attribute> identity = getIdentity(reductionOp);
90   if (!identity)
91     return b.notifyMatchFailure(op, "Unknown identity value for the redution");
92 
93   Location loc = op->getLoc();
94   SmallVector<Value> newInputs;
95   SmallVector<AffineMap> newMaps;
96   // Calculate the new shapes and indexing maps of the input operands.
97   for (OpOperand *operand : op.getInputOperands()) {
98     AffineMap map = op.getTiedIndexingMap(operand);
99     SmallVector<int64_t> newShape;
100     SmallVector<AffineExpr> exprs;
101     SmallVector<ReassociationIndices> reassociation;
102     unsigned index = 0;
103     for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
104       unsigned dim = map.getDimPosition(idx);
105       if (reductionDim == dim) {
106         newShape.push_back(ratio);
107         newShape.push_back(op.getShape(operand)[idx] / ratio);
108         reassociation.push_back({index++, index++});
109         exprs.push_back(b.getAffineDimExpr(insertDimIndex));
110         exprs.push_back(
111             b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
112         continue;
113       }
114       newShape.push_back(op.getShape(operand)[idx]);
115       exprs.push_back(b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
116       reassociation.push_back({index++});
117     }
118     newMaps.push_back(
119         AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
120     // If the shape is unchanged the input doesn't change.
121     if (newShape == op.getShape(operand)) {
122       newInputs.push_back(operand->get());
123       continue;
124     }
125     Type newType = RankedTensorType::get(
126         newShape,
127         operand->get().getType().cast<RankedTensorType>().getElementType());
128     Value newInput = b.create<tensor::ExpandShapeOp>(
129         loc, newType, operand->get(), reassociation);
130     newInputs.push_back(newInput);
131   }
132   // Calculate the new output map and shape, we insert the new dimension based
133   // on the index returned by `controlSplitReductionFn`.
134   SmallVector<int64_t> newOutputShape;
135   AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0));
136   ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0));
137   SmallVector<AffineExpr> outputExpr;
138   for (unsigned idx :
139        llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) {
140     if (idx == insertDimIndex) {
141       newOutputShape.push_back(ratio);
142       outputExpr.push_back(b.getAffineDimExpr(insertDimIndex));
143       continue;
144     }
145     unsigned oldDim = idx < insertDimIndex ? idx : idx - 1;
146     newOutputShape.push_back(oldShape[oldDim]);
147     unsigned dim = oldOutputMap.getDimPosition(oldDim);
148     outputExpr.push_back(
149         b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
150   }
151   Value initTensor = b.create<linalg::InitTensorOp>(
152       loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
153   Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
154   Value identityTensor =
155       b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor)
156           .getResult(0);
157 
158   newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
159                                    op.getContext()));
160   SmallVector<StringRef> newIteratorTypes;
161   for (auto &it : llvm::enumerate(op.iterator_types())) {
162     if (insertDimIndex == it.index())
163       newIteratorTypes.push_back(getParallelIteratorTypeName());
164     newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue());
165   }
166   // Create the new op matching the original op with an extra parallel
167   // dimension.
168   GenericOp genericOp = b.create<GenericOp>(
169       loc, TypeRange({initTensor.getType()}), newInputs,
170       ValueRange({identityTensor}), newMaps, newIteratorTypes);
171   b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
172                        genericOp.region().begin());
173 
174   // Then create a new reduction that only reduce the newly added dimension from
175   // the previous op.
176   unsigned intermRank = newOutputShape.size();
177   AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
178   SmallVector<Value> outputOperands = op.getOutputOperands();
179   SmallVector<StringRef> reductionIteratorTypes;
180   SmallVector<AffineExpr> exprs;
181   for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
182     if (insertDimIndex == i) {
183       reductionIteratorTypes.push_back(getReductionIteratorTypeName());
184     } else {
185       exprs.push_back(b.getAffineDimExpr(i));
186       reductionIteratorTypes.push_back(getParallelIteratorTypeName());
187     }
188   }
189   AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
190   SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
191 
192   auto reduction = b.create<GenericOp>(
193       loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
194       outputOperands, reductionMaps, reductionIteratorTypes,
195       [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
196         Operation *clonedReductionOp = b.clone(*reductionOp);
197         clonedReductionOp->setOperand(0, inputs[0]);
198         clonedReductionOp->setOperand(1, inputs[1]);
199         b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
200       });
201   b.replaceOp(op, reduction.getResults());
202   filter.replaceLinalgTransformationFilter(b, genericOp);
203   filter.replaceLinalgTransformationFilter(b, reduction);
204   return cast<LinalgOp>(genericOp.getOperation());
205 }
206 
207 namespace {
208 
209 struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
210   /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
211   LinalgSplitReduction(MLIRContext *context,
212                        ControlSplitReductionFn controlSplitReductionFn,
213                        LinalgTransformationFilter f, PatternBenefit benefit = 1)
214       : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
215         controlSplitReductionFn(std::move(controlSplitReductionFn)),
216         filter(std::move(f)) {}
217 
218   LogicalResult matchAndRewrite(LinalgOp op,
219                                 PatternRewriter &rewriter) const override {
220     return splitReduction(rewriter, op, controlSplitReductionFn, filter);
221   }
222 
223 private:
224   ControlSplitReductionFn controlSplitReductionFn;
225   LinalgTransformationFilter filter;
226 };
227 
228 } // namespace
229 
230 void linalg::populateSplitReductionPattern(
231     RewritePatternSet &patterns,
232     const ControlSplitReductionFn &controlSplitReductionFn,
233     const LinalgTransformationFilter &f) {
234   patterns.add<LinalgSplitReduction>(patterns.getContext(),
235                                      controlSplitReductionFn, f);
236 }
237