1 //===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===// 2 // 3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM 4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information. 5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// This file implements target-independent rewrites of MultiDimReductionOp. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 15 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/ImplicitLocOpBuilder.h" 18 #include "mlir/IR/TypeUtilities.h" 19 20 #define DEBUG_TYPE "vector-multi-reduction" 21 22 using namespace mlir; 23 24 /// This file implements the following transformations as composable atomic 25 /// patterns. 26 27 /// Converts vector.multi_reduction into inner-most/outer-most reduction form 28 /// by using vector.transpose 29 class InnerOuterDimReductionConversion 30 : public OpRewritePattern<vector::MultiDimReductionOp> { 31 public: 32 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 33 34 explicit InnerOuterDimReductionConversion( 35 MLIRContext *context, vector::VectorMultiReductionLowering options) 36 : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context), 37 useInnerDimsForReduction( 38 options == vector::VectorMultiReductionLowering::InnerReduction) {} 39 40 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 41 PatternRewriter &rewriter) const override { 42 auto src = multiReductionOp.getSource(); 43 auto loc = multiReductionOp.getLoc(); 44 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 45 46 // Separate reduction and parallel dims 47 auto reductionDimsRange = 48 multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>(); 49 auto reductionDims = llvm::to_vector<4>(llvm::map_range( 50 reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); })); 51 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(), 52 reductionDims.end()); 53 int64_t reductionSize = reductionDims.size(); 54 SmallVector<int64_t, 4> parallelDims; 55 for (int64_t i = 0; i < srcRank; ++i) 56 if (!reductionDimsSet.contains(i)) 57 parallelDims.push_back(i); 58 59 // Add transpose only if inner-most/outer-most dimensions are not parallel 60 // and there are parallel dims. 61 if (parallelDims.empty()) 62 return failure(); 63 if (useInnerDimsForReduction && 64 (parallelDims == 65 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size())))) 66 return failure(); 67 68 if (!useInnerDimsForReduction && 69 (parallelDims != 70 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size())))) 71 return failure(); 72 73 SmallVector<int64_t, 4> indices; 74 if (useInnerDimsForReduction) { 75 indices.append(parallelDims.begin(), parallelDims.end()); 76 indices.append(reductionDims.begin(), reductionDims.end()); 77 } else { 78 indices.append(reductionDims.begin(), reductionDims.end()); 79 indices.append(parallelDims.begin(), parallelDims.end()); 80 } 81 auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices); 82 SmallVector<bool> reductionMask(srcRank, false); 83 for (int i = 0; i < reductionSize; ++i) { 84 if (useInnerDimsForReduction) 85 reductionMask[srcRank - i - 1] = true; 86 else 87 reductionMask[i] = true; 88 } 89 rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>( 90 multiReductionOp, transposeOp.getResult(), reductionMask, 91 multiReductionOp.getKind()); 92 return success(); 93 } 94 95 private: 96 const bool useInnerDimsForReduction; 97 }; 98 99 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction 100 /// dimensions are either inner most or outer most. 101 class ReduceMultiDimReductionRank 102 : public OpRewritePattern<vector::MultiDimReductionOp> { 103 public: 104 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 105 106 explicit ReduceMultiDimReductionRank( 107 MLIRContext *context, vector::VectorMultiReductionLowering options) 108 : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context), 109 useInnerDimsForReduction( 110 options == vector::VectorMultiReductionLowering::InnerReduction) {} 111 112 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 113 PatternRewriter &rewriter) const override { 114 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 115 auto srcShape = multiReductionOp.getSourceVectorType().getShape(); 116 auto loc = multiReductionOp.getLoc(); 117 118 // If rank less than 2, nothing to do. 119 if (srcRank < 2) 120 return failure(); 121 122 // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail. 123 SmallVector<bool> reductionMask = multiReductionOp.getReductionMask(); 124 if (srcRank == 2 && reductionMask.front() != reductionMask.back()) 125 return failure(); 126 127 // 1. Separate reduction and parallel dims. 128 SmallVector<int64_t, 4> parallelDims, parallelShapes; 129 SmallVector<int64_t, 4> reductionDims, reductionShapes; 130 for (const auto &it : llvm::enumerate(reductionMask)) { 131 int64_t i = it.index(); 132 bool isReduction = it.value(); 133 if (isReduction) { 134 reductionDims.push_back(i); 135 reductionShapes.push_back(srcShape[i]); 136 } else { 137 parallelDims.push_back(i); 138 parallelShapes.push_back(srcShape[i]); 139 } 140 } 141 142 // 2. Compute flattened parallel and reduction sizes. 143 int flattenedParallelDim = 0; 144 int flattenedReductionDim = 0; 145 if (!parallelShapes.empty()) { 146 flattenedParallelDim = 1; 147 for (auto d : parallelShapes) 148 flattenedParallelDim *= d; 149 } 150 if (!reductionShapes.empty()) { 151 flattenedReductionDim = 1; 152 for (auto d : reductionShapes) 153 flattenedReductionDim *= d; 154 } 155 // We must at least have some parallel or some reduction. 156 assert((flattenedParallelDim || flattenedReductionDim) && 157 "expected at least one parallel or reduction dim"); 158 159 // 3. Fail if reduction/parallel dims are not contiguous. 160 // Check parallelDims are exactly [0 .. size). 161 int64_t counter = 0; 162 if (useInnerDimsForReduction && 163 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) 164 return failure(); 165 // Check parallelDims are exactly {reductionDims.size()} + [0 .. size). 166 counter = reductionDims.size(); 167 if (!useInnerDimsForReduction && 168 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) 169 return failure(); 170 171 // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into 172 // a single parallel (resp. reduction) dim. 173 SmallVector<bool, 2> mask; 174 SmallVector<int64_t, 2> vectorShape; 175 if (flattenedParallelDim) { 176 mask.push_back(false); 177 vectorShape.push_back(flattenedParallelDim); 178 } 179 if (flattenedReductionDim) { 180 mask.push_back(true); 181 vectorShape.push_back(flattenedReductionDim); 182 } 183 if (!useInnerDimsForReduction && vectorShape.size() == 2) { 184 std::swap(mask.front(), mask.back()); 185 std::swap(vectorShape.front(), vectorShape.back()); 186 } 187 auto castedType = VectorType::get( 188 vectorShape, multiReductionOp.getSourceVectorType().getElementType()); 189 Value cast = rewriter.create<vector::ShapeCastOp>( 190 loc, castedType, multiReductionOp.getSource()); 191 192 // 5. Creates the flattened form of vector.multi_reduction with inner/outer 193 // most dim as reduction. 194 auto newOp = rewriter.create<vector::MultiDimReductionOp>( 195 loc, cast, mask, multiReductionOp.getKind()); 196 197 // 6. If there are no parallel shapes, the result is a scalar. 198 // TODO: support 0-d vectors when available. 199 if (parallelShapes.empty()) { 200 rewriter.replaceOp(multiReductionOp, newOp.getDest()); 201 return success(); 202 } 203 204 // 7. Creates shape cast for the output n-D -> 2-D 205 VectorType outputCastedType = VectorType::get( 206 parallelShapes, 207 multiReductionOp.getSourceVectorType().getElementType()); 208 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 209 multiReductionOp, outputCastedType, newOp.getDest()); 210 return success(); 211 } 212 213 private: 214 const bool useInnerDimsForReduction; 215 }; 216 217 /// Unrolls vector.multi_reduction with outermost reductions 218 /// and combines results 219 struct TwoDimMultiReductionToElementWise 220 : public OpRewritePattern<vector::MultiDimReductionOp> { 221 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 222 223 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 224 PatternRewriter &rewriter) const override { 225 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 226 // Rank-2 ["parallel", "reduce"] or bail. 227 if (srcRank != 2) 228 return failure(); 229 230 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0)) 231 return failure(); 232 233 auto loc = multiReductionOp.getLoc(); 234 ArrayRef<int64_t> srcShape = 235 multiReductionOp.getSourceVectorType().getShape(); 236 237 Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType()); 238 if (!elementType.isIntOrIndexOrFloat()) 239 return failure(); 240 241 Value result = 242 rewriter.create<vector::ExtractOp>(loc, multiReductionOp.getSource(), 0) 243 .getResult(); 244 for (int64_t i = 1; i < srcShape[0]; i++) { 245 auto operand = rewriter.create<vector::ExtractOp>( 246 loc, multiReductionOp.getSource(), i); 247 result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(), 248 operand, result); 249 } 250 251 rewriter.replaceOp(multiReductionOp, result); 252 return success(); 253 } 254 }; 255 256 /// Converts 2d vector.multi_reduction with inner most reduction dimension into 257 /// a sequence of vector.reduction ops. 258 struct TwoDimMultiReductionToReduction 259 : public OpRewritePattern<vector::MultiDimReductionOp> { 260 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 261 262 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 263 PatternRewriter &rewriter) const override { 264 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 265 if (srcRank != 2) 266 return failure(); 267 268 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1)) 269 return failure(); 270 271 auto loc = multiReductionOp.getLoc(); 272 Value result = rewriter.create<arith::ConstantOp>( 273 loc, multiReductionOp.getDestType(), 274 rewriter.getZeroAttr(multiReductionOp.getDestType())); 275 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0]; 276 277 for (int i = 0; i < outerDim; ++i) { 278 auto v = rewriter.create<vector::ExtractOp>( 279 loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i}); 280 auto reducedValue = rewriter.create<vector::ReductionOp>( 281 loc, multiReductionOp.getKind(), v); 282 result = rewriter.create<vector::InsertElementOp>( 283 loc, reducedValue, result, 284 rewriter.create<arith::ConstantIndexOp>(loc, i)); 285 } 286 rewriter.replaceOp(multiReductionOp, result); 287 return success(); 288 } 289 }; 290 291 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d 292 /// form with both a single parallel and reduction dimension. 293 /// This is achieved with a simple vector.shape_cast that inserts a leading 1. 294 /// The case with a single parallel dimension is a noop and folds away 295 /// separately. 296 struct OneDimMultiReductionToTwoDim 297 : public OpRewritePattern<vector::MultiDimReductionOp> { 298 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 299 300 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 301 PatternRewriter &rewriter) const override { 302 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 303 // Rank-1 or bail. 304 if (srcRank != 1) 305 return failure(); 306 307 auto loc = multiReductionOp.getLoc(); 308 auto srcVectorType = multiReductionOp.getSourceVectorType(); 309 auto srcShape = srcVectorType.getShape(); 310 auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()}, 311 srcVectorType.getElementType()); 312 assert(!multiReductionOp.getDestType().isa<VectorType>() && 313 "multi_reduction with a single dimension expects a scalar result"); 314 315 // If the unique dim is reduced and we insert a parallel in front, we need a 316 // {false, true} mask. 317 SmallVector<bool, 2> mask{false, true}; 318 319 /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0) 320 Value cast = rewriter.create<vector::ShapeCastOp>( 321 loc, castedType, multiReductionOp.getSource()); 322 Value reduced = rewriter.create<vector::MultiDimReductionOp>( 323 loc, cast, mask, multiReductionOp.getKind()); 324 rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced, 325 ArrayRef<int64_t>{0}); 326 return success(); 327 } 328 }; 329 330 void mlir::vector::populateVectorMultiReductionLoweringPatterns( 331 RewritePatternSet &patterns, VectorMultiReductionLowering options) { 332 patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>( 333 patterns.getContext(), options); 334 patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext()); 335 if (options == VectorMultiReductionLowering ::InnerReduction) 336 patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext()); 337 else 338 patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext()); 339 } 340