1 //===- Split.cpp - Structured op splitting --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 11 #include "mlir/Dialect/Linalg/Utils/Utils.h" 12 #include "mlir/Dialect/Utils/StaticValueUtils.h" 13 14 #include "llvm/ADT/STLExtras.h" 15 16 using namespace mlir; 17 using namespace mlir::linalg; 18 19 /// Extract the slices of `operands` supplied to the given operation `op` such 20 /// that they are sufficient to execute the op for the subset of its iteration 21 /// space defined by `splitIterationSpace`. The subset is a part of the original 22 /// iteration space split at the given `dimension`. If `offset` is provided, it 23 /// indicates the iterator value at which the dimension has been split and 24 /// requires the "high" part starting at the given offset of the operands to be 25 /// generated; otherwise, the "low" part with no offset is generated. Note that 26 /// `operands` are not necessarily the actual operands of `op`. 27 static SmallVector<Value> 28 getOperandSlices(RewriterBase &b, Location loc, LinalgOp op, 29 ValueRange splitIterationSpace, ValueRange operands, 30 unsigned dimension, Value offset = nullptr) { 31 SmallVector<Value> slices; 32 slices.reserve(op.getNumInputsAndOutputs()); 33 for (OpOperand *opOperand : op.getInputAndOutputOperands()) { 34 auto type = opOperand->get().getType().dyn_cast<ShapedType>(); 35 AffineMap indexing = op.getTiedIndexingMap(opOperand); 36 37 // If the type is not sliceable, or the slice is requested along the 38 // dimension that is not used in indexing this type, just use the entire 39 // operand. 40 if (!type || dimension >= indexing.getNumDims() || 41 !indexing.isFunctionOfDim(dimension)) { 42 slices.push_back(opOperand->get()); 43 continue; 44 } 45 46 SmallVector<OpFoldResult> sizes; 47 sizes.reserve(indexing.getNumResults()); 48 for (AffineExpr dimIndexing : indexing.getResults()) { 49 sizes.push_back(makeComposedFoldedAffineApply( 50 b, loc, dimIndexing, 51 getAsOpFoldResult(llvm::to_vector(splitIterationSpace)))); 52 } 53 SmallVector<OpFoldResult> offsets(type.getRank(), b.getIndexAttr(0)); 54 SmallVector<OpFoldResult> strides(type.getRank(), b.getIndexAttr(1)); 55 56 if (offset) { 57 offsets[dimension] = offset; 58 offsets = applyMapToValues(b, loc, indexing, offsets); 59 } 60 61 slices.push_back(createSlice(b, loc, 62 operands[opOperand->getOperandNumber()], 63 offsets, sizes, strides)); 64 } 65 66 return slices; 67 } 68 69 /// Creates a part of the given `op` split along the iteration space `dimension` 70 /// with the given `size` and an optional `offset` (default 0). Makes slices 71 /// of operands, using the input operands of the original op and the output 72 /// operands provided as `resultOperands`. Expects `splitIterationSpace` to be 73 /// a list of values representing the shape of the iteration space of the 74 /// original op and updates it to be the iteration space of the curent part. 75 /// Returns the split-out op as well as the output operand values updated with 76 /// the partial results produced by this op through `results`. 77 static LinalgOp 78 createSplitPart(RewriterBase &b, Location loc, LinalgOp op, 79 ValueRange resultOperands, 80 llvm::MutableArrayRef<Value> splitIterationSpace, 81 unsigned dimension, OpFoldResult size, 82 SmallVectorImpl<Value> &results, Value offset = nullptr) { 83 ImplicitLocOpBuilder implicit(op.getLoc(), b); 84 splitIterationSpace[dimension] = materializeOpFoldResult(implicit, size); 85 SmallVector<Value> operands = llvm::to_vector( 86 llvm::map_range(op.getInputOperands(), 87 [](OpOperand *opOperand) { return opOperand->get(); })); 88 llvm::append_range(operands, resultOperands); 89 operands = getOperandSlices(b, loc, op, splitIterationSpace, operands, 90 dimension, offset); 91 Operation *part = 92 op.clone(b, loc, getTensorOutputTypes(op, operands), operands); 93 results = insertSlicesBack(b, loc, op, operands, part->getResults()); 94 return cast<LinalgOp>(part); 95 } 96 97 std::pair<LinalgOp, LinalgOp> linalg::splitOp(RewriterBase &rewriter, 98 LinalgOp op, unsigned dimension, 99 OpFoldResult splitPoint) { 100 // Bail out on dimension overflow. 101 if (dimension >= op.getNumLoops()) 102 return std::make_pair(op, LinalgOp()); 103 104 // Compute the iteration space size as values. 105 SmallVector<Value, 4> allShapes = 106 op.createFlatListOfOperandDims(rewriter, op.getLoc()); 107 AffineMap shapesToLoops = op.getShapesToLoopsMap(); 108 SmallVector<Value, 4> iterationSpaceShapes = 109 applyMapToValues(rewriter, op.getLoc(), shapesToLoops, allShapes); 110 111 // Update the iteration space to have `splitPoint` as the size of `dimension` 112 // and use it to slice operands and results for a new, smaller instance of the 113 // `op`. Adjust the size if necessary to prevent overflows. Insert the partial 114 // results back. 115 OpFoldResult dimSize = getAsOpFoldResult(iterationSpaceShapes[dimension]); 116 OpFoldResult minSplitPoint = makeComposedFoldedAffineMin( 117 rewriter, op->getLoc(), 118 AffineMap::getMultiDimIdentityMap(/*numDims=*/2, rewriter.getContext()), 119 {splitPoint, dimSize}); 120 SmallVector<Value> splitIterationSpace = 121 llvm::to_vector(iterationSpaceShapes); 122 SmallVector<Value> originalResults = llvm::to_vector( 123 llvm::map_range(op.getOutputOperands(), 124 [](OpOperand *opOperand) { return opOperand->get(); })); 125 SmallVector<Value> firstResults; 126 LinalgOp first = createSplitPart(rewriter, op.getLoc(), op, originalResults, 127 splitIterationSpace, dimension, 128 minSplitPoint, firstResults); 129 130 // Update the iteration space to cover the remaining part of the original 131 // space, then create another instance of the `op` in that space. The size of 132 // the remaining part may become zero, but is never negative because of the 133 // adjustment above. 134 AffineExpr d0 = rewriter.getAffineDimExpr(0); 135 AffineExpr d1 = rewriter.getAffineDimExpr(1); 136 OpFoldResult remainingSize = makeComposedFoldedAffineApply( 137 rewriter, op.getLoc(), d0 - d1, {dimSize, minSplitPoint}); 138 SmallVector<Value> secondResults; 139 ImplicitLocOpBuilder implicit(op.getLoc(), rewriter); 140 Value splitPointValue = materializeOpFoldResult(implicit, minSplitPoint); 141 LinalgOp second = createSplitPart( 142 rewriter, op.getLoc(), op, firstResults, splitIterationSpace, dimension, 143 remainingSize, secondResults, splitPointValue); 144 145 // Fixup the linalg.index results in the second part. 146 SmallVector<Value> ivAdditions; 147 ivAdditions.resize(splitIterationSpace.size()); 148 ivAdditions[dimension] = splitPointValue; 149 linalg::offsetIndices(rewriter, cast<LinalgOp>(second), ivAdditions); 150 151 // Replace the original op with the results of the two newly created ops. 152 rewriter.replaceOp(op, secondResults); 153 return std::make_pair(first, second); 154 } 155