1 //===- Split.cpp - Structured op splitting --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
11 #include "mlir/Dialect/Linalg/Utils/Utils.h"
12 
13 #include "llvm/ADT/STLExtras.h"
14 
15 using namespace mlir;
16 using namespace mlir::linalg;
17 
18 /// Extract the slices of `operands` supplied to the given operation `op` such
19 /// that they are sufficient to execute the op for the subset of its iteration
20 /// space defined by `splitIterationSpace`. The subset is a part of the original
21 /// iteration space split at the given `dimension`. If `offset` is provided, it
22 /// indicates the iterator value at which the dimension has been split and
23 /// requires the "high" part starting at the given offset of the operands to be
24 /// generated; otherwise, the "low" part with no offset is generated. Note that
25 /// `operands` are not necessarily the actual operands of `op`.
26 static SmallVector<Value>
27 getOperandSlices(ImplicitLocOpBuilder &builder, LinalgOp op,
28                  ValueRange splitIterationSpace, ValueRange operands,
29                  unsigned dimension, Value offset = nullptr) {
30   SmallVector<Value> slices;
31   slices.reserve(op.getNumInputsAndOutputs());
32   for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
33     auto type = opOperand->get().getType().dyn_cast<ShapedType>();
34     AffineMap indexing = op.getTiedIndexingMap(opOperand);
35 
36     // If the type is not sliceable, or the slice is requested along the
37     // dimension that is not used in indexing this type, just use the entire
38     // operand.
39     if (!type || dimension >= indexing.getNumDims() ||
40         !indexing.isFunctionOfDim(dimension)) {
41       slices.push_back(opOperand->get());
42       continue;
43     }
44 
45     SmallVector<Value, 4> sizes =
46         applyMapToValues(builder, op.getLoc(), indexing, splitIterationSpace);
47     SmallVector<OpFoldResult> offsets(type.getRank(), builder.getIndexAttr(0));
48     SmallVector<OpFoldResult> strides(type.getRank(), builder.getIndexAttr(1));
49 
50     if (offset) {
51       offsets[dimension] = offset;
52       IRRewriter rewriter(builder);
53       offsets = applyMapToValues(rewriter, builder.getLoc(), indexing, offsets);
54     }
55 
56     slices.push_back(createSlice(builder, op.getLoc(),
57                                  operands[opOperand->getOperandNumber()],
58                                  offsets, getAsOpFoldResult(sizes), strides));
59   }
60 
61   return slices;
62 }
63 
64 /// Creates a part of the given `op` split along the iteration space `dimension`
65 /// with the given `size` and an optional `offset` (default 0). Makes slices
66 /// of operands, using the input operands of the original op and the output
67 /// operands provided as `resultOperands`. Expects `splitIterationSpace` to be
68 /// a list of values representing the shape of the iteration space of the
69 /// original op and updates it to be the iteration space of the curent part.
70 /// Returns the split-out op as well as the output operand values updated with
71 /// the partial results produced by this op through `results`.
72 static LinalgOp createSplitPart(
73     ImplicitLocOpBuilder &builder, LinalgOp op, ValueRange resultOperands,
74     llvm::MutableArrayRef<Value> splitIterationSpace, unsigned dimension,
75     Value size, SmallVectorImpl<Value> &results, Value offset = nullptr) {
76   splitIterationSpace[dimension] = size;
77   SmallVector<Value> operands = llvm::to_vector(
78       llvm::map_range(op.getInputOperands(),
79                       [](OpOperand *opOperand) { return opOperand->get(); }));
80   llvm::append_range(operands, resultOperands);
81   operands = getOperandSlices(builder, op, splitIterationSpace, operands,
82                               dimension, offset);
83   Operation *part = op.clone(builder, op.getLoc(),
84                              getTensorOutputTypes(op, operands), operands);
85   results = insertSlicesBack(builder, builder.getLoc(), op, operands,
86                              part->getResults());
87   return cast<LinalgOp>(part);
88 }
89 
90 std::pair<LinalgOp, LinalgOp> linalg::splitOp(RewriterBase &rewriter,
91                                               LinalgOp op, unsigned dimension,
92                                               OpFoldResult splitPoint) {
93   // Bail out on dimension overflow.
94   if (dimension >= op.getNumLoops())
95     return std::make_pair(op, LinalgOp());
96 
97   // Compute the iteration space size as values.
98   ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
99   SmallVector<Value, 4> allShapes =
100       op.createFlatListOfOperandDims(builder, op.getLoc());
101   AffineMap shapesToLoops = op.getShapesToLoopsMap();
102   SmallVector<Value, 4> iterationSpaceShapes =
103       applyMapToValues(builder, op.getLoc(), shapesToLoops, allShapes);
104 
105   // Update the iteration space to have `splitPoint` as the size of `dimension`
106   // and use it to slice operands and results for a new, smaller instance of the
107   // `op`. Adjust the size if necessary to prevent overflows. Insert the partial
108   // results back.
109   Value splitPointValue = materializeOpFoldResult(builder, splitPoint);
110   splitPointValue = builder.createOrFold<AffineMinOp>(
111       builder.getIndexType(),
112       AffineMap::getMultiDimIdentityMap(/*numDims=*/2, builder.getContext()),
113       ValueRange({splitPointValue, iterationSpaceShapes[dimension]}));
114   SmallVector<Value> splitIterationSpace =
115       llvm::to_vector(iterationSpaceShapes);
116   SmallVector<Value> originalResults = llvm::to_vector(
117       llvm::map_range(op.getOutputOperands(),
118                       [](OpOperand *opOperand) { return opOperand->get(); }));
119   SmallVector<Value> firstResults;
120   LinalgOp first =
121       createSplitPart(builder, op, originalResults, splitIterationSpace,
122                       dimension, splitPointValue, firstResults);
123 
124   // Update the iteration space to cover the remaining part of the original
125   // space, then create another instance of the `op` in that space. The size of
126   // the remaining part may become zero, but is never negative because of the
127   // adjustment above.
128   AffineExpr d0 = builder.getAffineDimExpr(0);
129   AffineExpr d1 = builder.getAffineDimExpr(1);
130   SmallVector<Value, 4> remainingSizes = applyMapToValues(
131       builder, op.getLoc(), AffineMap::inferFromExprList({d0 - d1}).front(),
132       {iterationSpaceShapes[dimension], splitPointValue});
133   SmallVector<Value> secondResults;
134   LinalgOp second =
135       createSplitPart(builder, op, firstResults, splitIterationSpace, dimension,
136                       remainingSizes.front(), secondResults, splitPointValue);
137 
138   // Fixup the linalg.index results in the second part.
139   SmallVector<Value> ivAdditions;
140   ivAdditions.resize(splitIterationSpace.size());
141   ivAdditions[dimension] = splitPointValue;
142   linalg::addTileLoopIvsToIndexOpResults(builder, cast<LinalgOp>(second),
143                                          ivAdditions);
144 
145   // Replace the original op with the results of the two newly created ops.
146   rewriter.replaceOp(op, secondResults);
147   return std::make_pair(first, second);
148 }
149