1 //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements linalg transformation to break a reduction dimension
10 // between a parallel and a reduction dimension.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
16 #include "mlir/Analysis/SliceAnalysis.h"
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/IR/PatternMatch.h"
23 
24 using namespace mlir;
25 using namespace mlir::linalg;
26 
27 /// Return the identity numeric value associated to the give op.
28 static Optional<Attribute> getIdentity(Operation *op) {
29   // Builder only used as helper for attribute creation.
30   OpBuilder b(op->getContext());
31   Type resultType = op->getResult(0).getType();
32   if (auto floatType = resultType.dyn_cast<FloatType>()) {
33     const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
34     if (isa<arith::AddFOp>(op))
35       return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
36     if (isa<arith::MulFOp>(op))
37       return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
38     if (isa<arith::MaxFOp>(op))
39       return b.getFloatAttr(resultType,
40                             llvm::APFloat::getLargest(semantic, true));
41     if (isa<arith::MinFOp>(op))
42       return b.getFloatAttr(resultType,
43                             llvm::APFloat::getLargest(semantic, true));
44     return llvm::None;
45   }
46   if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
47     return b.getIntegerAttr(resultType, 0);
48   if (isa<arith::AndIOp>(op))
49     return b.getIntegerAttr(resultType, -1);
50   if (isa<arith::MaxSIOp>(op))
51     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
52   if (isa<arith::MinSIOp>(op))
53     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
54   if (isa<arith::MulIOp>(op))
55     return b.getIntegerAttr(resultType, 1);
56   return llvm::None;
57 }
58 
59 FailureOr<LinalgOp> mlir::linalg::splitReduction(
60     PatternRewriter &b, LinalgOp op,
61     const ControlSplitReductionFn &controlSplitReductionFn,
62     const LinalgTransformationFilter &filter) {
63   if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
64       op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
65       !op.hasOnlyProjectedPermutations())
66     return b.notifyMatchFailure(op, "precondition not met");
67 
68   FailureOr<SplitReductionResult> res =
69       splitReduction(b, op, controlSplitReductionFn);
70   if (failed(res))
71     return failure();
72 
73   filter.replaceLinalgTransformationFilter(b, res->splitLinalgOp);
74   filter.replaceLinalgTransformationFilter(b, res->resultCombiningLinalgOp);
75 
76   return res->splitLinalgOp;
77 }
78 
79 FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
80     PatternRewriter &b, LinalgOp op,
81     const ControlSplitReductionFn &controlSplitReductionFn) {
82   OpBuilder::InsertionGuard guard(b);
83   b.setInsertionPoint(op);
84 
85   std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
86   int64_t ratio = control.first;
87   unsigned insertDimIndex = control.second;
88   if (ratio <= 1)
89     return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
90 
91   SmallVector<unsigned> dims;
92   op.getReductionDims(dims);
93   assert(dims.size() == 1);
94   unsigned reductionDim = dims[0];
95   SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
96   int64_t reductionDimSize = loopRanges[reductionDim];
97   if (reductionDimSize == ShapedType::kDynamicSize ||
98       reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges.size())
99     return b.notifyMatchFailure(
100         op, "Reduction dimension not divisible by split ratio");
101 
102   SmallVector<Operation *, 4> combinerOps;
103   if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
104       combinerOps.size() != 1)
105     return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
106 
107   Operation *reductionOp = combinerOps[0];
108   Optional<Attribute> identity = getIdentity(reductionOp);
109   if (!identity)
110     return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
111 
112   Location loc = op->getLoc();
113   SmallVector<Value> newInputs;
114   SmallVector<AffineMap> newMaps;
115   // Calculate the new shapes and indexing maps of the input operands.
116   for (OpOperand *operand : op.getInputOperands()) {
117     AffineMap map = op.getTiedIndexingMap(operand);
118     SmallVector<int64_t> newShape;
119     SmallVector<AffineExpr> exprs;
120     SmallVector<ReassociationIndices> reassociation;
121     unsigned index = 0;
122     for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
123       unsigned dim = map.getDimPosition(idx);
124       if (reductionDim == dim) {
125         newShape.push_back(ratio);
126         newShape.push_back(op.getShape(operand)[idx] / ratio);
127         reassociation.push_back({index++, index++});
128         exprs.push_back(b.getAffineDimExpr(insertDimIndex));
129         exprs.push_back(
130             b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
131         continue;
132       }
133       newShape.push_back(op.getShape(operand)[idx]);
134       exprs.push_back(b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
135       reassociation.push_back({index++});
136     }
137     newMaps.push_back(
138         AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
139     // If the shape is unchanged the input doesn't change.
140     if (newShape == op.getShape(operand)) {
141       newInputs.push_back(operand->get());
142       continue;
143     }
144     Type newType = RankedTensorType::get(
145         newShape,
146         operand->get().getType().cast<RankedTensorType>().getElementType());
147     Value newInput = b.create<tensor::ExpandShapeOp>(
148         loc, newType, operand->get(), reassociation);
149     newInputs.push_back(newInput);
150   }
151 
152   // Calculate the new output map and shape, we insert the new dimension based
153   // on the index returned by `controlSplitReductionFn`.
154   SmallVector<int64_t> newOutputShape;
155   AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0));
156   ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0));
157   SmallVector<AffineExpr> outputExpr;
158   for (unsigned idx :
159        llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) {
160     if (idx == insertDimIndex) {
161       newOutputShape.push_back(ratio);
162       outputExpr.push_back(b.getAffineDimExpr(insertDimIndex));
163       continue;
164     }
165     unsigned oldDim = idx < insertDimIndex ? idx : idx - 1;
166     newOutputShape.push_back(oldShape[oldDim]);
167     unsigned dim = oldOutputMap.getDimPosition(oldDim);
168     outputExpr.push_back(
169         b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
170   }
171   Value initTensor = b.create<linalg::InitTensorOp>(
172       loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
173   Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
174   Value identityTensor =
175       b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor)
176           .getResult(0);
177 
178   newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
179                                    op.getContext()));
180   SmallVector<StringRef> newIteratorTypes;
181   for (auto &it : llvm::enumerate(op.iterator_types())) {
182     if (insertDimIndex == it.index())
183       newIteratorTypes.push_back(getParallelIteratorTypeName());
184     newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue());
185   }
186   // Create the new op matching the original op with an extra parallel
187   // dimension.
188   GenericOp genericOp = b.create<GenericOp>(
189       loc, TypeRange({initTensor.getType()}), newInputs,
190       ValueRange({identityTensor}), newMaps, newIteratorTypes);
191   b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
192                        genericOp.region().begin());
193 
194   // Then create a new reduction that only reduce the newly added dimension
195   // from the previous op.
196   unsigned intermRank = newOutputShape.size();
197   AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
198   SmallVector<Value> outputOperands = op.getOutputOperands();
199   SmallVector<StringRef> reductionIteratorTypes;
200   SmallVector<AffineExpr> exprs;
201   for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
202     if (insertDimIndex == i) {
203       reductionIteratorTypes.push_back(getReductionIteratorTypeName());
204     } else {
205       exprs.push_back(b.getAffineDimExpr(i));
206       reductionIteratorTypes.push_back(getParallelIteratorTypeName());
207     }
208   }
209   AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
210   SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
211 
212   auto reduction = b.create<GenericOp>(
213       loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
214       outputOperands, reductionMaps, reductionIteratorTypes,
215       [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
216         Operation *clonedReductionOp = b.clone(*reductionOp);
217         clonedReductionOp->setOperand(0, inputs[0]);
218         clonedReductionOp->setOperand(1, inputs[1]);
219         b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
220       });
221   b.replaceOp(op, reduction.getResults());
222 
223   return SplitReductionResult{identityTensor.getDefiningOp<FillOp>(),
224                               cast<LinalgOp>(genericOp.getOperation()),
225                               reduction};
226 }
227 
228 namespace {
229 
230 struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
231   /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
232   LinalgSplitReduction(MLIRContext *context,
233                        ControlSplitReductionFn controlSplitReductionFn,
234                        LinalgTransformationFilter f, PatternBenefit benefit = 1)
235       : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
236         controlSplitReductionFn(std::move(controlSplitReductionFn)),
237         filter(std::move(f)) {}
238 
239   LogicalResult matchAndRewrite(LinalgOp op,
240                                 PatternRewriter &rewriter) const override {
241     return splitReduction(rewriter, op, controlSplitReductionFn, filter);
242   }
243 
244 private:
245   ControlSplitReductionFn controlSplitReductionFn;
246   LinalgTransformationFilter filter;
247 };
248 
249 } // namespace
250 
251 void linalg::populateSplitReductionPattern(
252     RewritePatternSet &patterns,
253     const ControlSplitReductionFn &controlSplitReductionFn,
254     const LinalgTransformationFilter &f) {
255   patterns.add<LinalgSplitReduction>(patterns.getContext(),
256                                      controlSplitReductionFn, f);
257 }
258