1ff6e5508SAlex Zinenko //===- Split.cpp - Structured op splitting --------------------------------===//
2ff6e5508SAlex Zinenko //
3ff6e5508SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ff6e5508SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5ff6e5508SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ff6e5508SAlex Zinenko //
7ff6e5508SAlex Zinenko //===----------------------------------------------------------------------===//
8ff6e5508SAlex Zinenko
9ff6e5508SAlex Zinenko #include "mlir/Dialect/Affine/IR/AffineOps.h"
10ff6e5508SAlex Zinenko #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
11ff6e5508SAlex Zinenko #include "mlir/Dialect/Linalg/Utils/Utils.h"
12*a5c802a4SAlex Zinenko #include "mlir/Dialect/Utils/StaticValueUtils.h"
13ff6e5508SAlex Zinenko
14ff6e5508SAlex Zinenko #include "llvm/ADT/STLExtras.h"
15ff6e5508SAlex Zinenko
16ff6e5508SAlex Zinenko using namespace mlir;
17ff6e5508SAlex Zinenko using namespace mlir::linalg;
18ff6e5508SAlex Zinenko
19ff6e5508SAlex Zinenko /// Extract the slices of `operands` supplied to the given operation `op` such
20ff6e5508SAlex Zinenko /// that they are sufficient to execute the op for the subset of its iteration
21ff6e5508SAlex Zinenko /// space defined by `splitIterationSpace`. The subset is a part of the original
22ff6e5508SAlex Zinenko /// iteration space split at the given `dimension`. If `offset` is provided, it
23ff6e5508SAlex Zinenko /// indicates the iterator value at which the dimension has been split and
24ff6e5508SAlex Zinenko /// requires the "high" part starting at the given offset of the operands to be
25ff6e5508SAlex Zinenko /// generated; otherwise, the "low" part with no offset is generated. Note that
26ff6e5508SAlex Zinenko /// `operands` are not necessarily the actual operands of `op`.
27ff6e5508SAlex Zinenko static SmallVector<Value>
getOperandSlices(RewriterBase & b,Location loc,LinalgOp op,ValueRange splitIterationSpace,ValueRange operands,unsigned dimension,Value offset=nullptr)28*a5c802a4SAlex Zinenko getOperandSlices(RewriterBase &b, Location loc, LinalgOp op,
29ff6e5508SAlex Zinenko ValueRange splitIterationSpace, ValueRange operands,
30ff6e5508SAlex Zinenko unsigned dimension, Value offset = nullptr) {
31ff6e5508SAlex Zinenko SmallVector<Value> slices;
32ff6e5508SAlex Zinenko slices.reserve(op.getNumInputsAndOutputs());
33ff6e5508SAlex Zinenko for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
34ff6e5508SAlex Zinenko auto type = opOperand->get().getType().dyn_cast<ShapedType>();
35ff6e5508SAlex Zinenko AffineMap indexing = op.getTiedIndexingMap(opOperand);
36ff6e5508SAlex Zinenko
37ff6e5508SAlex Zinenko // If the type is not sliceable, or the slice is requested along the
38ff6e5508SAlex Zinenko // dimension that is not used in indexing this type, just use the entire
39ff6e5508SAlex Zinenko // operand.
40ff6e5508SAlex Zinenko if (!type || dimension >= indexing.getNumDims() ||
41ff6e5508SAlex Zinenko !indexing.isFunctionOfDim(dimension)) {
42ff6e5508SAlex Zinenko slices.push_back(opOperand->get());
43ff6e5508SAlex Zinenko continue;
44ff6e5508SAlex Zinenko }
45ff6e5508SAlex Zinenko
46*a5c802a4SAlex Zinenko SmallVector<OpFoldResult> sizes;
47*a5c802a4SAlex Zinenko sizes.reserve(indexing.getNumResults());
48*a5c802a4SAlex Zinenko for (AffineExpr dimIndexing : indexing.getResults()) {
49*a5c802a4SAlex Zinenko sizes.push_back(makeComposedFoldedAffineApply(
50*a5c802a4SAlex Zinenko b, loc, dimIndexing,
51*a5c802a4SAlex Zinenko getAsOpFoldResult(llvm::to_vector(splitIterationSpace))));
52*a5c802a4SAlex Zinenko }
53*a5c802a4SAlex Zinenko SmallVector<OpFoldResult> offsets(type.getRank(), b.getIndexAttr(0));
54*a5c802a4SAlex Zinenko SmallVector<OpFoldResult> strides(type.getRank(), b.getIndexAttr(1));
55ff6e5508SAlex Zinenko
56ff6e5508SAlex Zinenko if (offset) {
57ff6e5508SAlex Zinenko offsets[dimension] = offset;
58*a5c802a4SAlex Zinenko offsets = applyMapToValues(b, loc, indexing, offsets);
59ff6e5508SAlex Zinenko }
60ff6e5508SAlex Zinenko
61*a5c802a4SAlex Zinenko slices.push_back(createSlice(b, loc,
62ff6e5508SAlex Zinenko operands[opOperand->getOperandNumber()],
63*a5c802a4SAlex Zinenko offsets, sizes, strides));
64ff6e5508SAlex Zinenko }
65ff6e5508SAlex Zinenko
66ff6e5508SAlex Zinenko return slices;
67ff6e5508SAlex Zinenko }
68ff6e5508SAlex Zinenko
69ff6e5508SAlex Zinenko /// Creates a part of the given `op` split along the iteration space `dimension`
70ff6e5508SAlex Zinenko /// with the given `size` and an optional `offset` (default 0). Makes slices
71ff6e5508SAlex Zinenko /// of operands, using the input operands of the original op and the output
72ff6e5508SAlex Zinenko /// operands provided as `resultOperands`. Expects `splitIterationSpace` to be
73ff6e5508SAlex Zinenko /// a list of values representing the shape of the iteration space of the
74ff6e5508SAlex Zinenko /// original op and updates it to be the iteration space of the curent part.
75ff6e5508SAlex Zinenko /// Returns the split-out op as well as the output operand values updated with
76ff6e5508SAlex Zinenko /// the partial results produced by this op through `results`.
77*a5c802a4SAlex Zinenko static LinalgOp
createSplitPart(RewriterBase & b,Location loc,LinalgOp op,ValueRange resultOperands,llvm::MutableArrayRef<Value> splitIterationSpace,unsigned dimension,OpFoldResult size,SmallVectorImpl<Value> & results,Value offset=nullptr)78*a5c802a4SAlex Zinenko createSplitPart(RewriterBase &b, Location loc, LinalgOp op,
79*a5c802a4SAlex Zinenko ValueRange resultOperands,
80*a5c802a4SAlex Zinenko llvm::MutableArrayRef<Value> splitIterationSpace,
81*a5c802a4SAlex Zinenko unsigned dimension, OpFoldResult size,
82*a5c802a4SAlex Zinenko SmallVectorImpl<Value> &results, Value offset = nullptr) {
83*a5c802a4SAlex Zinenko ImplicitLocOpBuilder implicit(op.getLoc(), b);
84*a5c802a4SAlex Zinenko splitIterationSpace[dimension] = materializeOpFoldResult(implicit, size);
85ff6e5508SAlex Zinenko SmallVector<Value> operands = llvm::to_vector(
86ff6e5508SAlex Zinenko llvm::map_range(op.getInputOperands(),
87ff6e5508SAlex Zinenko [](OpOperand *opOperand) { return opOperand->get(); }));
88ff6e5508SAlex Zinenko llvm::append_range(operands, resultOperands);
89*a5c802a4SAlex Zinenko operands = getOperandSlices(b, loc, op, splitIterationSpace, operands,
90ff6e5508SAlex Zinenko dimension, offset);
91*a5c802a4SAlex Zinenko Operation *part =
92*a5c802a4SAlex Zinenko op.clone(b, loc, getTensorOutputTypes(op, operands), operands);
93*a5c802a4SAlex Zinenko results = insertSlicesBack(b, loc, op, operands, part->getResults());
94ff6e5508SAlex Zinenko return cast<LinalgOp>(part);
95ff6e5508SAlex Zinenko }
96ff6e5508SAlex Zinenko
splitOp(RewriterBase & rewriter,LinalgOp op,unsigned dimension,OpFoldResult splitPoint)97ff6e5508SAlex Zinenko std::pair<LinalgOp, LinalgOp> linalg::splitOp(RewriterBase &rewriter,
98ff6e5508SAlex Zinenko LinalgOp op, unsigned dimension,
99ff6e5508SAlex Zinenko OpFoldResult splitPoint) {
100ff6e5508SAlex Zinenko // Bail out on dimension overflow.
101ff6e5508SAlex Zinenko if (dimension >= op.getNumLoops())
102ff6e5508SAlex Zinenko return std::make_pair(op, LinalgOp());
103ff6e5508SAlex Zinenko
104ff6e5508SAlex Zinenko // Compute the iteration space size as values.
105ff6e5508SAlex Zinenko SmallVector<Value, 4> allShapes =
106*a5c802a4SAlex Zinenko op.createFlatListOfOperandDims(rewriter, op.getLoc());
107ff6e5508SAlex Zinenko AffineMap shapesToLoops = op.getShapesToLoopsMap();
108ff6e5508SAlex Zinenko SmallVector<Value, 4> iterationSpaceShapes =
109*a5c802a4SAlex Zinenko applyMapToValues(rewriter, op.getLoc(), shapesToLoops, allShapes);
110ff6e5508SAlex Zinenko
111ff6e5508SAlex Zinenko // Update the iteration space to have `splitPoint` as the size of `dimension`
112ff6e5508SAlex Zinenko // and use it to slice operands and results for a new, smaller instance of the
113ff6e5508SAlex Zinenko // `op`. Adjust the size if necessary to prevent overflows. Insert the partial
114ff6e5508SAlex Zinenko // results back.
115*a5c802a4SAlex Zinenko OpFoldResult dimSize = getAsOpFoldResult(iterationSpaceShapes[dimension]);
116*a5c802a4SAlex Zinenko OpFoldResult minSplitPoint = makeComposedFoldedAffineMin(
117*a5c802a4SAlex Zinenko rewriter, op->getLoc(),
118*a5c802a4SAlex Zinenko AffineMap::getMultiDimIdentityMap(/*numDims=*/2, rewriter.getContext()),
119*a5c802a4SAlex Zinenko {splitPoint, dimSize});
120ff6e5508SAlex Zinenko SmallVector<Value> splitIterationSpace =
121ff6e5508SAlex Zinenko llvm::to_vector(iterationSpaceShapes);
122ff6e5508SAlex Zinenko SmallVector<Value> originalResults = llvm::to_vector(
123ff6e5508SAlex Zinenko llvm::map_range(op.getOutputOperands(),
124ff6e5508SAlex Zinenko [](OpOperand *opOperand) { return opOperand->get(); }));
125ff6e5508SAlex Zinenko SmallVector<Value> firstResults;
126*a5c802a4SAlex Zinenko LinalgOp first = createSplitPart(rewriter, op.getLoc(), op, originalResults,
127*a5c802a4SAlex Zinenko splitIterationSpace, dimension,
128*a5c802a4SAlex Zinenko minSplitPoint, firstResults);
129ff6e5508SAlex Zinenko
130ff6e5508SAlex Zinenko // Update the iteration space to cover the remaining part of the original
131ff6e5508SAlex Zinenko // space, then create another instance of the `op` in that space. The size of
132ff6e5508SAlex Zinenko // the remaining part may become zero, but is never negative because of the
133ff6e5508SAlex Zinenko // adjustment above.
134*a5c802a4SAlex Zinenko AffineExpr d0 = rewriter.getAffineDimExpr(0);
135*a5c802a4SAlex Zinenko AffineExpr d1 = rewriter.getAffineDimExpr(1);
136*a5c802a4SAlex Zinenko OpFoldResult remainingSize = makeComposedFoldedAffineApply(
137*a5c802a4SAlex Zinenko rewriter, op.getLoc(), d0 - d1, {dimSize, minSplitPoint});
138ff6e5508SAlex Zinenko SmallVector<Value> secondResults;
139*a5c802a4SAlex Zinenko ImplicitLocOpBuilder implicit(op.getLoc(), rewriter);
140*a5c802a4SAlex Zinenko Value splitPointValue = materializeOpFoldResult(implicit, minSplitPoint);
141*a5c802a4SAlex Zinenko LinalgOp second = createSplitPart(
142*a5c802a4SAlex Zinenko rewriter, op.getLoc(), op, firstResults, splitIterationSpace, dimension,
143*a5c802a4SAlex Zinenko remainingSize, secondResults, splitPointValue);
144ff6e5508SAlex Zinenko
145ff6e5508SAlex Zinenko // Fixup the linalg.index results in the second part.
146ff6e5508SAlex Zinenko SmallVector<Value> ivAdditions;
147ff6e5508SAlex Zinenko ivAdditions.resize(splitIterationSpace.size());
148ff6e5508SAlex Zinenko ivAdditions[dimension] = splitPointValue;
14981b62f7fSAlex Zinenko linalg::offsetIndices(rewriter, cast<LinalgOp>(second), ivAdditions);
150ff6e5508SAlex Zinenko
151ff6e5508SAlex Zinenko // Replace the original op with the results of the two newly created ops.
152ff6e5508SAlex Zinenko rewriter.replaceOp(op, secondResults);
153ff6e5508SAlex Zinenko return std::make_pair(first, second);
154ff6e5508SAlex Zinenko }
155