1 //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements linalg transformation to break a reduction dimension 10 // between a parallel and a reduction dimension. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <utility> 15 16 #include "mlir/Analysis/SliceAnalysis.h" 17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/IR/PatternMatch.h" 23 24 using namespace mlir; 25 using namespace mlir::linalg; 26 27 /// Return the identity numeric value associated to the give op. 28 static Optional<Attribute> getIdentity(Operation *op) { 29 // Builder only used as helper for attribute creation. 30 OpBuilder b(op->getContext()); 31 Type resultType = op->getResult(0).getType(); 32 if (auto floatType = resultType.dyn_cast<FloatType>()) { 33 const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); 34 if (isa<arith::AddFOp>(op)) 35 return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); 36 if (isa<arith::MulFOp>(op)) 37 return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1)); 38 if (isa<arith::MaxFOp>(op)) 39 return b.getFloatAttr(resultType, 40 llvm::APFloat::getLargest(semantic, true)); 41 if (isa<arith::MinFOp>(op)) 42 return b.getFloatAttr(resultType, 43 llvm::APFloat::getLargest(semantic, true)); 44 return llvm::None; 45 } 46 if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op)) 47 return b.getIntegerAttr(resultType, 0); 48 if (isa<arith::AndIOp>(op)) 49 return b.getIntegerAttr(resultType, -1); 50 if (isa<arith::MaxSIOp>(op)) 51 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min()); 52 if (isa<arith::MinSIOp>(op)) 53 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max()); 54 if (isa<arith::MulIOp>(op)) 55 return b.getIntegerAttr(resultType, 1); 56 return llvm::None; 57 } 58 59 FailureOr<LinalgOp> mlir::linalg::splitReduction( 60 PatternRewriter &b, LinalgOp op, 61 const ControlSplitReductionFn &controlSplitReductionFn, 62 const LinalgTransformationFilter &filter) { 63 if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() || 64 op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 || 65 !op.hasOnlyProjectedPermutations()) 66 return b.notifyMatchFailure(op, "precondition not met"); 67 std::pair<int64_t, unsigned> control = controlSplitReductionFn(op); 68 int64_t ratio = control.first; 69 unsigned insertDimIndex = control.second; 70 if (ratio <= 1) 71 return b.notifyMatchFailure(op, "split ratio needs to be greater than 1"); 72 SmallVector<unsigned> dims; 73 op.getReductionDims(dims); 74 assert(dims.size() == 1); 75 unsigned reductionDim = dims[0]; 76 Optional<SmallVector<int64_t, 4>> loopRanges = op.getStaticLoopRanges(); 77 if (!loopRanges) 78 return b.notifyMatchFailure(op, "Cannot analyze loops"); 79 int64_t reductionDimSize = (*loopRanges)[reductionDim]; 80 if (reductionDimSize == ShapedType::kDynamicSize || 81 reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges->size()) 82 return b.notifyMatchFailure( 83 op, "Reduction dimension not divisible by split ratio"); 84 SmallVector<Operation *, 4> combinerOps; 85 if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) || 86 combinerOps.size() != 1) 87 return b.notifyMatchFailure(op, "Cannot match the reduction pattern"); 88 Operation *reductionOp = combinerOps[0]; 89 Optional<Attribute> identity = getIdentity(reductionOp); 90 if (!identity) 91 return b.notifyMatchFailure(op, "Unknown identity value for the redution"); 92 93 Location loc = op->getLoc(); 94 SmallVector<Value> newInputs; 95 SmallVector<AffineMap> newMaps; 96 // Calculate the new shapes and indexing maps of the input operands. 97 for (OpOperand *operand : op.getInputOperands()) { 98 AffineMap map = op.getTiedIndexingMap(operand); 99 SmallVector<int64_t> newShape; 100 SmallVector<AffineExpr> exprs; 101 SmallVector<ReassociationIndices> reassociation; 102 unsigned index = 0; 103 for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) { 104 unsigned dim = map.getDimPosition(idx); 105 if (reductionDim == dim) { 106 newShape.push_back(ratio); 107 newShape.push_back(op.getShape(operand)[idx] / ratio); 108 reassociation.push_back({index++, index++}); 109 exprs.push_back(b.getAffineDimExpr(insertDimIndex)); 110 exprs.push_back( 111 b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); 112 continue; 113 } 114 newShape.push_back(op.getShape(operand)[idx]); 115 exprs.push_back(b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); 116 reassociation.push_back({index++}); 117 } 118 newMaps.push_back( 119 AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext())); 120 // If the shape is unchanged the input doesn't change. 121 if (newShape == op.getShape(operand)) { 122 newInputs.push_back(operand->get()); 123 continue; 124 } 125 Type newType = RankedTensorType::get( 126 newShape, 127 operand->get().getType().cast<RankedTensorType>().getElementType()); 128 Value newInput = b.create<tensor::ExpandShapeOp>( 129 loc, newType, operand->get(), reassociation); 130 newInputs.push_back(newInput); 131 } 132 // Calculate the new output map and shape, we insert the new dimension based 133 // on the index returned by `controlSplitReductionFn`. 134 SmallVector<int64_t> newOutputShape; 135 AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0)); 136 ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0)); 137 SmallVector<AffineExpr> outputExpr; 138 for (unsigned idx : 139 llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) { 140 if (idx == insertDimIndex) { 141 newOutputShape.push_back(ratio); 142 outputExpr.push_back(b.getAffineDimExpr(insertDimIndex)); 143 continue; 144 } 145 unsigned oldDim = idx < insertDimIndex ? idx : idx - 1; 146 newOutputShape.push_back(oldShape[oldDim]); 147 unsigned dim = oldOutputMap.getDimPosition(oldDim); 148 outputExpr.push_back( 149 b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); 150 } 151 Value initTensor = b.create<linalg::InitTensorOp>( 152 loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); 153 Value constantOp = b.create<arith::ConstantOp>(loc, *identity); 154 Value identityTensor = 155 b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor) 156 .getResult(0); 157 158 newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, 159 op.getContext())); 160 SmallVector<StringRef> newIteratorTypes; 161 for (auto &it : llvm::enumerate(op.iterator_types())) { 162 if (insertDimIndex == it.index()) 163 newIteratorTypes.push_back(getParallelIteratorTypeName()); 164 newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue()); 165 } 166 // Create the new op matching the original op with an extra parallel 167 // dimension. 168 GenericOp genericOp = b.create<GenericOp>( 169 loc, TypeRange({initTensor.getType()}), newInputs, 170 ValueRange({identityTensor}), newMaps, newIteratorTypes); 171 b.inlineRegionBefore(op->getRegion(0), genericOp.region(), 172 genericOp.region().begin()); 173 174 // Then create a new reduction that only reduce the newly added dimension from 175 // the previous op. 176 unsigned intermRank = newOutputShape.size(); 177 AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); 178 SmallVector<Value> outputOperands = op.getOutputOperands(); 179 SmallVector<StringRef> reductionIteratorTypes; 180 SmallVector<AffineExpr> exprs; 181 for (unsigned i : llvm::seq<unsigned>(0, intermRank)) { 182 if (insertDimIndex == i) { 183 reductionIteratorTypes.push_back(getReductionIteratorTypeName()); 184 } else { 185 exprs.push_back(b.getAffineDimExpr(i)); 186 reductionIteratorTypes.push_back(getParallelIteratorTypeName()); 187 } 188 } 189 AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext()); 190 SmallVector<AffineMap> reductionMaps = {inputMap, outputMap}; 191 192 auto reduction = b.create<GenericOp>( 193 loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}), 194 outputOperands, reductionMaps, reductionIteratorTypes, 195 [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { 196 Operation *clonedReductionOp = b.clone(*reductionOp); 197 clonedReductionOp->setOperand(0, inputs[0]); 198 clonedReductionOp->setOperand(1, inputs[1]); 199 b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); 200 }); 201 b.replaceOp(op, reduction.getResults()); 202 filter.replaceLinalgTransformationFilter(b, genericOp); 203 filter.replaceLinalgTransformationFilter(b, reduction); 204 return cast<LinalgOp>(genericOp.getOperation()); 205 } 206 207 namespace { 208 209 struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> { 210 /// Construct a generic pattern applied to all LinalgOp that verify `filter`. 211 LinalgSplitReduction(MLIRContext *context, 212 ControlSplitReductionFn controlSplitReductionFn, 213 LinalgTransformationFilter f, PatternBenefit benefit = 1) 214 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 215 controlSplitReductionFn(std::move(controlSplitReductionFn)), 216 filter(std::move(f)) {} 217 218 LogicalResult matchAndRewrite(LinalgOp op, 219 PatternRewriter &rewriter) const override { 220 return splitReduction(rewriter, op, controlSplitReductionFn, filter); 221 } 222 223 private: 224 ControlSplitReductionFn controlSplitReductionFn; 225 LinalgTransformationFilter filter; 226 }; 227 228 } // namespace 229 230 void linalg::populateSplitReductionPattern( 231 RewritePatternSet &patterns, 232 const ControlSplitReductionFn &controlSplitReductionFn, 233 const LinalgTransformationFilter &f) { 234 patterns.add<LinalgSplitReduction>(patterns.getContext(), 235 controlSplitReductionFn, f); 236 } 237