1 //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements linalg transformation to break a reduction dimension 10 // between a parallel and a reduction dimension. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <utility> 15 16 #include "mlir/Analysis/SliceAnalysis.h" 17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/IR/PatternMatch.h" 23 24 using namespace mlir; 25 using namespace mlir::linalg; 26 27 /// Return the identity numeric value associated to the give op. 28 static Optional<Attribute> getIdentity(Operation *op) { 29 // Builder only used as helper for attribute creation. 30 OpBuilder b(op->getContext()); 31 Type resultType = op->getResult(0).getType(); 32 if (auto floatType = resultType.dyn_cast<FloatType>()) { 33 const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); 34 if (isa<arith::AddFOp>(op)) 35 return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); 36 if (isa<arith::MulFOp>(op)) 37 return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1)); 38 if (isa<arith::MaxFOp>(op)) 39 return b.getFloatAttr(resultType, 40 llvm::APFloat::getLargest(semantic, true)); 41 if (isa<arith::MinFOp>(op)) 42 return b.getFloatAttr(resultType, 43 llvm::APFloat::getLargest(semantic, true)); 44 return llvm::None; 45 } 46 if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op)) 47 return b.getIntegerAttr(resultType, 0); 48 if (isa<arith::AndIOp>(op)) 49 return b.getIntegerAttr(resultType, -1); 50 if (isa<arith::MaxSIOp>(op)) 51 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min()); 52 if (isa<arith::MinSIOp>(op)) 53 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max()); 54 if (isa<arith::MulIOp>(op)) 55 return b.getIntegerAttr(resultType, 1); 56 return llvm::None; 57 } 58 59 FailureOr<LinalgOp> mlir::linalg::splitReduction( 60 PatternRewriter &b, LinalgOp op, 61 const ControlSplitReductionFn &controlSplitReductionFn, 62 const LinalgTransformationFilter &filter) { 63 if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() || 64 op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 || 65 !op.hasOnlyProjectedPermutations()) 66 return b.notifyMatchFailure(op, "precondition not met"); 67 68 FailureOr<SplitReductionResult> res = 69 splitReduction(b, op, controlSplitReductionFn); 70 if (failed(res)) 71 return failure(); 72 73 filter.replaceLinalgTransformationFilter(b, res->splitLinalgOp); 74 filter.replaceLinalgTransformationFilter(b, res->resultCombiningLinalgOp); 75 76 return res->splitLinalgOp; 77 } 78 79 FailureOr<SplitReductionResult> mlir::linalg::splitReduction( 80 PatternRewriter &b, LinalgOp op, 81 const ControlSplitReductionFn &controlSplitReductionFn) { 82 OpBuilder::InsertionGuard guard(b); 83 b.setInsertionPoint(op); 84 85 std::pair<int64_t, unsigned> control = controlSplitReductionFn(op); 86 int64_t ratio = control.first; 87 unsigned insertDimIndex = control.second; 88 if (ratio <= 1) 89 return b.notifyMatchFailure(op, "split ratio needs to be greater than 1"); 90 91 SmallVector<unsigned> dims; 92 op.getReductionDims(dims); 93 assert(dims.size() == 1); 94 unsigned reductionDim = dims[0]; 95 SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges(); 96 int64_t reductionDimSize = loopRanges[reductionDim]; 97 if (reductionDimSize == ShapedType::kDynamicSize || 98 reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges.size()) 99 return b.notifyMatchFailure( 100 op, "Reduction dimension not divisible by split ratio"); 101 102 SmallVector<Operation *, 4> combinerOps; 103 if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) || 104 combinerOps.size() != 1) 105 return b.notifyMatchFailure(op, "Cannot match the reduction pattern"); 106 107 Operation *reductionOp = combinerOps[0]; 108 Optional<Attribute> identity = getIdentity(reductionOp); 109 if (!identity) 110 return b.notifyMatchFailure(op, "Unknown identity value for the reduction"); 111 112 Location loc = op->getLoc(); 113 SmallVector<Value> newInputs; 114 SmallVector<AffineMap> newMaps; 115 // Calculate the new shapes and indexing maps of the input operands. 116 for (OpOperand *operand : op.getInputOperands()) { 117 AffineMap map = op.getTiedIndexingMap(operand); 118 SmallVector<int64_t> newShape; 119 SmallVector<AffineExpr> exprs; 120 SmallVector<ReassociationIndices> reassociation; 121 unsigned index = 0; 122 for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) { 123 unsigned dim = map.getDimPosition(idx); 124 if (reductionDim == dim) { 125 newShape.push_back(ratio); 126 newShape.push_back(op.getShape(operand)[idx] / ratio); 127 reassociation.push_back({index++, index++}); 128 exprs.push_back(b.getAffineDimExpr(insertDimIndex)); 129 exprs.push_back( 130 b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); 131 continue; 132 } 133 newShape.push_back(op.getShape(operand)[idx]); 134 exprs.push_back(b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); 135 reassociation.push_back({index++}); 136 } 137 newMaps.push_back( 138 AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext())); 139 // If the shape is unchanged the input doesn't change. 140 if (newShape == op.getShape(operand)) { 141 newInputs.push_back(operand->get()); 142 continue; 143 } 144 Type newType = RankedTensorType::get( 145 newShape, 146 operand->get().getType().cast<RankedTensorType>().getElementType()); 147 Value newInput = b.create<tensor::ExpandShapeOp>( 148 loc, newType, operand->get(), reassociation); 149 newInputs.push_back(newInput); 150 } 151 152 // Calculate the new output map and shape, we insert the new dimension based 153 // on the index returned by `controlSplitReductionFn`. 154 SmallVector<int64_t> newOutputShape; 155 AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0)); 156 ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0)); 157 SmallVector<AffineExpr> outputExpr; 158 for (unsigned idx : 159 llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) { 160 if (idx == insertDimIndex) { 161 newOutputShape.push_back(ratio); 162 outputExpr.push_back(b.getAffineDimExpr(insertDimIndex)); 163 continue; 164 } 165 unsigned oldDim = idx < insertDimIndex ? idx : idx - 1; 166 newOutputShape.push_back(oldShape[oldDim]); 167 unsigned dim = oldOutputMap.getDimPosition(oldDim); 168 outputExpr.push_back( 169 b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); 170 } 171 Value initTensor = b.create<linalg::InitTensorOp>( 172 loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); 173 Value constantOp = b.create<arith::ConstantOp>(loc, *identity); 174 Value identityTensor = 175 b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor) 176 .getResult(0); 177 178 newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, 179 op.getContext())); 180 SmallVector<StringRef> newIteratorTypes; 181 for (auto &it : llvm::enumerate(op.iterator_types())) { 182 if (insertDimIndex == it.index()) 183 newIteratorTypes.push_back(getParallelIteratorTypeName()); 184 newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue()); 185 } 186 // Create the new op matching the original op with an extra parallel 187 // dimension. 188 GenericOp genericOp = b.create<GenericOp>( 189 loc, TypeRange({initTensor.getType()}), newInputs, 190 ValueRange({identityTensor}), newMaps, newIteratorTypes); 191 b.inlineRegionBefore(op->getRegion(0), genericOp.region(), 192 genericOp.region().begin()); 193 194 // Then create a new reduction that only reduce the newly added dimension 195 // from the previous op. 196 unsigned intermRank = newOutputShape.size(); 197 AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); 198 SmallVector<Value> outputOperands = op.getOutputOperands(); 199 SmallVector<StringRef> reductionIteratorTypes; 200 SmallVector<AffineExpr> exprs; 201 for (unsigned i : llvm::seq<unsigned>(0, intermRank)) { 202 if (insertDimIndex == i) { 203 reductionIteratorTypes.push_back(getReductionIteratorTypeName()); 204 } else { 205 exprs.push_back(b.getAffineDimExpr(i)); 206 reductionIteratorTypes.push_back(getParallelIteratorTypeName()); 207 } 208 } 209 AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext()); 210 SmallVector<AffineMap> reductionMaps = {inputMap, outputMap}; 211 212 auto reduction = b.create<GenericOp>( 213 loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}), 214 outputOperands, reductionMaps, reductionIteratorTypes, 215 [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { 216 Operation *clonedReductionOp = b.clone(*reductionOp); 217 clonedReductionOp->setOperand(0, inputs[0]); 218 clonedReductionOp->setOperand(1, inputs[1]); 219 b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); 220 }); 221 b.replaceOp(op, reduction.getResults()); 222 223 return SplitReductionResult{identityTensor.getDefiningOp<FillOp>(), 224 cast<LinalgOp>(genericOp.getOperation()), 225 reduction}; 226 } 227 228 namespace { 229 230 struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> { 231 /// Construct a generic pattern applied to all LinalgOp that verify `filter`. 232 LinalgSplitReduction(MLIRContext *context, 233 ControlSplitReductionFn controlSplitReductionFn, 234 LinalgTransformationFilter f, PatternBenefit benefit = 1) 235 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 236 controlSplitReductionFn(std::move(controlSplitReductionFn)), 237 filter(std::move(f)) {} 238 239 LogicalResult matchAndRewrite(LinalgOp op, 240 PatternRewriter &rewriter) const override { 241 return splitReduction(rewriter, op, controlSplitReductionFn, filter); 242 } 243 244 private: 245 ControlSplitReductionFn controlSplitReductionFn; 246 LinalgTransformationFilter filter; 247 }; 248 249 } // namespace 250 251 void linalg::populateSplitReductionPattern( 252 RewritePatternSet &patterns, 253 const ControlSplitReductionFn &controlSplitReductionFn, 254 const LinalgTransformationFilter &f) { 255 patterns.add<LinalgSplitReduction>(patterns.getContext(), 256 controlSplitReductionFn, f); 257 } 258