1 //===- Split.cpp - Structured op splitting --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
11 #include "mlir/Dialect/Linalg/Utils/Utils.h"
12 #include "mlir/Dialect/Utils/StaticValueUtils.h"
13
14 #include "llvm/ADT/STLExtras.h"
15
16 using namespace mlir;
17 using namespace mlir::linalg;
18
19 /// Extract the slices of `operands` supplied to the given operation `op` such
20 /// that they are sufficient to execute the op for the subset of its iteration
21 /// space defined by `splitIterationSpace`. The subset is a part of the original
22 /// iteration space split at the given `dimension`. If `offset` is provided, it
23 /// indicates the iterator value at which the dimension has been split and
24 /// requires the "high" part starting at the given offset of the operands to be
25 /// generated; otherwise, the "low" part with no offset is generated. Note that
26 /// `operands` are not necessarily the actual operands of `op`.
27 static SmallVector<Value>
getOperandSlices(RewriterBase & b,Location loc,LinalgOp op,ValueRange splitIterationSpace,ValueRange operands,unsigned dimension,Value offset=nullptr)28 getOperandSlices(RewriterBase &b, Location loc, LinalgOp op,
29 ValueRange splitIterationSpace, ValueRange operands,
30 unsigned dimension, Value offset = nullptr) {
31 SmallVector<Value> slices;
32 slices.reserve(op.getNumInputsAndOutputs());
33 for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
34 auto type = opOperand->get().getType().dyn_cast<ShapedType>();
35 AffineMap indexing = op.getTiedIndexingMap(opOperand);
36
37 // If the type is not sliceable, or the slice is requested along the
38 // dimension that is not used in indexing this type, just use the entire
39 // operand.
40 if (!type || dimension >= indexing.getNumDims() ||
41 !indexing.isFunctionOfDim(dimension)) {
42 slices.push_back(opOperand->get());
43 continue;
44 }
45
46 SmallVector<OpFoldResult> sizes;
47 sizes.reserve(indexing.getNumResults());
48 for (AffineExpr dimIndexing : indexing.getResults()) {
49 sizes.push_back(makeComposedFoldedAffineApply(
50 b, loc, dimIndexing,
51 getAsOpFoldResult(llvm::to_vector(splitIterationSpace))));
52 }
53 SmallVector<OpFoldResult> offsets(type.getRank(), b.getIndexAttr(0));
54 SmallVector<OpFoldResult> strides(type.getRank(), b.getIndexAttr(1));
55
56 if (offset) {
57 offsets[dimension] = offset;
58 offsets = applyMapToValues(b, loc, indexing, offsets);
59 }
60
61 slices.push_back(createSlice(b, loc,
62 operands[opOperand->getOperandNumber()],
63 offsets, sizes, strides));
64 }
65
66 return slices;
67 }
68
69 /// Creates a part of the given `op` split along the iteration space `dimension`
70 /// with the given `size` and an optional `offset` (default 0). Makes slices
71 /// of operands, using the input operands of the original op and the output
72 /// operands provided as `resultOperands`. Expects `splitIterationSpace` to be
73 /// a list of values representing the shape of the iteration space of the
74 /// original op and updates it to be the iteration space of the curent part.
75 /// Returns the split-out op as well as the output operand values updated with
76 /// the partial results produced by this op through `results`.
77 static LinalgOp
createSplitPart(RewriterBase & b,Location loc,LinalgOp op,ValueRange resultOperands,llvm::MutableArrayRef<Value> splitIterationSpace,unsigned dimension,OpFoldResult size,SmallVectorImpl<Value> & results,Value offset=nullptr)78 createSplitPart(RewriterBase &b, Location loc, LinalgOp op,
79 ValueRange resultOperands,
80 llvm::MutableArrayRef<Value> splitIterationSpace,
81 unsigned dimension, OpFoldResult size,
82 SmallVectorImpl<Value> &results, Value offset = nullptr) {
83 ImplicitLocOpBuilder implicit(op.getLoc(), b);
84 splitIterationSpace[dimension] = materializeOpFoldResult(implicit, size);
85 SmallVector<Value> operands = llvm::to_vector(
86 llvm::map_range(op.getInputOperands(),
87 [](OpOperand *opOperand) { return opOperand->get(); }));
88 llvm::append_range(operands, resultOperands);
89 operands = getOperandSlices(b, loc, op, splitIterationSpace, operands,
90 dimension, offset);
91 Operation *part =
92 op.clone(b, loc, getTensorOutputTypes(op, operands), operands);
93 results = insertSlicesBack(b, loc, op, operands, part->getResults());
94 return cast<LinalgOp>(part);
95 }
96
splitOp(RewriterBase & rewriter,LinalgOp op,unsigned dimension,OpFoldResult splitPoint)97 std::pair<LinalgOp, LinalgOp> linalg::splitOp(RewriterBase &rewriter,
98 LinalgOp op, unsigned dimension,
99 OpFoldResult splitPoint) {
100 // Bail out on dimension overflow.
101 if (dimension >= op.getNumLoops())
102 return std::make_pair(op, LinalgOp());
103
104 // Compute the iteration space size as values.
105 SmallVector<Value, 4> allShapes =
106 op.createFlatListOfOperandDims(rewriter, op.getLoc());
107 AffineMap shapesToLoops = op.getShapesToLoopsMap();
108 SmallVector<Value, 4> iterationSpaceShapes =
109 applyMapToValues(rewriter, op.getLoc(), shapesToLoops, allShapes);
110
111 // Update the iteration space to have `splitPoint` as the size of `dimension`
112 // and use it to slice operands and results for a new, smaller instance of the
113 // `op`. Adjust the size if necessary to prevent overflows. Insert the partial
114 // results back.
115 OpFoldResult dimSize = getAsOpFoldResult(iterationSpaceShapes[dimension]);
116 OpFoldResult minSplitPoint = makeComposedFoldedAffineMin(
117 rewriter, op->getLoc(),
118 AffineMap::getMultiDimIdentityMap(/*numDims=*/2, rewriter.getContext()),
119 {splitPoint, dimSize});
120 SmallVector<Value> splitIterationSpace =
121 llvm::to_vector(iterationSpaceShapes);
122 SmallVector<Value> originalResults = llvm::to_vector(
123 llvm::map_range(op.getOutputOperands(),
124 [](OpOperand *opOperand) { return opOperand->get(); }));
125 SmallVector<Value> firstResults;
126 LinalgOp first = createSplitPart(rewriter, op.getLoc(), op, originalResults,
127 splitIterationSpace, dimension,
128 minSplitPoint, firstResults);
129
130 // Update the iteration space to cover the remaining part of the original
131 // space, then create another instance of the `op` in that space. The size of
132 // the remaining part may become zero, but is never negative because of the
133 // adjustment above.
134 AffineExpr d0 = rewriter.getAffineDimExpr(0);
135 AffineExpr d1 = rewriter.getAffineDimExpr(1);
136 OpFoldResult remainingSize = makeComposedFoldedAffineApply(
137 rewriter, op.getLoc(), d0 - d1, {dimSize, minSplitPoint});
138 SmallVector<Value> secondResults;
139 ImplicitLocOpBuilder implicit(op.getLoc(), rewriter);
140 Value splitPointValue = materializeOpFoldResult(implicit, minSplitPoint);
141 LinalgOp second = createSplitPart(
142 rewriter, op.getLoc(), op, firstResults, splitIterationSpace, dimension,
143 remainingSize, secondResults, splitPointValue);
144
145 // Fixup the linalg.index results in the second part.
146 SmallVector<Value> ivAdditions;
147 ivAdditions.resize(splitIterationSpace.size());
148 ivAdditions[dimension] = splitPointValue;
149 linalg::offsetIndices(rewriter, cast<LinalgOp>(second), ivAdditions);
150
151 // Replace the original op with the results of the two newly created ops.
152 rewriter.replaceOp(op, secondResults);
153 return std::make_pair(first, second);
154 }
155