1 //===- Split.cpp - Structured op splitting --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
11 #include "mlir/Dialect/Linalg/Utils/Utils.h"
12 
13 #include "llvm/ADT/STLExtras.h"
14 
15 using namespace mlir;
16 using namespace mlir::linalg;
17 
18 /// Turns an OpFoldResult into a value, creating an index-typed constant if
19 /// necessary.
20 static Value materializeOpFoldResult(ImplicitLocOpBuilder &builder,
21                                      OpFoldResult opFoldResult) {
22   if (opFoldResult.is<Value>())
23     return opFoldResult.get<Value>();
24   auto attr = opFoldResult.get<Attribute>().cast<IntegerAttr>();
25   return builder.create<arith::ConstantIndexOp>(attr.getValue().getSExtValue());
26 }
27 
28 /// Extract the slices of `operands` supplied to the given operation `op` such
29 /// that they are sufficient to execute the op for the subset of its iteration
30 /// space defined by `splitIterationSpace`. The subset is a part of the original
31 /// iteration space split at the given `dimension`. If `offset` is provided, it
32 /// indicates the iterator value at which the dimension has been split and
33 /// requires the "high" part starting at the given offset of the operands to be
34 /// generated; otherwise, the "low" part with no offset is generated. Note that
35 /// `operands` are not necessarily the actual operands of `op`.
36 static SmallVector<Value>
37 getOperandSlices(ImplicitLocOpBuilder &builder, LinalgOp op,
38                  ValueRange splitIterationSpace, ValueRange operands,
39                  unsigned dimension, Value offset = nullptr) {
40   SmallVector<Value> slices;
41   slices.reserve(op.getNumInputsAndOutputs());
42   for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
43     auto type = opOperand->get().getType().dyn_cast<ShapedType>();
44     AffineMap indexing = op.getTiedIndexingMap(opOperand);
45 
46     // If the type is not sliceable, or the slice is requested along the
47     // dimension that is not used in indexing this type, just use the entire
48     // operand.
49     if (!type || dimension >= indexing.getNumDims() ||
50         !indexing.isFunctionOfDim(dimension)) {
51       slices.push_back(opOperand->get());
52       continue;
53     }
54 
55     SmallVector<Value, 4> sizes =
56         applyMapToValues(builder, op.getLoc(), indexing, splitIterationSpace);
57     SmallVector<OpFoldResult> offsets(type.getRank(), builder.getIndexAttr(0));
58     SmallVector<OpFoldResult> strides(type.getRank(), builder.getIndexAttr(1));
59 
60     if (offset) {
61       offsets[dimension] = offset;
62       IRRewriter rewriter(builder);
63       offsets = applyMapToValues(rewriter, builder.getLoc(), indexing, offsets);
64     }
65 
66     slices.push_back(createSlice(builder, op.getLoc(),
67                                  operands[opOperand->getOperandNumber()],
68                                  offsets, getAsOpFoldResult(sizes), strides));
69   }
70 
71   return slices;
72 }
73 
74 /// Creates a part of the given `op` split along the iteration space `dimension`
75 /// with the given `size` and an optional `offset` (default 0). Makes slices
76 /// of operands, using the input operands of the original op and the output
77 /// operands provided as `resultOperands`. Expects `splitIterationSpace` to be
78 /// a list of values representing the shape of the iteration space of the
79 /// original op and updates it to be the iteration space of the curent part.
80 /// Returns the split-out op as well as the output operand values updated with
81 /// the partial results produced by this op through `results`.
82 static LinalgOp createSplitPart(
83     ImplicitLocOpBuilder &builder, LinalgOp op, ValueRange resultOperands,
84     llvm::MutableArrayRef<Value> splitIterationSpace, unsigned dimension,
85     Value size, SmallVectorImpl<Value> &results, Value offset = nullptr) {
86   splitIterationSpace[dimension] = size;
87   SmallVector<Value> operands = llvm::to_vector(
88       llvm::map_range(op.getInputOperands(),
89                       [](OpOperand *opOperand) { return opOperand->get(); }));
90   llvm::append_range(operands, resultOperands);
91   operands = getOperandSlices(builder, op, splitIterationSpace, operands,
92                               dimension, offset);
93   Operation *part = op.clone(builder, op.getLoc(),
94                              getTensorOutputTypes(op, operands), operands);
95   results = insertSlicesBack(builder, builder.getLoc(), op, operands,
96                              part->getResults());
97   return cast<LinalgOp>(part);
98 }
99 
100 std::pair<LinalgOp, LinalgOp> linalg::splitOp(RewriterBase &rewriter,
101                                               LinalgOp op, unsigned dimension,
102                                               OpFoldResult splitPoint) {
103   // Bail out on dimension overflow.
104   if (dimension >= op.getNumLoops())
105     return std::make_pair(op, LinalgOp());
106 
107   // Compute the iteration space size as values.
108   ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
109   SmallVector<Value, 4> allShapes =
110       op.createFlatListOfOperandDims(builder, op.getLoc());
111   AffineMap shapesToLoops = op.getShapesToLoopsMap();
112   SmallVector<Value, 4> iterationSpaceShapes =
113       applyMapToValues(builder, op.getLoc(), shapesToLoops, allShapes);
114 
115   // Update the iteration space to have `splitPoint` as the size of `dimension`
116   // and use it to slice operands and results for a new, smaller instance of the
117   // `op`. Adjust the size if necessary to prevent overflows. Insert the partial
118   // results back.
119   Value splitPointValue = materializeOpFoldResult(builder, splitPoint);
120   splitPointValue = builder.createOrFold<AffineMinOp>(
121       builder.getIndexType(),
122       AffineMap::getMultiDimIdentityMap(/*numDims=*/2, builder.getContext()),
123       ValueRange({splitPointValue, iterationSpaceShapes[dimension]}));
124   SmallVector<Value> splitIterationSpace =
125       llvm::to_vector(iterationSpaceShapes);
126   SmallVector<Value> originalResults = llvm::to_vector(
127       llvm::map_range(op.getOutputOperands(),
128                       [](OpOperand *opOperand) { return opOperand->get(); }));
129   SmallVector<Value> firstResults;
130   LinalgOp first =
131       createSplitPart(builder, op, originalResults, splitIterationSpace,
132                       dimension, splitPointValue, firstResults);
133 
134   // Update the iteration space to cover the remaining part of the original
135   // space, then create another instance of the `op` in that space. The size of
136   // the remaining part may become zero, but is never negative because of the
137   // adjustment above.
138   AffineExpr d0 = builder.getAffineDimExpr(0);
139   AffineExpr d1 = builder.getAffineDimExpr(1);
140   SmallVector<Value, 4> remainingSizes = applyMapToValues(
141       builder, op.getLoc(), AffineMap::inferFromExprList({d0 - d1}).front(),
142       {iterationSpaceShapes[dimension], splitPointValue});
143   SmallVector<Value> secondResults;
144   LinalgOp second =
145       createSplitPart(builder, op, firstResults, splitIterationSpace, dimension,
146                       remainingSizes.front(), secondResults, splitPointValue);
147 
148   // Fixup the linalg.index results in the second part.
149   SmallVector<Value> ivAdditions;
150   ivAdditions.resize(splitIterationSpace.size());
151   ivAdditions[dimension] = splitPointValue;
152   linalg::addTileLoopIvsToIndexOpResults(builder, cast<LinalgOp>(second),
153                                          ivAdditions);
154 
155   // Replace the original op with the results of the two newly created ops.
156   rewriter.replaceOp(op, secondResults);
157   return std::make_pair(first, second);
158 }
159