1 //===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===// 2 // 3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM 4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information. 5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// This file implements target-independent rewrites of MultiDimReductionOp. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/ImplicitLocOpBuilder.h" 17 #include "mlir/IR/TypeUtilities.h" 18 19 #define DEBUG_TYPE "vector-multi-reduction" 20 21 using namespace mlir; 22 23 /// This file implements the following transformations as composable atomic 24 /// patterns. 25 26 /// Converts vector.multi_reduction into inner-most/outer-most reduction form 27 /// by using vector.transpose 28 class InnerOuterDimReductionConversion 29 : public OpRewritePattern<vector::MultiDimReductionOp> { 30 public: 31 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 32 33 explicit InnerOuterDimReductionConversion( 34 MLIRContext *context, vector::VectorMultiReductionLowering options) 35 : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context), 36 useInnerDimsForReduction( 37 options == vector::VectorMultiReductionLowering::InnerReduction) {} 38 39 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 40 PatternRewriter &rewriter) const override { 41 auto src = multiReductionOp.getSource(); 42 auto loc = multiReductionOp.getLoc(); 43 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 44 45 // Separate reduction and parallel dims 46 auto reductionDimsRange = 47 multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>(); 48 auto reductionDims = llvm::to_vector<4>(llvm::map_range( 49 reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); })); 50 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(), 51 reductionDims.end()); 52 int64_t reductionSize = reductionDims.size(); 53 SmallVector<int64_t, 4> parallelDims; 54 for (int64_t i = 0; i < srcRank; ++i) 55 if (!reductionDimsSet.contains(i)) 56 parallelDims.push_back(i); 57 58 // Add transpose only if inner-most/outer-most dimensions are not parallel 59 // and there are parallel dims. 60 if (parallelDims.empty()) 61 return failure(); 62 if (useInnerDimsForReduction && 63 (parallelDims == 64 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size())))) 65 return failure(); 66 67 if (!useInnerDimsForReduction && 68 (parallelDims != 69 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size())))) 70 return failure(); 71 72 SmallVector<int64_t, 4> indices; 73 if (useInnerDimsForReduction) { 74 indices.append(parallelDims.begin(), parallelDims.end()); 75 indices.append(reductionDims.begin(), reductionDims.end()); 76 } else { 77 indices.append(reductionDims.begin(), reductionDims.end()); 78 indices.append(parallelDims.begin(), parallelDims.end()); 79 } 80 auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices); 81 SmallVector<bool> reductionMask(srcRank, false); 82 for (int i = 0; i < reductionSize; ++i) { 83 if (useInnerDimsForReduction) 84 reductionMask[srcRank - i - 1] = true; 85 else 86 reductionMask[i] = true; 87 } 88 rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>( 89 multiReductionOp, transposeOp.getResult(), reductionMask, 90 multiReductionOp.getKind()); 91 return success(); 92 } 93 94 private: 95 const bool useInnerDimsForReduction; 96 }; 97 98 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction 99 /// dimensions are either inner most or outer most. 100 class ReduceMultiDimReductionRank 101 : public OpRewritePattern<vector::MultiDimReductionOp> { 102 public: 103 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 104 105 explicit ReduceMultiDimReductionRank( 106 MLIRContext *context, vector::VectorMultiReductionLowering options) 107 : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context), 108 useInnerDimsForReduction( 109 options == vector::VectorMultiReductionLowering::InnerReduction) {} 110 111 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 112 PatternRewriter &rewriter) const override { 113 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 114 auto srcShape = multiReductionOp.getSourceVectorType().getShape(); 115 auto loc = multiReductionOp.getLoc(); 116 117 // If rank less than 2, nothing to do. 118 if (srcRank < 2) 119 return failure(); 120 121 // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail. 122 SmallVector<bool> reductionMask = multiReductionOp.getReductionMask(); 123 if (srcRank == 2 && reductionMask.front() != reductionMask.back()) 124 return failure(); 125 126 // 1. Separate reduction and parallel dims. 127 SmallVector<int64_t, 4> parallelDims, parallelShapes; 128 SmallVector<int64_t, 4> reductionDims, reductionShapes; 129 for (const auto &it : llvm::enumerate(reductionMask)) { 130 int64_t i = it.index(); 131 bool isReduction = it.value(); 132 if (isReduction) { 133 reductionDims.push_back(i); 134 reductionShapes.push_back(srcShape[i]); 135 } else { 136 parallelDims.push_back(i); 137 parallelShapes.push_back(srcShape[i]); 138 } 139 } 140 141 // 2. Compute flattened parallel and reduction sizes. 142 int flattenedParallelDim = 0; 143 int flattenedReductionDim = 0; 144 if (!parallelShapes.empty()) { 145 flattenedParallelDim = 1; 146 for (auto d : parallelShapes) 147 flattenedParallelDim *= d; 148 } 149 if (!reductionShapes.empty()) { 150 flattenedReductionDim = 1; 151 for (auto d : reductionShapes) 152 flattenedReductionDim *= d; 153 } 154 // We must at least have some parallel or some reduction. 155 assert((flattenedParallelDim || flattenedReductionDim) && 156 "expected at least one parallel or reduction dim"); 157 158 // 3. Fail if reduction/parallel dims are not contiguous. 159 // Check parallelDims are exactly [0 .. size). 160 int64_t counter = 0; 161 if (useInnerDimsForReduction && 162 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) 163 return failure(); 164 // Check parallelDims are exactly {reductionDims.size()} + [0 .. size). 165 counter = reductionDims.size(); 166 if (!useInnerDimsForReduction && 167 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) 168 return failure(); 169 170 // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into 171 // a single parallel (resp. reduction) dim. 172 SmallVector<bool, 2> mask; 173 SmallVector<int64_t, 2> vectorShape; 174 if (flattenedParallelDim) { 175 mask.push_back(false); 176 vectorShape.push_back(flattenedParallelDim); 177 } 178 if (flattenedReductionDim) { 179 mask.push_back(true); 180 vectorShape.push_back(flattenedReductionDim); 181 } 182 if (!useInnerDimsForReduction && vectorShape.size() == 2) { 183 std::swap(mask.front(), mask.back()); 184 std::swap(vectorShape.front(), vectorShape.back()); 185 } 186 auto castedType = VectorType::get( 187 vectorShape, multiReductionOp.getSourceVectorType().getElementType()); 188 Value cast = rewriter.create<vector::ShapeCastOp>( 189 loc, castedType, multiReductionOp.getSource()); 190 191 // 5. Creates the flattened form of vector.multi_reduction with inner/outer 192 // most dim as reduction. 193 auto newOp = rewriter.create<vector::MultiDimReductionOp>( 194 loc, cast, mask, multiReductionOp.getKind()); 195 196 // 6. If there are no parallel shapes, the result is a scalar. 197 // TODO: support 0-d vectors when available. 198 if (parallelShapes.empty()) { 199 rewriter.replaceOp(multiReductionOp, newOp.getDest()); 200 return success(); 201 } 202 203 // 7. Creates shape cast for the output n-D -> 2-D 204 VectorType outputCastedType = VectorType::get( 205 parallelShapes, 206 multiReductionOp.getSourceVectorType().getElementType()); 207 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 208 multiReductionOp, outputCastedType, newOp.getDest()); 209 return success(); 210 } 211 212 private: 213 const bool useInnerDimsForReduction; 214 }; 215 216 /// Unrolls vector.multi_reduction with outermost reductions 217 /// and combines results 218 struct TwoDimMultiReductionToElementWise 219 : public OpRewritePattern<vector::MultiDimReductionOp> { 220 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 221 222 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 223 PatternRewriter &rewriter) const override { 224 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 225 // Rank-2 ["parallel", "reduce"] or bail. 226 if (srcRank != 2) 227 return failure(); 228 229 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0)) 230 return failure(); 231 232 auto loc = multiReductionOp.getLoc(); 233 ArrayRef<int64_t> srcShape = 234 multiReductionOp.getSourceVectorType().getShape(); 235 236 Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType()); 237 if (!elementType.isIntOrIndexOrFloat()) 238 return failure(); 239 240 Value result = 241 rewriter.create<vector::ExtractOp>(loc, multiReductionOp.getSource(), 0) 242 .getResult(); 243 for (int64_t i = 1; i < srcShape[0]; i++) { 244 auto operand = rewriter.create<vector::ExtractOp>( 245 loc, multiReductionOp.getSource(), i); 246 result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(), 247 operand, result); 248 } 249 250 rewriter.replaceOp(multiReductionOp, result); 251 return success(); 252 } 253 }; 254 255 /// Converts 2d vector.multi_reduction with inner most reduction dimension into 256 /// a sequence of vector.reduction ops. 257 struct TwoDimMultiReductionToReduction 258 : public OpRewritePattern<vector::MultiDimReductionOp> { 259 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 260 261 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 262 PatternRewriter &rewriter) const override { 263 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 264 if (srcRank != 2) 265 return failure(); 266 267 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1)) 268 return failure(); 269 270 auto loc = multiReductionOp.getLoc(); 271 Value result = rewriter.create<arith::ConstantOp>( 272 loc, multiReductionOp.getDestType(), 273 rewriter.getZeroAttr(multiReductionOp.getDestType())); 274 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0]; 275 276 for (int i = 0; i < outerDim; ++i) { 277 auto v = rewriter.create<vector::ExtractOp>( 278 loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i}); 279 auto reducedValue = rewriter.create<vector::ReductionOp>( 280 loc, multiReductionOp.getKind(), v); 281 result = rewriter.create<vector::InsertElementOp>( 282 loc, reducedValue, result, 283 rewriter.create<arith::ConstantIndexOp>(loc, i)); 284 } 285 rewriter.replaceOp(multiReductionOp, result); 286 return success(); 287 } 288 }; 289 290 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d 291 /// form with both a single parallel and reduction dimension. 292 /// This is achieved with a simple vector.shape_cast that inserts a leading 1. 293 /// The case with a single parallel dimension is a noop and folds away 294 /// separately. 295 struct OneDimMultiReductionToTwoDim 296 : public OpRewritePattern<vector::MultiDimReductionOp> { 297 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 298 299 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 300 PatternRewriter &rewriter) const override { 301 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 302 // Rank-1 or bail. 303 if (srcRank != 1) 304 return failure(); 305 306 auto loc = multiReductionOp.getLoc(); 307 auto srcVectorType = multiReductionOp.getSourceVectorType(); 308 auto srcShape = srcVectorType.getShape(); 309 auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()}, 310 srcVectorType.getElementType()); 311 assert(!multiReductionOp.getDestType().isa<VectorType>() && 312 "multi_reduction with a single dimension expects a scalar result"); 313 314 // If the unique dim is reduced and we insert a parallel in front, we need a 315 // {false, true} mask. 316 SmallVector<bool, 2> mask{false, true}; 317 318 /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0) 319 Value cast = rewriter.create<vector::ShapeCastOp>( 320 loc, castedType, multiReductionOp.getSource()); 321 Value reduced = rewriter.create<vector::MultiDimReductionOp>( 322 loc, cast, mask, multiReductionOp.getKind()); 323 rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced, 324 ArrayRef<int64_t>{0}); 325 return success(); 326 } 327 }; 328 329 void mlir::vector::populateVectorMultiReductionLoweringPatterns( 330 RewritePatternSet &patterns, VectorMultiReductionLowering options) { 331 patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>( 332 patterns.getContext(), options); 333 patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext()); 334 if (options == VectorMultiReductionLowering ::InnerReduction) 335 patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext()); 336 else 337 patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext()); 338 } 339