1 //===- Split.cpp - Structured op splitting --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 11 #include "mlir/Dialect/Linalg/Utils/Utils.h" 12 13 #include "llvm/ADT/STLExtras.h" 14 15 using namespace mlir; 16 using namespace mlir::linalg; 17 18 /// Turns an OpFoldResult into a value, creating an index-typed constant if 19 /// necessary. 20 static Value materializeOpFoldResult(ImplicitLocOpBuilder &builder, 21 OpFoldResult opFoldResult) { 22 if (opFoldResult.is<Value>()) 23 return opFoldResult.get<Value>(); 24 auto attr = opFoldResult.get<Attribute>().cast<IntegerAttr>(); 25 return builder.create<arith::ConstantIndexOp>(attr.getValue().getSExtValue()); 26 } 27 28 /// Extract the slices of `operands` supplied to the given operation `op` such 29 /// that they are sufficient to execute the op for the subset of its iteration 30 /// space defined by `splitIterationSpace`. The subset is a part of the original 31 /// iteration space split at the given `dimension`. If `offset` is provided, it 32 /// indicates the iterator value at which the dimension has been split and 33 /// requires the "high" part starting at the given offset of the operands to be 34 /// generated; otherwise, the "low" part with no offset is generated. Note that 35 /// `operands` are not necessarily the actual operands of `op`. 36 static SmallVector<Value> 37 getOperandSlices(ImplicitLocOpBuilder &builder, LinalgOp op, 38 ValueRange splitIterationSpace, ValueRange operands, 39 unsigned dimension, Value offset = nullptr) { 40 SmallVector<Value> slices; 41 slices.reserve(op.getNumInputsAndOutputs()); 42 for (OpOperand *opOperand : op.getInputAndOutputOperands()) { 43 auto type = opOperand->get().getType().dyn_cast<ShapedType>(); 44 AffineMap indexing = op.getTiedIndexingMap(opOperand); 45 46 // If the type is not sliceable, or the slice is requested along the 47 // dimension that is not used in indexing this type, just use the entire 48 // operand. 49 if (!type || dimension >= indexing.getNumDims() || 50 !indexing.isFunctionOfDim(dimension)) { 51 slices.push_back(opOperand->get()); 52 continue; 53 } 54 55 SmallVector<Value, 4> sizes = 56 applyMapToValues(builder, op.getLoc(), indexing, splitIterationSpace); 57 SmallVector<OpFoldResult> offsets(type.getRank(), builder.getIndexAttr(0)); 58 SmallVector<OpFoldResult> strides(type.getRank(), builder.getIndexAttr(1)); 59 60 if (offset) { 61 offsets[dimension] = offset; 62 IRRewriter rewriter(builder); 63 offsets = applyMapToValues(rewriter, builder.getLoc(), indexing, offsets); 64 } 65 66 slices.push_back(createSlice(builder, op.getLoc(), 67 operands[opOperand->getOperandNumber()], 68 offsets, getAsOpFoldResult(sizes), strides)); 69 } 70 71 return slices; 72 } 73 74 /// Creates a part of the given `op` split along the iteration space `dimension` 75 /// with the given `size` and an optional `offset` (default 0). Makes slices 76 /// of operands, using the input operands of the original op and the output 77 /// operands provided as `resultOperands`. Expects `splitIterationSpace` to be 78 /// a list of values representing the shape of the iteration space of the 79 /// original op and updates it to be the iteration space of the curent part. 80 /// Returns the split-out op as well as the output operand values updated with 81 /// the partial results produced by this op through `results`. 82 static LinalgOp createSplitPart( 83 ImplicitLocOpBuilder &builder, LinalgOp op, ValueRange resultOperands, 84 llvm::MutableArrayRef<Value> splitIterationSpace, unsigned dimension, 85 Value size, SmallVectorImpl<Value> &results, Value offset = nullptr) { 86 splitIterationSpace[dimension] = size; 87 SmallVector<Value> operands = llvm::to_vector( 88 llvm::map_range(op.getInputOperands(), 89 [](OpOperand *opOperand) { return opOperand->get(); })); 90 llvm::append_range(operands, resultOperands); 91 operands = getOperandSlices(builder, op, splitIterationSpace, operands, 92 dimension, offset); 93 Operation *part = op.clone(builder, op.getLoc(), 94 getTensorOutputTypes(op, operands), operands); 95 results = insertSlicesBack(builder, builder.getLoc(), op, operands, 96 part->getResults()); 97 return cast<LinalgOp>(part); 98 } 99 100 std::pair<LinalgOp, LinalgOp> linalg::splitOp(RewriterBase &rewriter, 101 LinalgOp op, unsigned dimension, 102 OpFoldResult splitPoint) { 103 // Bail out on dimension overflow. 104 if (dimension >= op.getNumLoops()) 105 return std::make_pair(op, LinalgOp()); 106 107 // Compute the iteration space size as values. 108 ImplicitLocOpBuilder builder(op.getLoc(), rewriter); 109 SmallVector<Value, 4> allShapes = 110 op.createFlatListOfOperandDims(builder, op.getLoc()); 111 AffineMap shapesToLoops = op.getShapesToLoopsMap(); 112 SmallVector<Value, 4> iterationSpaceShapes = 113 applyMapToValues(builder, op.getLoc(), shapesToLoops, allShapes); 114 115 // Update the iteration space to have `splitPoint` as the size of `dimension` 116 // and use it to slice operands and results for a new, smaller instance of the 117 // `op`. Adjust the size if necessary to prevent overflows. Insert the partial 118 // results back. 119 Value splitPointValue = materializeOpFoldResult(builder, splitPoint); 120 splitPointValue = builder.createOrFold<AffineMinOp>( 121 builder.getIndexType(), 122 AffineMap::getMultiDimIdentityMap(/*numDims=*/2, builder.getContext()), 123 ValueRange({splitPointValue, iterationSpaceShapes[dimension]})); 124 SmallVector<Value> splitIterationSpace = 125 llvm::to_vector(iterationSpaceShapes); 126 SmallVector<Value> originalResults = llvm::to_vector( 127 llvm::map_range(op.getOutputOperands(), 128 [](OpOperand *opOperand) { return opOperand->get(); })); 129 SmallVector<Value> firstResults; 130 LinalgOp first = 131 createSplitPart(builder, op, originalResults, splitIterationSpace, 132 dimension, splitPointValue, firstResults); 133 134 // Update the iteration space to cover the remaining part of the original 135 // space, then create another instance of the `op` in that space. The size of 136 // the remaining part may become zero, but is never negative because of the 137 // adjustment above. 138 AffineExpr d0 = builder.getAffineDimExpr(0); 139 AffineExpr d1 = builder.getAffineDimExpr(1); 140 SmallVector<Value, 4> remainingSizes = applyMapToValues( 141 builder, op.getLoc(), AffineMap::inferFromExprList({d0 - d1}).front(), 142 {iterationSpaceShapes[dimension], splitPointValue}); 143 SmallVector<Value> secondResults; 144 LinalgOp second = 145 createSplitPart(builder, op, firstResults, splitIterationSpace, dimension, 146 remainingSizes.front(), secondResults, splitPointValue); 147 148 // Fixup the linalg.index results in the second part. 149 SmallVector<Value> ivAdditions; 150 ivAdditions.resize(splitIterationSpace.size()); 151 ivAdditions[dimension] = splitPointValue; 152 linalg::addTileLoopIvsToIndexOpResults(builder, cast<LinalgOp>(second), 153 ivAdditions); 154 155 // Replace the original op with the results of the two newly created ops. 156 rewriter.replaceOp(op, secondResults); 157 return std::make_pair(first, second); 158 } 159