1 //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements linalg transformation to break a reduction dimension
10 // between a parallel and a reduction dimension.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
16 #include "mlir/Analysis/SliceAnalysis.h"
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18 #include "mlir/Dialect/Linalg/IR/Linalg.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/Dialect/Tensor/Utils/Utils.h"
23 #include "mlir/IR/PatternMatch.h"
24 
25 using namespace mlir;
26 using namespace mlir::linalg;
27 
28 /// Return the identity numeric value associated to the give op.
29 static Attribute getNeutralElement(Operation *op) {
30   // Builder only used as helper for attribute creation.
31   OpBuilder b(op->getContext());
32   Type resultType = op->getResult(0).getType();
33   if (auto floatType = resultType.dyn_cast<FloatType>()) {
34     const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
35     if (isa<arith::AddFOp>(op))
36       return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
37     if (isa<arith::MulFOp>(op))
38       return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
39     if (isa<arith::MaxFOp>(op))
40       return b.getFloatAttr(resultType,
41                             llvm::APFloat::getLargest(semantic, true));
42     if (isa<arith::MinFOp>(op))
43       return b.getFloatAttr(resultType,
44                             llvm::APFloat::getLargest(semantic, true));
45     return Attribute();
46   }
47   if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
48     return b.getIntegerAttr(resultType, 0);
49   if (isa<arith::AndIOp>(op))
50     return b.getIntegerAttr(resultType, -1);
51   if (isa<arith::MaxSIOp>(op))
52     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
53   if (isa<arith::MinSIOp>(op))
54     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
55   if (isa<arith::MulIOp>(op))
56     return b.getIntegerAttr(resultType, 1);
57   return Attribute();
58 }
59 
60 FailureOr<LinalgOp> mlir::linalg::splitReduction(
61     PatternRewriter &b, LinalgOp op,
62     const ControlSplitReductionFn &controlSplitReductionFn,
63     const LinalgTransformationFilter &filter) {
64   if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
65       op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
66       !op.hasOnlyProjectedPermutations())
67     return b.notifyMatchFailure(op, "precondition not met");
68 
69   FailureOr<SplitReductionResult> res =
70       splitReduction(b, op, controlSplitReductionFn);
71   if (failed(res))
72     return failure();
73 
74   filter.replaceLinalgTransformationFilter(b, res->splitLinalgOp);
75   filter.replaceLinalgTransformationFilter(b, res->resultCombiningLinalgOp);
76 
77   return res->splitLinalgOp;
78 }
79 
80 FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
81     PatternRewriter &b, LinalgOp op,
82     const ControlSplitReductionFn &controlSplitReductionFn) {
83   OpBuilder::InsertionGuard guard(b);
84   b.setInsertionPoint(op);
85 
86   std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
87   int64_t ratio = control.first;
88   unsigned insertSplitDimension = control.second;
89   if (ratio <= 1)
90     return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
91 
92   SmallVector<unsigned> dims;
93   op.getReductionDims(dims);
94   assert(dims.size() == 1);
95   unsigned reductionDim = dims[0];
96   SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
97   int64_t reductionDimSize = loopRanges[reductionDim];
98   if (reductionDimSize == ShapedType::kDynamicSize ||
99       reductionDimSize % ratio != 0 ||
100       insertSplitDimension >= loopRanges.size())
101     return b.notifyMatchFailure(
102         op, "Reduction dimension not divisible by split ratio");
103 
104   SmallVector<Operation *, 4> combinerOps;
105   if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
106       combinerOps.size() != 1)
107     return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
108 
109   Operation *reductionOp = combinerOps[0];
110   Attribute identity = getNeutralElement(reductionOp);
111   if (!identity)
112     return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
113 
114   Location loc = op->getLoc();
115   SmallVector<Value> newInputs;
116   SmallVector<AffineMap> newMaps;
117   // Calculate the new shapes and indexing maps of the input operands.
118   for (OpOperand *operand : op.getInputOperands()) {
119     AffineMap map = op.getTiedIndexingMap(operand);
120     SmallVector<int64_t> newShape;
121     SmallVector<AffineExpr> exprs;
122     SmallVector<ReassociationIndices> reassociation;
123     unsigned index = 0;
124     for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
125       unsigned dim = map.getDimPosition(idx);
126       if (reductionDim == dim) {
127         newShape.push_back(ratio);
128         newShape.push_back(op.getShape(operand)[idx] / ratio);
129         reassociation.push_back({index++, index++});
130         exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
131         exprs.push_back(
132             b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
133         continue;
134       }
135       newShape.push_back(op.getShape(operand)[idx]);
136       exprs.push_back(
137           b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
138       reassociation.push_back({index++});
139     }
140     newMaps.push_back(
141         AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
142     // If the shape is unchanged the input doesn't change.
143     if (newShape == op.getShape(operand)) {
144       newInputs.push_back(operand->get());
145       continue;
146     }
147     Type newType = RankedTensorType::get(
148         newShape,
149         operand->get().getType().cast<RankedTensorType>().getElementType());
150     Value newInput = b.create<tensor::ExpandShapeOp>(
151         loc, newType, operand->get(), reassociation);
152     newInputs.push_back(newInput);
153   }
154 
155   // Calculate the new output map and shape, we insert the new dimension based
156   // on the index returned by `controlSplitReductionFn`.
157   SmallVector<int64_t> newOutputShape;
158   AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0));
159   ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0));
160   SmallVector<AffineExpr> outputExpr;
161   for (unsigned idx :
162        llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) {
163     if (idx == insertSplitDimension) {
164       newOutputShape.push_back(ratio);
165       outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
166       continue;
167     }
168     unsigned oldDim = idx < insertSplitDimension ? idx : idx - 1;
169     newOutputShape.push_back(oldShape[oldDim]);
170     unsigned dim = oldOutputMap.getDimPosition(oldDim);
171     outputExpr.push_back(
172         b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
173   }
174   Value initTensor = b.create<linalg::InitTensorOp>(
175       loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
176   Value constantOp = b.create<arith::ConstantOp>(loc, identity);
177   Value identityTensor =
178       b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor)
179           .getResult(0);
180 
181   newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
182                                    op.getContext()));
183   SmallVector<StringRef> newIteratorTypes;
184   for (auto &it : llvm::enumerate(op.iterator_types())) {
185     if (insertSplitDimension == it.index())
186       newIteratorTypes.push_back(getParallelIteratorTypeName());
187     newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue());
188   }
189   // Create the new op matching the original op with an extra parallel
190   // dimension.
191   GenericOp genericOp = b.create<GenericOp>(
192       loc, TypeRange({initTensor.getType()}), newInputs,
193       ValueRange({identityTensor}), newMaps, newIteratorTypes);
194   b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
195                        genericOp.region().begin());
196 
197   // Then create a new reduction that only reduce the newly added dimension
198   // from the previous op.
199   unsigned intermRank = newOutputShape.size();
200   AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
201   SmallVector<Value> outputOperands = op.getOutputOperands();
202   SmallVector<StringRef> reductionIteratorTypes;
203   SmallVector<AffineExpr> exprs;
204   for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
205     if (insertSplitDimension == i) {
206       reductionIteratorTypes.push_back(getReductionIteratorTypeName());
207     } else {
208       exprs.push_back(b.getAffineDimExpr(i));
209       reductionIteratorTypes.push_back(getParallelIteratorTypeName());
210     }
211   }
212   AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
213   SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
214 
215   auto reduction = b.create<GenericOp>(
216       loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
217       outputOperands, reductionMaps, reductionIteratorTypes,
218       [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
219         Operation *clonedReductionOp = b.clone(*reductionOp);
220         clonedReductionOp->setOperand(0, inputs[0]);
221         clonedReductionOp->setOperand(1, inputs[1]);
222         b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
223       });
224   b.replaceOp(op, reduction.getResults());
225 
226   return SplitReductionResult{identityTensor.getDefiningOp<FillOp>(),
227                               cast<LinalgOp>(genericOp.getOperation()),
228                               reduction};
229 }
230 
231 /// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
232 /// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into
233 /// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better
234 /// done as a transform to enable better vectorization.
235 static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand,
236                                    unsigned reductionDimPos,
237                                    int64_t reductionRatio) {
238   auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
239   auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext());
240   AffineMap map = op.getTiedIndexingMap(&opOperand);
241   AffineMap idMap =
242       AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
243   AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
244   AffineMap composeMap = shiftedIdMap.replace(
245       reductionDim, reductionDim * reductionRatio + reductionDimP1,
246       shiftedIdMap.getNumDims(), /*numSymbols=*/0);
247   return map.compose(composeMap);
248 }
249 
250 static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
251                                    unsigned reductionDimPos, int64_t size) {
252   auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
253   AffineMap map = op.getTiedIndexingMap(&opOperand);
254   AffineMap idMap =
255       AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
256   AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
257   return map.compose(shiftedIdMap).insertResult(reductionDim, reductionDimPos);
258 }
259 
260 /// Core rewrite implementation.
261 FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
262     PatternRewriter &b, LinalgOp op,
263     const ControlSplitReductionFn &controlSplitReductionFn) {
264   OpBuilder::InsertionGuard guard(b);
265   b.setInsertionPoint(op);
266 
267   // Matcher part, enforce preconditions.
268   std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
269   int64_t splitFactor = control.first;
270   unsigned insertSplitDimension = control.second;
271   if (splitFactor <= 1)
272     return b.notifyMatchFailure(op, "split factor needs to be greater than 1");
273 
274   SmallVector<unsigned> dims;
275   op.getReductionDims(dims);
276   if (dims.empty())
277     return b.notifyMatchFailure(op, "needs at least 1 reduction dimension");
278 
279   unsigned reductionDimPos = dims[0];
280   SmallVector<int64_t> loopRanges = op.getStaticLoopRanges();
281   int64_t reductionDimSize = loopRanges[reductionDimPos];
282   if (reductionDimSize == ShapedType::kDynamicSize ||
283       reductionDimSize % splitFactor != 0 ||
284       insertSplitDimension >= loopRanges.size())
285     return b.notifyMatchFailure(
286         op, "first reduction dimension not divisible by split factor");
287 
288   SmallVector<Operation *> combinerOps;
289   if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
290     return b.notifyMatchFailure(op, "cannot match a reduction pattern");
291 
292   SmallVector<Attribute> neutralElements = llvm::to_vector<4>(
293       llvm::map_range(combinerOps, [&](Operation *reductionOp) {
294         return getNeutralElement(reductionOp);
295       }));
296   if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; }))
297     return b.notifyMatchFailure(op, "unknown reduction neutral");
298 
299   // TODO: relax this when multi-reduction support is available.
300   if (op.getNumOutputs() != (int)neutralElements.size())
301     return b.notifyMatchFailure(op, "expect one reduction per output");
302 
303   // Rewrite part.
304   // Step 1. Build the intermediate outputs filled with the proper
305   // neutralElements. Such outputs are of the same shape with an extra dimension
306   // inserted at `insertSplitDimension`.
307   //
308   // Consider a minimal example where `k` is reduced:
309   //     O(i, j) += I(i, j, k)
310   // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0.
311   // The compute is rewritten as:
312   //   a. O_i(kk, i, j) += I(i, j, 16 * k + kk)
313   //   b. O(i, j) += O_i(kk, i, j)
314   // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5.
315   Location loc = op->getLoc();
316   MLIRContext *context = op.getContext();
317   // For now assume outputs are 1-1 with reduction neutralElements.
318   // TODO: generalize when multi-reduction support is available.
319   SmallVector<Value> newOutputs;
320   newOutputs.reserve(op.getNumOutputs());
321   SmallVector<linalg::FillOp> fillOps;
322   fillOps.reserve(op.getNumOutputs());
323   for (auto it : llvm::zip(op.outputs(), neutralElements)) {
324     Value rankedTensor = std::get<0>(it);
325     auto t = rankedTensor.getType().cast<RankedTensorType>();
326     RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
327         reductionDimSize / splitFactor, insertSplitDimension);
328     SmallVector<Value> dims =
329         tensor::createDynamicDimValues(b, loc, rankedTensor);
330     Value initTensor = b.create<linalg::InitTensorOp>(
331         loc, dims, newT.getShape(), t.getElementType());
332     Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
333     fillOps.push_back(
334         b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor));
335     newOutputs.push_back(fillOps.back().getResult(0));
336   }
337 
338   // Step 2. Reindex / expand indexing maps.
339   // Reindex existing input indexings: k -> k * splitFactor + k'.
340   SmallVector<AffineMap> newMaps;
341   newMaps.reserve(op.getNumInputsAndOutputs() + 1);
342   for (OpOperand *o : op.getInputOperands())
343     newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
344   // Provision a new indexing for the shape-only tensor.
345   auto nDims = op.getNumLoops() + 1;
346   auto redDim = getAffineDimExpr(reductionDimPos, context);
347   auto redDimP1 = getAffineDimExpr(reductionDimPos + 1, context);
348   newMaps.push_back(AffineMap::get(nDims, 0, {redDim, redDimP1}, context));
349   // Expand existing output indexings.
350   // TODO: a subset of these may not reduce along reducePos and should be
351   // reindexed: k -> k * splitFactor + k', when multi-reduction support is
352   // available.
353   for (OpOperand *o : op.getOutputOperands())
354     newMaps.push_back(insertParallelDim(op, *o, reductionDimPos,
355                                         reductionDimSize / splitFactor));
356 
357   // Step 3. Handle operands.
358   // Compute the new input tensors.
359   auto newInputs = llvm::to_vector<4>(op.inputs());
360   // Add a single shape-only tensor to carry the dimensions without resorting to
361   // more complex inversions.
362   newInputs.push_back(b.create<linalg::InitTensorOp>(
363       loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor},
364       b.getIntegerType(1)));
365   // Output tensors are already good to go.
366 
367   // Step 4. Create the new op matching the original op with an extra parallel
368   // dimension.
369   SmallVector<StringRef> iteratorTypes =
370       llvm::to_vector<4>(op.getIteratorTypes().getAsValueRange<StringAttr>());
371   iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
372                        getParallelIteratorTypeName());
373   GenericOp genericOp =
374       b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs,
375                           newOutputs, newMaps, iteratorTypes);
376   b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
377                        genericOp.region().begin());
378   genericOp.region().front().insertArgument(reductionDimPos,
379                                             b.getIntegerType(1), loc);
380 
381   // Step 5. Create new reduction ops that only reduce the newly added
382   // dimensions from the previous op.
383   // For now assume outputs are 1-1 with reduction ops.
384   // TODO: a subset of these may not reduce in the first place and do not
385   // require a new op, when multi-reduction support is available.
386   // TODO: all results can be handled in a single GenericOp, when
387   // multi-reduction support is available.
388   SmallVector<LinalgOp> results;
389   for (auto it :
390        llvm::zip(genericOp->getResults(), op.outputs(), combinerOps)) {
391     Value reindexedOutput = std::get<0>(it);
392     Value originalOutput = std::get<1>(it);
393     auto originalOutputType = originalOutput.getType().cast<RankedTensorType>();
394     Operation *combinerOp = std::get<2>(it);
395 
396     AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
397     SmallVector<AffineMap> indexingMaps = {
398         map, map.dropResult(insertSplitDimension)};
399     SmallVector<StringRef> reductionIteratorTypes(
400         originalOutputType.getRank() + 1, getParallelIteratorTypeName());
401     reductionIteratorTypes[insertSplitDimension] =
402         getReductionIteratorTypeName();
403 
404     // clang-format off
405     auto reductionOp = b.create<GenericOp>(
406         loc,
407         originalOutputType,
408         reindexedOutput,
409         originalOutput,
410         indexingMaps,
411         reductionIteratorTypes,
412         [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) {
413           Operation *clonedReductionOp = b.clone(*combinerOp);
414           clonedReductionOp->setOperand(0, bbArgs[0]);
415           clonedReductionOp->setOperand(1, bbArgs[1]);
416           b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
417         });
418     // clang-format on
419 
420     results.push_back(reductionOp);
421   }
422 
423   // TODO: extend when multi-reduction support is available.
424   assert(fillOps.size() == results.size() && results.size() == 1);
425   b.replaceOp(op, results.front()->getResults());
426   return SplitReductionResult{fillOps.front(),
427                               cast<LinalgOp>(genericOp.getOperation()),
428                               results.front()};
429 }
430 
431 namespace {
432 
433 struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
434   /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
435   LinalgSplitReduction(MLIRContext *context,
436                        ControlSplitReductionFn controlSplitReductionFn,
437                        LinalgTransformationFilter f, PatternBenefit benefit = 1)
438       : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
439         controlSplitReductionFn(std::move(controlSplitReductionFn)),
440         filter(std::move(f)) {}
441 
442   LogicalResult matchAndRewrite(LinalgOp op,
443                                 PatternRewriter &rewriter) const override {
444     return splitReduction(rewriter, op, controlSplitReductionFn, filter);
445   }
446 
447 private:
448   ControlSplitReductionFn controlSplitReductionFn;
449   LinalgTransformationFilter filter;
450 };
451 
452 } // namespace
453 
454 void linalg::populateSplitReductionPattern(
455     RewritePatternSet &patterns,
456     const ControlSplitReductionFn &controlSplitReductionFn,
457     const LinalgTransformationFilter &f) {
458   patterns.add<LinalgSplitReduction>(patterns.getContext(),
459                                      controlSplitReductionFn, f);
460 }
461