1 //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements linalg transformation to break a reduction dimension 10 // between a parallel and a reduction dimension. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <utility> 15 16 #include "mlir/Analysis/SliceAnalysis.h" 17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/IR/PatternMatch.h" 23 24 using namespace mlir; 25 using namespace mlir::linalg; 26 27 /// Return the identity numeric value associated to the give op. 28 static Optional<Attribute> getIdentity(Operation *op) { 29 // Builder only used as helper for attribute creation. 30 OpBuilder b(op->getContext()); 31 Type resultType = op->getResult(0).getType(); 32 if (auto floatType = resultType.dyn_cast<FloatType>()) { 33 const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); 34 if (isa<arith::AddFOp>(op)) 35 return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); 36 if (isa<arith::MulFOp>(op)) 37 return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1)); 38 if (isa<arith::MaxFOp>(op)) 39 return b.getFloatAttr(resultType, 40 llvm::APFloat::getLargest(semantic, true)); 41 if (isa<arith::MinFOp>(op)) 42 return b.getFloatAttr(resultType, 43 llvm::APFloat::getLargest(semantic, true)); 44 return llvm::None; 45 } 46 if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op)) 47 return b.getIntegerAttr(resultType, 0); 48 if (isa<arith::AndIOp>(op)) 49 return b.getIntegerAttr(resultType, -1); 50 if (isa<arith::MaxSIOp>(op)) 51 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min()); 52 if (isa<arith::MinSIOp>(op)) 53 return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max()); 54 if (isa<arith::MulIOp>(op)) 55 return b.getIntegerAttr(resultType, 1); 56 return llvm::None; 57 } 58 59 FailureOr<LinalgOp> mlir::linalg::splitReduction( 60 PatternRewriter &b, LinalgOp op, 61 const ControlSplitReductionFn &controlSplitReductionFn, 62 const LinalgTransformationFilter &filter) { 63 if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() || 64 op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 || 65 !op.hasOnlyProjectedPermutations()) 66 return b.notifyMatchFailure(op, "precondition not met"); 67 std::pair<int64_t, unsigned> control = controlSplitReductionFn(op); 68 int64_t ratio = control.first; 69 unsigned insertDimIndex = control.second; 70 if (ratio <= 1) 71 return b.notifyMatchFailure(op, "split ratio needs to be greater than 1"); 72 SmallVector<unsigned> dims; 73 op.getReductionDims(dims); 74 assert(dims.size() == 1); 75 unsigned reductionDim = dims[0]; 76 SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges(); 77 int64_t reductionDimSize = loopRanges[reductionDim]; 78 if (reductionDimSize == ShapedType::kDynamicSize || 79 reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges.size()) 80 return b.notifyMatchFailure( 81 op, "Reduction dimension not divisible by split ratio"); 82 SmallVector<Operation *, 4> combinerOps; 83 if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) || 84 combinerOps.size() != 1) 85 return b.notifyMatchFailure(op, "Cannot match the reduction pattern"); 86 Operation *reductionOp = combinerOps[0]; 87 Optional<Attribute> identity = getIdentity(reductionOp); 88 if (!identity) 89 return b.notifyMatchFailure(op, "Unknown identity value for the redution"); 90 91 Location loc = op->getLoc(); 92 SmallVector<Value> newInputs; 93 SmallVector<AffineMap> newMaps; 94 // Calculate the new shapes and indexing maps of the input operands. 95 for (OpOperand *operand : op.getInputOperands()) { 96 AffineMap map = op.getTiedIndexingMap(operand); 97 SmallVector<int64_t> newShape; 98 SmallVector<AffineExpr> exprs; 99 SmallVector<ReassociationIndices> reassociation; 100 unsigned index = 0; 101 for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) { 102 unsigned dim = map.getDimPosition(idx); 103 if (reductionDim == dim) { 104 newShape.push_back(ratio); 105 newShape.push_back(op.getShape(operand)[idx] / ratio); 106 reassociation.push_back({index++, index++}); 107 exprs.push_back(b.getAffineDimExpr(insertDimIndex)); 108 exprs.push_back( 109 b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); 110 continue; 111 } 112 newShape.push_back(op.getShape(operand)[idx]); 113 exprs.push_back(b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); 114 reassociation.push_back({index++}); 115 } 116 newMaps.push_back( 117 AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext())); 118 // If the shape is unchanged the input doesn't change. 119 if (newShape == op.getShape(operand)) { 120 newInputs.push_back(operand->get()); 121 continue; 122 } 123 Type newType = RankedTensorType::get( 124 newShape, 125 operand->get().getType().cast<RankedTensorType>().getElementType()); 126 Value newInput = b.create<tensor::ExpandShapeOp>( 127 loc, newType, operand->get(), reassociation); 128 newInputs.push_back(newInput); 129 } 130 // Calculate the new output map and shape, we insert the new dimension based 131 // on the index returned by `controlSplitReductionFn`. 132 SmallVector<int64_t> newOutputShape; 133 AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0)); 134 ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0)); 135 SmallVector<AffineExpr> outputExpr; 136 for (unsigned idx : 137 llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) { 138 if (idx == insertDimIndex) { 139 newOutputShape.push_back(ratio); 140 outputExpr.push_back(b.getAffineDimExpr(insertDimIndex)); 141 continue; 142 } 143 unsigned oldDim = idx < insertDimIndex ? idx : idx - 1; 144 newOutputShape.push_back(oldShape[oldDim]); 145 unsigned dim = oldOutputMap.getDimPosition(oldDim); 146 outputExpr.push_back( 147 b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); 148 } 149 Value initTensor = b.create<linalg::InitTensorOp>( 150 loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); 151 Value constantOp = b.create<arith::ConstantOp>(loc, *identity); 152 Value identityTensor = 153 b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor) 154 .getResult(0); 155 156 newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, 157 op.getContext())); 158 SmallVector<StringRef> newIteratorTypes; 159 for (auto &it : llvm::enumerate(op.iterator_types())) { 160 if (insertDimIndex == it.index()) 161 newIteratorTypes.push_back(getParallelIteratorTypeName()); 162 newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue()); 163 } 164 // Create the new op matching the original op with an extra parallel 165 // dimension. 166 GenericOp genericOp = b.create<GenericOp>( 167 loc, TypeRange({initTensor.getType()}), newInputs, 168 ValueRange({identityTensor}), newMaps, newIteratorTypes); 169 b.inlineRegionBefore(op->getRegion(0), genericOp.region(), 170 genericOp.region().begin()); 171 172 // Then create a new reduction that only reduce the newly added dimension from 173 // the previous op. 174 unsigned intermRank = newOutputShape.size(); 175 AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); 176 SmallVector<Value> outputOperands = op.getOutputOperands(); 177 SmallVector<StringRef> reductionIteratorTypes; 178 SmallVector<AffineExpr> exprs; 179 for (unsigned i : llvm::seq<unsigned>(0, intermRank)) { 180 if (insertDimIndex == i) { 181 reductionIteratorTypes.push_back(getReductionIteratorTypeName()); 182 } else { 183 exprs.push_back(b.getAffineDimExpr(i)); 184 reductionIteratorTypes.push_back(getParallelIteratorTypeName()); 185 } 186 } 187 AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext()); 188 SmallVector<AffineMap> reductionMaps = {inputMap, outputMap}; 189 190 auto reduction = b.create<GenericOp>( 191 loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}), 192 outputOperands, reductionMaps, reductionIteratorTypes, 193 [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { 194 Operation *clonedReductionOp = b.clone(*reductionOp); 195 clonedReductionOp->setOperand(0, inputs[0]); 196 clonedReductionOp->setOperand(1, inputs[1]); 197 b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); 198 }); 199 b.replaceOp(op, reduction.getResults()); 200 filter.replaceLinalgTransformationFilter(b, genericOp); 201 filter.replaceLinalgTransformationFilter(b, reduction); 202 return cast<LinalgOp>(genericOp.getOperation()); 203 } 204 205 namespace { 206 207 struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> { 208 /// Construct a generic pattern applied to all LinalgOp that verify `filter`. 209 LinalgSplitReduction(MLIRContext *context, 210 ControlSplitReductionFn controlSplitReductionFn, 211 LinalgTransformationFilter f, PatternBenefit benefit = 1) 212 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 213 controlSplitReductionFn(std::move(controlSplitReductionFn)), 214 filter(std::move(f)) {} 215 216 LogicalResult matchAndRewrite(LinalgOp op, 217 PatternRewriter &rewriter) const override { 218 return splitReduction(rewriter, op, controlSplitReductionFn, filter); 219 } 220 221 private: 222 ControlSplitReductionFn controlSplitReductionFn; 223 LinalgTransformationFilter filter; 224 }; 225 226 } // namespace 227 228 void linalg::populateSplitReductionPattern( 229 RewritePatternSet &patterns, 230 const ControlSplitReductionFn &controlSplitReductionFn, 231 const LinalgTransformationFilter &f) { 232 patterns.add<LinalgSplitReduction>(patterns.getContext(), 233 controlSplitReductionFn, f); 234 } 235