1 //===- Split.cpp - Structured op splitting --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 11 #include "mlir/Dialect/Linalg/Utils/Utils.h" 12 13 #include "llvm/ADT/STLExtras.h" 14 15 using namespace mlir; 16 using namespace mlir::linalg; 17 18 /// Extract the slices of `operands` supplied to the given operation `op` such 19 /// that they are sufficient to execute the op for the subset of its iteration 20 /// space defined by `splitIterationSpace`. The subset is a part of the original 21 /// iteration space split at the given `dimension`. If `offset` is provided, it 22 /// indicates the iterator value at which the dimension has been split and 23 /// requires the "high" part starting at the given offset of the operands to be 24 /// generated; otherwise, the "low" part with no offset is generated. Note that 25 /// `operands` are not necessarily the actual operands of `op`. 26 static SmallVector<Value> 27 getOperandSlices(ImplicitLocOpBuilder &builder, LinalgOp op, 28 ValueRange splitIterationSpace, ValueRange operands, 29 unsigned dimension, Value offset = nullptr) { 30 SmallVector<Value> slices; 31 slices.reserve(op.getNumInputsAndOutputs()); 32 for (OpOperand *opOperand : op.getInputAndOutputOperands()) { 33 auto type = opOperand->get().getType().dyn_cast<ShapedType>(); 34 AffineMap indexing = op.getTiedIndexingMap(opOperand); 35 36 // If the type is not sliceable, or the slice is requested along the 37 // dimension that is not used in indexing this type, just use the entire 38 // operand. 39 if (!type || dimension >= indexing.getNumDims() || 40 !indexing.isFunctionOfDim(dimension)) { 41 slices.push_back(opOperand->get()); 42 continue; 43 } 44 45 SmallVector<Value, 4> sizes = 46 applyMapToValues(builder, op.getLoc(), indexing, splitIterationSpace); 47 SmallVector<OpFoldResult> offsets(type.getRank(), builder.getIndexAttr(0)); 48 SmallVector<OpFoldResult> strides(type.getRank(), builder.getIndexAttr(1)); 49 50 if (offset) { 51 offsets[dimension] = offset; 52 IRRewriter rewriter(builder); 53 offsets = applyMapToValues(rewriter, builder.getLoc(), indexing, offsets); 54 } 55 56 slices.push_back(createSlice(builder, op.getLoc(), 57 operands[opOperand->getOperandNumber()], 58 offsets, getAsOpFoldResult(sizes), strides)); 59 } 60 61 return slices; 62 } 63 64 /// Creates a part of the given `op` split along the iteration space `dimension` 65 /// with the given `size` and an optional `offset` (default 0). Makes slices 66 /// of operands, using the input operands of the original op and the output 67 /// operands provided as `resultOperands`. Expects `splitIterationSpace` to be 68 /// a list of values representing the shape of the iteration space of the 69 /// original op and updates it to be the iteration space of the curent part. 70 /// Returns the split-out op as well as the output operand values updated with 71 /// the partial results produced by this op through `results`. 72 static LinalgOp createSplitPart( 73 ImplicitLocOpBuilder &builder, LinalgOp op, ValueRange resultOperands, 74 llvm::MutableArrayRef<Value> splitIterationSpace, unsigned dimension, 75 Value size, SmallVectorImpl<Value> &results, Value offset = nullptr) { 76 splitIterationSpace[dimension] = size; 77 SmallVector<Value> operands = llvm::to_vector( 78 llvm::map_range(op.getInputOperands(), 79 [](OpOperand *opOperand) { return opOperand->get(); })); 80 llvm::append_range(operands, resultOperands); 81 operands = getOperandSlices(builder, op, splitIterationSpace, operands, 82 dimension, offset); 83 Operation *part = op.clone(builder, op.getLoc(), 84 getTensorOutputTypes(op, operands), operands); 85 results = insertSlicesBack(builder, builder.getLoc(), op, operands, 86 part->getResults()); 87 return cast<LinalgOp>(part); 88 } 89 90 std::pair<LinalgOp, LinalgOp> linalg::splitOp(RewriterBase &rewriter, 91 LinalgOp op, unsigned dimension, 92 OpFoldResult splitPoint) { 93 // Bail out on dimension overflow. 94 if (dimension >= op.getNumLoops()) 95 return std::make_pair(op, LinalgOp()); 96 97 // Compute the iteration space size as values. 98 ImplicitLocOpBuilder builder(op.getLoc(), rewriter); 99 SmallVector<Value, 4> allShapes = 100 op.createFlatListOfOperandDims(builder, op.getLoc()); 101 AffineMap shapesToLoops = op.getShapesToLoopsMap(); 102 SmallVector<Value, 4> iterationSpaceShapes = 103 applyMapToValues(builder, op.getLoc(), shapesToLoops, allShapes); 104 105 // Update the iteration space to have `splitPoint` as the size of `dimension` 106 // and use it to slice operands and results for a new, smaller instance of the 107 // `op`. Adjust the size if necessary to prevent overflows. Insert the partial 108 // results back. 109 Value splitPointValue = materializeOpFoldResult(builder, splitPoint); 110 splitPointValue = builder.createOrFold<AffineMinOp>( 111 builder.getIndexType(), 112 AffineMap::getMultiDimIdentityMap(/*numDims=*/2, builder.getContext()), 113 ValueRange({splitPointValue, iterationSpaceShapes[dimension]})); 114 SmallVector<Value> splitIterationSpace = 115 llvm::to_vector(iterationSpaceShapes); 116 SmallVector<Value> originalResults = llvm::to_vector( 117 llvm::map_range(op.getOutputOperands(), 118 [](OpOperand *opOperand) { return opOperand->get(); })); 119 SmallVector<Value> firstResults; 120 LinalgOp first = 121 createSplitPart(builder, op, originalResults, splitIterationSpace, 122 dimension, splitPointValue, firstResults); 123 124 // Update the iteration space to cover the remaining part of the original 125 // space, then create another instance of the `op` in that space. The size of 126 // the remaining part may become zero, but is never negative because of the 127 // adjustment above. 128 AffineExpr d0 = builder.getAffineDimExpr(0); 129 AffineExpr d1 = builder.getAffineDimExpr(1); 130 SmallVector<Value, 4> remainingSizes = applyMapToValues( 131 builder, op.getLoc(), AffineMap::inferFromExprList({d0 - d1}).front(), 132 {iterationSpaceShapes[dimension], splitPointValue}); 133 SmallVector<Value> secondResults; 134 LinalgOp second = 135 createSplitPart(builder, op, firstResults, splitIterationSpace, dimension, 136 remainingSizes.front(), secondResults, splitPointValue); 137 138 // Fixup the linalg.index results in the second part. 139 SmallVector<Value> ivAdditions; 140 ivAdditions.resize(splitIterationSpace.size()); 141 ivAdditions[dimension] = splitPointValue; 142 linalg::addTileLoopIvsToIndexOpResults(builder, cast<LinalgOp>(second), 143 ivAdditions); 144 145 // Replace the original op with the results of the two newly created ops. 146 rewriter.replaceOp(op, secondResults); 147 return std::make_pair(first, second); 148 } 149