1 //===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===// 2 // 3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM 4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information. 5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// This file implements target-independent rewrites of MultiDimReductionOp. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 15 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/ImplicitLocOpBuilder.h" 18 #include "mlir/IR/TypeUtilities.h" 19 20 #define DEBUG_TYPE "vector-multi-reduction" 21 22 using namespace mlir; 23 24 /// This file implements the following transformations as composable atomic 25 /// patterns. 26 27 /// Converts vector.multi_reduction into inner-most/outer-most reduction form 28 /// by using vector.transpose 29 class InnerOuterDimReductionConversion 30 : public OpRewritePattern<vector::MultiDimReductionOp> { 31 public: 32 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 33 34 explicit InnerOuterDimReductionConversion( 35 MLIRContext *context, vector::VectorMultiReductionLowering options) 36 : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context), 37 useInnerDimsForReduction( 38 options == vector::VectorMultiReductionLowering::InnerReduction) {} 39 40 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 41 PatternRewriter &rewriter) const override { 42 auto src = multiReductionOp.getSource(); 43 auto loc = multiReductionOp.getLoc(); 44 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 45 46 // Separate reduction and parallel dims 47 auto reductionDimsRange = 48 multiReductionOp.getReductionDims().getAsValueRange<IntegerAttr>(); 49 auto reductionDims = llvm::to_vector<4>(llvm::map_range( 50 reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); })); 51 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(), 52 reductionDims.end()); 53 int64_t reductionSize = reductionDims.size(); 54 SmallVector<int64_t, 4> parallelDims; 55 for (int64_t i = 0; i < srcRank; ++i) 56 if (!reductionDimsSet.contains(i)) 57 parallelDims.push_back(i); 58 59 // Add transpose only if inner-most/outer-most dimensions are not parallel 60 // and there are parallel dims. 61 if (parallelDims.empty()) 62 return failure(); 63 if (useInnerDimsForReduction && 64 (parallelDims == 65 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size())))) 66 return failure(); 67 68 if (!useInnerDimsForReduction && 69 (parallelDims != 70 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size())))) 71 return failure(); 72 73 SmallVector<int64_t, 4> indices; 74 if (useInnerDimsForReduction) { 75 indices.append(parallelDims.begin(), parallelDims.end()); 76 indices.append(reductionDims.begin(), reductionDims.end()); 77 } else { 78 indices.append(reductionDims.begin(), reductionDims.end()); 79 indices.append(parallelDims.begin(), parallelDims.end()); 80 } 81 auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices); 82 SmallVector<bool> reductionMask(srcRank, false); 83 for (int i = 0; i < reductionSize; ++i) { 84 if (useInnerDimsForReduction) 85 reductionMask[srcRank - i - 1] = true; 86 else 87 reductionMask[i] = true; 88 } 89 rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>( 90 multiReductionOp, transposeOp.getResult(), multiReductionOp.getAcc(), 91 reductionMask, multiReductionOp.getKind()); 92 return success(); 93 } 94 95 private: 96 const bool useInnerDimsForReduction; 97 }; 98 99 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction 100 /// dimensions are either inner most or outer most. 101 class ReduceMultiDimReductionRank 102 : public OpRewritePattern<vector::MultiDimReductionOp> { 103 public: 104 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 105 106 explicit ReduceMultiDimReductionRank( 107 MLIRContext *context, vector::VectorMultiReductionLowering options) 108 : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context), 109 useInnerDimsForReduction( 110 options == vector::VectorMultiReductionLowering::InnerReduction) {} 111 112 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 113 PatternRewriter &rewriter) const override { 114 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 115 auto srcShape = multiReductionOp.getSourceVectorType().getShape(); 116 auto loc = multiReductionOp.getLoc(); 117 118 // If rank less than 2, nothing to do. 119 if (srcRank < 2) 120 return failure(); 121 122 // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail. 123 SmallVector<bool> reductionMask = multiReductionOp.getReductionMask(); 124 if (srcRank == 2 && reductionMask.front() != reductionMask.back()) 125 return failure(); 126 127 // 1. Separate reduction and parallel dims. 128 SmallVector<int64_t, 4> parallelDims, parallelShapes; 129 SmallVector<int64_t, 4> reductionDims, reductionShapes; 130 for (const auto &it : llvm::enumerate(reductionMask)) { 131 int64_t i = it.index(); 132 bool isReduction = it.value(); 133 if (isReduction) { 134 reductionDims.push_back(i); 135 reductionShapes.push_back(srcShape[i]); 136 } else { 137 parallelDims.push_back(i); 138 parallelShapes.push_back(srcShape[i]); 139 } 140 } 141 142 // 2. Compute flattened parallel and reduction sizes. 143 int flattenedParallelDim = 0; 144 int flattenedReductionDim = 0; 145 if (!parallelShapes.empty()) { 146 flattenedParallelDim = 1; 147 for (auto d : parallelShapes) 148 flattenedParallelDim *= d; 149 } 150 if (!reductionShapes.empty()) { 151 flattenedReductionDim = 1; 152 for (auto d : reductionShapes) 153 flattenedReductionDim *= d; 154 } 155 // We must at least have some parallel or some reduction. 156 assert((flattenedParallelDim || flattenedReductionDim) && 157 "expected at least one parallel or reduction dim"); 158 159 // 3. Fail if reduction/parallel dims are not contiguous. 160 // Check parallelDims are exactly [0 .. size). 161 int64_t counter = 0; 162 if (useInnerDimsForReduction && 163 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) 164 return failure(); 165 // Check parallelDims are exactly {reductionDims.size()} + [0 .. size). 166 counter = reductionDims.size(); 167 if (!useInnerDimsForReduction && 168 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) 169 return failure(); 170 171 // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into 172 // a single parallel (resp. reduction) dim. 173 SmallVector<bool, 2> mask; 174 SmallVector<int64_t, 2> vectorShape; 175 if (flattenedParallelDim) { 176 mask.push_back(false); 177 vectorShape.push_back(flattenedParallelDim); 178 } 179 if (flattenedReductionDim) { 180 mask.push_back(true); 181 vectorShape.push_back(flattenedReductionDim); 182 } 183 if (!useInnerDimsForReduction && vectorShape.size() == 2) { 184 std::swap(mask.front(), mask.back()); 185 std::swap(vectorShape.front(), vectorShape.back()); 186 } 187 auto castedType = VectorType::get( 188 vectorShape, multiReductionOp.getSourceVectorType().getElementType()); 189 Value cast = rewriter.create<vector::ShapeCastOp>( 190 loc, castedType, multiReductionOp.getSource()); 191 Value acc = multiReductionOp.getAcc(); 192 if (flattenedParallelDim) { 193 auto accType = VectorType::get( 194 {flattenedParallelDim}, 195 multiReductionOp.getSourceVectorType().getElementType()); 196 acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc); 197 } 198 // 5. Creates the flattened form of vector.multi_reduction with inner/outer 199 // most dim as reduction. 200 auto newOp = rewriter.create<vector::MultiDimReductionOp>( 201 loc, cast, acc, mask, multiReductionOp.getKind()); 202 203 // 6. If there are no parallel shapes, the result is a scalar. 204 // TODO: support 0-d vectors when available. 205 if (parallelShapes.empty()) { 206 rewriter.replaceOp(multiReductionOp, newOp.getDest()); 207 return success(); 208 } 209 210 // 7. Creates shape cast for the output n-D -> 2-D 211 VectorType outputCastedType = VectorType::get( 212 parallelShapes, 213 multiReductionOp.getSourceVectorType().getElementType()); 214 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 215 multiReductionOp, outputCastedType, newOp.getDest()); 216 return success(); 217 } 218 219 private: 220 const bool useInnerDimsForReduction; 221 }; 222 223 /// Unrolls vector.multi_reduction with outermost reductions 224 /// and combines results 225 struct TwoDimMultiReductionToElementWise 226 : public OpRewritePattern<vector::MultiDimReductionOp> { 227 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 228 229 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 230 PatternRewriter &rewriter) const override { 231 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 232 // Rank-2 ["parallel", "reduce"] or bail. 233 if (srcRank != 2) 234 return failure(); 235 236 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0)) 237 return failure(); 238 239 auto loc = multiReductionOp.getLoc(); 240 ArrayRef<int64_t> srcShape = 241 multiReductionOp.getSourceVectorType().getShape(); 242 243 Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType()); 244 if (!elementType.isIntOrIndexOrFloat()) 245 return failure(); 246 247 Value result = multiReductionOp.getAcc(); 248 for (int64_t i = 0; i < srcShape[0]; i++) { 249 auto operand = rewriter.create<vector::ExtractOp>( 250 loc, multiReductionOp.getSource(), i); 251 result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(), 252 operand, result); 253 } 254 255 rewriter.replaceOp(multiReductionOp, result); 256 return success(); 257 } 258 }; 259 260 /// Converts 2d vector.multi_reduction with inner most reduction dimension into 261 /// a sequence of vector.reduction ops. 262 struct TwoDimMultiReductionToReduction 263 : public OpRewritePattern<vector::MultiDimReductionOp> { 264 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 265 266 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 267 PatternRewriter &rewriter) const override { 268 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 269 if (srcRank != 2) 270 return failure(); 271 272 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1)) 273 return failure(); 274 275 auto loc = multiReductionOp.getLoc(); 276 Value result = rewriter.create<arith::ConstantOp>( 277 loc, multiReductionOp.getDestType(), 278 rewriter.getZeroAttr(multiReductionOp.getDestType())); 279 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0]; 280 281 for (int i = 0; i < outerDim; ++i) { 282 auto v = rewriter.create<vector::ExtractOp>( 283 loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i}); 284 auto acc = rewriter.create<vector::ExtractOp>( 285 loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i}); 286 auto reducedValue = rewriter.create<vector::ReductionOp>( 287 loc, multiReductionOp.getKind(), v, acc); 288 result = rewriter.create<vector::InsertElementOp>( 289 loc, reducedValue, result, 290 rewriter.create<arith::ConstantIndexOp>(loc, i)); 291 } 292 rewriter.replaceOp(multiReductionOp, result); 293 return success(); 294 } 295 }; 296 297 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d 298 /// form with both a single parallel and reduction dimension. 299 /// This is achieved with a simple vector.shape_cast that inserts a leading 1. 300 /// The case with a single parallel dimension is a noop and folds away 301 /// separately. 302 struct OneDimMultiReductionToTwoDim 303 : public OpRewritePattern<vector::MultiDimReductionOp> { 304 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 305 306 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 307 PatternRewriter &rewriter) const override { 308 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 309 // Rank-1 or bail. 310 if (srcRank != 1) 311 return failure(); 312 313 auto loc = multiReductionOp.getLoc(); 314 auto srcVectorType = multiReductionOp.getSourceVectorType(); 315 auto srcShape = srcVectorType.getShape(); 316 auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()}, 317 srcVectorType.getElementType()); 318 auto accType = 319 VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType()); 320 assert(!multiReductionOp.getDestType().isa<VectorType>() && 321 "multi_reduction with a single dimension expects a scalar result"); 322 323 // If the unique dim is reduced and we insert a parallel in front, we need a 324 // {false, true} mask. 325 SmallVector<bool, 2> mask{false, true}; 326 327 /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0) 328 Value cast = rewriter.create<vector::ShapeCastOp>( 329 loc, castedType, multiReductionOp.getSource()); 330 Value castAcc = rewriter.create<vector::BroadcastOp>( 331 loc, accType, multiReductionOp.getAcc()); 332 Value reduced = rewriter.create<vector::MultiDimReductionOp>( 333 loc, cast, castAcc, mask, multiReductionOp.getKind()); 334 rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced, 335 ArrayRef<int64_t>{0}); 336 return success(); 337 } 338 }; 339 340 void mlir::vector::populateVectorMultiReductionLoweringPatterns( 341 RewritePatternSet &patterns, VectorMultiReductionLowering options) { 342 patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>( 343 patterns.getContext(), options); 344 patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext()); 345 if (options == VectorMultiReductionLowering ::InnerReduction) 346 patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext()); 347 else 348 patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext()); 349 } 350