1 //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements linalg transformation to break a reduction dimension
10 // between a parallel and a reduction dimension.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
16 #include "mlir/Analysis/SliceAnalysis.h"
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/Dialect/Linalg/Utils/Utils.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Dialect/Tensor/Utils/Utils.h"
24 #include "mlir/IR/PatternMatch.h"
25 
26 using namespace mlir;
27 using namespace mlir::linalg;
28 
29 /// Return the identity numeric value associated to the give op.
30 static Attribute getNeutralElement(Operation *op) {
31   // Builder only used as helper for attribute creation.
32   OpBuilder b(op->getContext());
33   Type resultType = op->getResult(0).getType();
34   if (auto floatType = resultType.dyn_cast<FloatType>()) {
35     const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
36     if (isa<arith::AddFOp>(op))
37       return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
38     if (isa<arith::MulFOp>(op))
39       return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
40     if (isa<arith::MaxFOp>(op))
41       return b.getFloatAttr(resultType,
42                             llvm::APFloat::getLargest(semantic, true));
43     if (isa<arith::MinFOp>(op))
44       return b.getFloatAttr(resultType,
45                             llvm::APFloat::getLargest(semantic, true));
46     return Attribute();
47   }
48   if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
49     return b.getIntegerAttr(resultType, 0);
50   if (isa<arith::AndIOp>(op))
51     return b.getIntegerAttr(resultType, -1);
52   if (isa<arith::MaxSIOp>(op))
53     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
54   if (isa<arith::MinSIOp>(op))
55     return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
56   if (isa<arith::MulIOp>(op))
57     return b.getIntegerAttr(resultType, 1);
58   return Attribute();
59 }
60 
61 FailureOr<LinalgOp> mlir::linalg::splitReduction(
62     PatternRewriter &b, LinalgOp op,
63     const ControlSplitReductionFn &controlSplitReductionFn,
64     const LinalgTransformationFilter &filter, bool useAlloc) {
65   if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
66       op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
67       !op.hasOnlyProjectedPermutations())
68     return b.notifyMatchFailure(op, "precondition not met");
69 
70   FailureOr<SplitReductionResult> res =
71       splitReduction(b, op, controlSplitReductionFn, useAlloc);
72   if (failed(res))
73     return failure();
74 
75   filter.replaceLinalgTransformationFilter(b, res->splitLinalgOp);
76   filter.replaceLinalgTransformationFilter(b, res->resultCombiningLinalgOp);
77 
78   return res->splitLinalgOp;
79 }
80 
81 FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
82     PatternRewriter &b, LinalgOp op,
83     const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
84   OpBuilder::InsertionGuard guard(b);
85   b.setInsertionPoint(op);
86 
87   std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
88   int64_t ratio = control.first;
89   unsigned insertSplitDimension = control.second;
90   if (ratio <= 1)
91     return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
92 
93   SmallVector<unsigned> dims;
94   op.getReductionDims(dims);
95   assert(dims.size() == 1);
96   unsigned reductionDim = dims[0];
97   SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
98   int64_t reductionDimSize = loopRanges[reductionDim];
99   if (reductionDimSize == ShapedType::kDynamicSize ||
100       reductionDimSize % ratio != 0 ||
101       insertSplitDimension >= loopRanges.size())
102     return b.notifyMatchFailure(
103         op, "Reduction dimension not divisible by split ratio");
104 
105   SmallVector<Operation *, 4> combinerOps;
106   if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
107       combinerOps.size() != 1)
108     return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
109 
110   Operation *reductionOp = combinerOps[0];
111   Attribute identity = getNeutralElement(reductionOp);
112   if (!identity)
113     return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
114 
115   Location loc = op->getLoc();
116   SmallVector<Value> newInputs;
117   SmallVector<AffineMap> newMaps;
118   // Calculate the new shapes and indexing maps of the input operands.
119   for (OpOperand *operand : op.getInputOperands()) {
120     AffineMap map = op.getTiedIndexingMap(operand);
121     SmallVector<int64_t> newShape;
122     SmallVector<AffineExpr> exprs;
123     SmallVector<ReassociationIndices> reassociation;
124     unsigned index = 0;
125     for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
126       unsigned dim = map.getDimPosition(idx);
127       if (reductionDim == dim) {
128         newShape.push_back(ratio);
129         newShape.push_back(op.getShape(operand)[idx] / ratio);
130         reassociation.push_back({index++, index++});
131         exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
132         exprs.push_back(
133             b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
134         continue;
135       }
136       newShape.push_back(op.getShape(operand)[idx]);
137       exprs.push_back(
138           b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
139       reassociation.push_back({index++});
140     }
141     newMaps.push_back(
142         AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
143     // If the shape is unchanged the input doesn't change.
144     if (newShape == op.getShape(operand)) {
145       newInputs.push_back(operand->get());
146       continue;
147     }
148     Type newType = RankedTensorType::get(
149         newShape,
150         operand->get().getType().cast<RankedTensorType>().getElementType());
151     Value newInput = b.create<tensor::ExpandShapeOp>(
152         loc, newType, operand->get(), reassociation);
153     newInputs.push_back(newInput);
154   }
155 
156   // Calculate the new output map and shape, we insert the new dimension based
157   // on the index returned by `controlSplitReductionFn`.
158   SmallVector<int64_t> newOutputShape;
159   AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0));
160   ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0));
161   SmallVector<AffineExpr> outputExpr;
162   for (unsigned idx :
163        llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) {
164     if (idx == insertSplitDimension) {
165       newOutputShape.push_back(ratio);
166       outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
167       continue;
168     }
169     unsigned oldDim = idx < insertSplitDimension ? idx : idx - 1;
170     newOutputShape.push_back(oldShape[oldDim]);
171     unsigned dim = oldOutputMap.getDimPosition(oldDim);
172     outputExpr.push_back(
173         b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
174   }
175   Value initOrAllocTensor;
176   if (useAlloc) {
177     initOrAllocTensor = b.create<bufferization::AllocTensorOp>(
178         loc,
179         RankedTensorType::get(newOutputShape,
180                               op.getRegionOutputArgs()[0].getType()),
181         ValueRange{});
182   } else {
183     initOrAllocTensor = b.create<linalg::InitTensorOp>(
184         loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
185   }
186   Value constantOp = b.create<arith::ConstantOp>(loc, identity);
187   Value identityTensor =
188       b.create<linalg::FillOp>(op->getLoc(), constantOp, initOrAllocTensor)
189           .getResult(0);
190 
191   newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
192                                    op.getContext()));
193   SmallVector<StringRef> newIteratorTypes;
194   for (auto &it : llvm::enumerate(op.iterator_types())) {
195     if (insertSplitDimension == it.index())
196       newIteratorTypes.push_back(getParallelIteratorTypeName());
197     newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue());
198   }
199   // Create the new op matching the original op with an extra parallel
200   // dimension.
201   GenericOp genericOp = b.create<GenericOp>(
202       loc, TypeRange({initOrAllocTensor.getType()}), newInputs,
203       ValueRange({identityTensor}), newMaps, newIteratorTypes);
204   b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
205                        genericOp.region().begin());
206 
207   // Then create a new reduction that only reduce the newly added dimension
208   // from the previous op.
209   unsigned intermRank = newOutputShape.size();
210   AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
211   SmallVector<Value> outputOperands = op.getOutputOperands();
212   SmallVector<StringRef> reductionIteratorTypes;
213   SmallVector<AffineExpr> exprs;
214   for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
215     if (insertSplitDimension == i) {
216       reductionIteratorTypes.push_back(getReductionIteratorTypeName());
217     } else {
218       exprs.push_back(b.getAffineDimExpr(i));
219       reductionIteratorTypes.push_back(getParallelIteratorTypeName());
220     }
221   }
222   AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
223   SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
224 
225   auto reduction = b.create<GenericOp>(
226       loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
227       outputOperands, reductionMaps, reductionIteratorTypes,
228       [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
229         Operation *clonedReductionOp = b.clone(*reductionOp);
230         clonedReductionOp->setOperand(0, inputs[0]);
231         clonedReductionOp->setOperand(1, inputs[1]);
232         b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
233       });
234   b.replaceOp(op, reduction.getResults());
235 
236   return SplitReductionResult{
237       initOrAllocTensor.getDefiningOp(), identityTensor.getDefiningOp<FillOp>(),
238       cast<LinalgOp>(genericOp.getOperation()), reduction};
239 }
240 
241 /// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
242 /// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into
243 /// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better
244 /// done as a transform to enable better vectorization.
245 static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand,
246                                    unsigned reductionDimPos,
247                                    int64_t reductionRatio) {
248   auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
249   auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext());
250   AffineMap map = op.getTiedIndexingMap(&opOperand);
251   AffineMap idMap =
252       AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
253   AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
254   AffineMap composeMap = shiftedIdMap.replace(
255       reductionDim, reductionDim * reductionRatio + reductionDimP1,
256       shiftedIdMap.getNumDims(), /*numSymbols=*/0);
257   return map.compose(composeMap);
258 }
259 
260 static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
261                                    unsigned reductionDimPos, int64_t size) {
262   auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
263   AffineMap map = op.getTiedIndexingMap(&opOperand);
264   AffineMap idMap =
265       AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
266   AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
267   return map.compose(shiftedIdMap).insertResult(reductionDim, reductionDimPos);
268 }
269 
270 /// Core rewrite implementation.
271 FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
272     PatternRewriter &b, LinalgOp op,
273     const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
274   OpBuilder::InsertionGuard guard(b);
275   b.setInsertionPoint(op);
276 
277   // Matcher part, enforce preconditions.
278   std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
279   int64_t splitFactor = control.first;
280   unsigned insertSplitDimension = control.second;
281   if (splitFactor <= 1)
282     return b.notifyMatchFailure(op, "split factor needs to be greater than 1");
283 
284   SmallVector<unsigned> dims;
285   op.getReductionDims(dims);
286   if (dims.empty())
287     return b.notifyMatchFailure(op, "needs at least 1 reduction dimension");
288 
289   unsigned reductionDimPos = dims[0];
290   SmallVector<int64_t> loopRanges = op.getStaticLoopRanges();
291   int64_t reductionDimSize = loopRanges[reductionDimPos];
292   if (reductionDimSize == ShapedType::kDynamicSize ||
293       reductionDimSize % splitFactor != 0 ||
294       insertSplitDimension >= loopRanges.size())
295     return b.notifyMatchFailure(
296         op, "first reduction dimension not divisible by split factor");
297 
298   SmallVector<Operation *> combinerOps;
299   if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
300     return b.notifyMatchFailure(op, "cannot match a reduction pattern");
301 
302   SmallVector<Attribute> neutralElements = llvm::to_vector<4>(
303       llvm::map_range(combinerOps, [&](Operation *reductionOp) {
304         return getNeutralElement(reductionOp);
305       }));
306   if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; }))
307     return b.notifyMatchFailure(op, "unknown reduction neutral");
308 
309   // TODO: relax this when multi-reduction support is available.
310   if (op.getNumOutputs() != static_cast<int64_t>(neutralElements.size()))
311     return b.notifyMatchFailure(op, "expect one reduction per output");
312 
313   // Rewrite part.
314   // Step 1. Build the intermediate outputs filled with the proper
315   // neutralElements. Such outputs are of the same shape with an extra dimension
316   // inserted at `insertSplitDimension`.
317   //
318   // Consider a minimal example where `k` is reduced:
319   //     O(i, j) += I(i, j, k)
320   // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0.
321   // The compute is rewritten as:
322   //   a. O_i(kk, i, j) += I(i, j, 16 * k + kk)
323   //   b. O(i, j) += O_i(kk, i, j)
324   // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5.
325   Location loc = op->getLoc();
326   MLIRContext *context = op.getContext();
327   // For now assume outputs are 1-1 with reduction neutralElements.
328   // TODO: generalize when multi-reduction support is available.
329   SmallVector<Value> newOutputs;
330   newOutputs.reserve(op.getNumOutputs());
331   SmallVector<Operation *> initOrAllocTensorOps;
332   SmallVector<linalg::FillOp> fillOps;
333   fillOps.reserve(op.getNumOutputs());
334   for (auto it : llvm::zip(op.outputs(), neutralElements)) {
335     Value rankedTensor = std::get<0>(it);
336     auto t = rankedTensor.getType().cast<RankedTensorType>();
337     RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
338         reductionDimSize / splitFactor, insertSplitDimension);
339     SmallVector<Value> dims =
340         tensor::createDynamicDimValues(b, loc, rankedTensor);
341     Value initOrAllocTensor;
342     if (useAlloc) {
343       initOrAllocTensor =
344           b.create<bufferization::AllocTensorOp>(loc, newT, dims);
345     } else {
346       initOrAllocTensor = b.create<linalg::InitTensorOp>(
347           loc, dims, newT.getShape(), t.getElementType());
348     }
349     Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
350     fillOps.push_back(
351         b.create<linalg::FillOp>(op->getLoc(), constantOp, initOrAllocTensor));
352     newOutputs.push_back(fillOps.back().getResult(0));
353     initOrAllocTensorOps.push_back(initOrAllocTensor.getDefiningOp());
354   }
355 
356   // Step 2. Reindex / expand indexing maps.
357   // Reindex existing input indexings: k -> k * splitFactor + k'.
358   SmallVector<AffineMap> newMaps;
359   newMaps.reserve(op.getNumInputsAndOutputs() + 1);
360   for (OpOperand *o : op.getInputOperands())
361     newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
362   // Provision a new indexing for the shape-only tensor.
363   auto nDims = op.getNumLoops() + 1;
364   auto redDim = getAffineDimExpr(reductionDimPos, context);
365   auto redDimP1 = getAffineDimExpr(reductionDimPos + 1, context);
366   newMaps.push_back(AffineMap::get(nDims, 0, {redDim, redDimP1}, context));
367   // Expand existing output indexings.
368   // TODO: a subset of these may not reduce along reducePos and should be
369   // reindexed: k -> k * splitFactor + k', when multi-reduction support is
370   // available.
371   for (OpOperand *o : op.getOutputOperands())
372     newMaps.push_back(insertParallelDim(op, *o, reductionDimPos,
373                                         reductionDimSize / splitFactor));
374 
375   // Step 3. Handle operands.
376   // Compute the new input tensors.
377   auto newInputs = llvm::to_vector<4>(op.inputs());
378   // Add a single shape-only tensor to carry the dimensions without resorting to
379   // more complex inversions.
380   newInputs.push_back(b.create<linalg::InitTensorOp>(
381       loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor},
382       b.getIntegerType(1)));
383   // Output tensors are already good to go.
384 
385   // Step 4. Create the new op matching the original op with an extra parallel
386   // dimension.
387   SmallVector<StringRef> iteratorTypes =
388       llvm::to_vector<4>(op.getIteratorTypes().getAsValueRange<StringAttr>());
389   iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
390                        getParallelIteratorTypeName());
391   GenericOp genericOp =
392       b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs,
393                           newOutputs, newMaps, iteratorTypes);
394   b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
395                        genericOp.region().begin());
396   genericOp.region().front().insertArgument(reductionDimPos,
397                                             b.getIntegerType(1), loc);
398 
399   // Step 5. Create new reduction ops that only reduce the newly added
400   // dimensions from the previous op.
401   // For now assume outputs are 1-1 with reduction ops.
402   // TODO: a subset of these may not reduce in the first place and do not
403   // require a new op, when multi-reduction support is available.
404   // TODO: all results can be handled in a single GenericOp, when
405   // multi-reduction support is available.
406   SmallVector<LinalgOp> results;
407   for (auto it :
408        llvm::zip(genericOp->getResults(), op.outputs(), combinerOps)) {
409     Value reindexedOutput = std::get<0>(it);
410     Value originalOutput = std::get<1>(it);
411     auto originalOutputType = originalOutput.getType().cast<RankedTensorType>();
412     Operation *combinerOp = std::get<2>(it);
413 
414     AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
415     SmallVector<AffineMap> indexingMaps = {
416         map, map.dropResult(insertSplitDimension)};
417     SmallVector<StringRef> reductionIteratorTypes(
418         originalOutputType.getRank() + 1, getParallelIteratorTypeName());
419     reductionIteratorTypes[insertSplitDimension] =
420         getReductionIteratorTypeName();
421 
422     // clang-format off
423     auto reductionOp = b.create<GenericOp>(
424         loc,
425         originalOutputType,
426         reindexedOutput,
427         originalOutput,
428         indexingMaps,
429         reductionIteratorTypes,
430         [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) {
431           Operation *clonedReductionOp = b.clone(*combinerOp);
432           clonedReductionOp->setOperand(0, bbArgs[0]);
433           clonedReductionOp->setOperand(1, bbArgs[1]);
434           b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
435         });
436     // clang-format on
437 
438     results.push_back(reductionOp);
439   }
440 
441   // TODO: extend when multi-reduction support is available.
442   assert(fillOps.size() == results.size() && results.size() == 1);
443   b.replaceOp(op, results.front()->getResults());
444   return SplitReductionResult{initOrAllocTensorOps.front(), fillOps.front(),
445                               cast<LinalgOp>(genericOp.getOperation()),
446                               results.front()};
447 }
448 
449 namespace {
450 
451 struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
452   /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
453   LinalgSplitReduction(MLIRContext *context,
454                        ControlSplitReductionFn controlSplitReductionFn,
455                        LinalgTransformationFilter f, bool useAlloc = false,
456                        PatternBenefit benefit = 1)
457       : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
458         controlSplitReductionFn(std::move(controlSplitReductionFn)),
459         useAlloc(useAlloc), filter(std::move(f)) {}
460 
461   LogicalResult matchAndRewrite(LinalgOp op,
462                                 PatternRewriter &rewriter) const override {
463     return splitReduction(rewriter, op, controlSplitReductionFn, filter,
464                           useAlloc);
465   }
466 
467 private:
468   ControlSplitReductionFn controlSplitReductionFn;
469   bool useAlloc;
470   LinalgTransformationFilter filter;
471 };
472 
473 } // namespace
474 
475 void linalg::populateSplitReductionPattern(
476     RewritePatternSet &patterns,
477     const ControlSplitReductionFn &controlSplitReductionFn,
478     const LinalgTransformationFilter &f, bool useAlloc) {
479   patterns.add<LinalgSplitReduction>(patterns.getContext(),
480                                      controlSplitReductionFn, f, useAlloc);
481 }
482