1 //===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===// 2 // 3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM 4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information. 5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// This file implements target-independent rewrites of MultiDimReductionOp. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/ImplicitLocOpBuilder.h" 17 #include "mlir/IR/TypeUtilities.h" 18 19 #define DEBUG_TYPE "vector-multi-reduction" 20 21 using namespace mlir; 22 23 /// This file implements the following transformations as composable atomic 24 /// patterns. 25 26 /// Converts vector.multi_reduction into inner-most/outer-most reduction form 27 /// by using vector.transpose 28 class InnerOuterDimReductionConversion 29 : public OpRewritePattern<vector::MultiDimReductionOp> { 30 public: 31 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 32 33 explicit InnerOuterDimReductionConversion( 34 MLIRContext *context, vector::VectorMultiReductionLowering options) 35 : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context), 36 useInnerDimsForReduction( 37 options == vector::VectorMultiReductionLowering::InnerReduction) {} 38 39 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 40 PatternRewriter &rewriter) const override { 41 auto src = multiReductionOp.source(); 42 auto loc = multiReductionOp.getLoc(); 43 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 44 45 // Separate reduction and parallel dims 46 auto reductionDimsRange = 47 multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>(); 48 auto reductionDims = llvm::to_vector<4>(llvm::map_range( 49 reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); })); 50 llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(), 51 reductionDims.end()); 52 int64_t reductionSize = reductionDims.size(); 53 SmallVector<int64_t, 4> parallelDims; 54 for (int64_t i = 0; i < srcRank; ++i) 55 if (!reductionDimsSet.contains(i)) 56 parallelDims.push_back(i); 57 58 // Add transpose only if inner-most/outer-most dimensions are not parallel 59 // and there are parallel dims. 60 if (parallelDims.empty()) 61 return failure(); 62 if (useInnerDimsForReduction && 63 (parallelDims == 64 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size())))) 65 return failure(); 66 67 if (!useInnerDimsForReduction && 68 (parallelDims != 69 llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size())))) 70 return failure(); 71 72 SmallVector<int64_t, 4> indices; 73 if (useInnerDimsForReduction) { 74 indices.append(parallelDims.begin(), parallelDims.end()); 75 indices.append(reductionDims.begin(), reductionDims.end()); 76 } else { 77 indices.append(reductionDims.begin(), reductionDims.end()); 78 indices.append(parallelDims.begin(), parallelDims.end()); 79 } 80 auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices); 81 SmallVector<bool> reductionMask(srcRank, false); 82 for (int i = 0; i < reductionSize; ++i) { 83 if (useInnerDimsForReduction) 84 reductionMask[srcRank - i - 1] = true; 85 else 86 reductionMask[i] = true; 87 } 88 rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>( 89 multiReductionOp, transposeOp.result(), reductionMask, 90 multiReductionOp.kind()); 91 return success(); 92 } 93 94 private: 95 const bool useInnerDimsForReduction; 96 }; 97 98 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction 99 /// dimensions are either inner most or outer most. 100 class ReduceMultiDimReductionRank 101 : public OpRewritePattern<vector::MultiDimReductionOp> { 102 public: 103 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 104 105 explicit ReduceMultiDimReductionRank( 106 MLIRContext *context, vector::VectorMultiReductionLowering options) 107 : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context), 108 useInnerDimsForReduction( 109 options == vector::VectorMultiReductionLowering::InnerReduction) {} 110 111 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 112 PatternRewriter &rewriter) const override { 113 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 114 auto srcShape = multiReductionOp.getSourceVectorType().getShape(); 115 auto loc = multiReductionOp.getLoc(); 116 117 // If rank less than 2, nothing to do. 118 if (srcRank < 2) 119 return failure(); 120 121 // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail. 122 SmallVector<bool> reductionMask = multiReductionOp.getReductionMask(); 123 if (srcRank == 2 && reductionMask.front() != reductionMask.back()) 124 return failure(); 125 126 // 1. Separate reduction and parallel dims. 127 SmallVector<int64_t, 4> parallelDims, parallelShapes; 128 SmallVector<int64_t, 4> reductionDims, reductionShapes; 129 for (const auto &it : llvm::enumerate(reductionMask)) { 130 int64_t i = it.index(); 131 bool isReduction = it.value(); 132 if (isReduction) { 133 reductionDims.push_back(i); 134 reductionShapes.push_back(srcShape[i]); 135 } else { 136 parallelDims.push_back(i); 137 parallelShapes.push_back(srcShape[i]); 138 } 139 } 140 141 // 2. Compute flattened parallel and reduction sizes. 142 int flattenedParallelDim = 0; 143 int flattenedReductionDim = 0; 144 if (!parallelShapes.empty()) { 145 flattenedParallelDim = 1; 146 for (auto d : parallelShapes) 147 flattenedParallelDim *= d; 148 } 149 if (!reductionShapes.empty()) { 150 flattenedReductionDim = 1; 151 for (auto d : reductionShapes) 152 flattenedReductionDim *= d; 153 } 154 // We must at least have some parallel or some reduction. 155 assert((flattenedParallelDim || flattenedReductionDim) && 156 "expected at least one parallel or reduction dim"); 157 158 // 3. Fail if reduction/parallel dims are not contiguous. 159 // Check parallelDims are exactly [0 .. size). 160 int64_t counter = 0; 161 if (useInnerDimsForReduction && 162 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) 163 return failure(); 164 // Check parallelDims are exactly {reductionDims.size()} + [0 .. size). 165 counter = reductionDims.size(); 166 if (!useInnerDimsForReduction && 167 llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) 168 return failure(); 169 170 // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into 171 // a single parallel (resp. reduction) dim. 172 SmallVector<bool, 2> mask; 173 SmallVector<int64_t, 2> vectorShape; 174 if (flattenedParallelDim) { 175 mask.push_back(false); 176 vectorShape.push_back(flattenedParallelDim); 177 } 178 if (flattenedReductionDim) { 179 mask.push_back(true); 180 vectorShape.push_back(flattenedReductionDim); 181 } 182 if (!useInnerDimsForReduction && vectorShape.size() == 2) { 183 std::swap(mask.front(), mask.back()); 184 std::swap(vectorShape.front(), vectorShape.back()); 185 } 186 auto castedType = VectorType::get( 187 vectorShape, multiReductionOp.getSourceVectorType().getElementType()); 188 Value cast = rewriter.create<vector::ShapeCastOp>( 189 loc, castedType, multiReductionOp.source()); 190 191 // 5. Creates the flattened form of vector.multi_reduction with inner/outer 192 // most dim as reduction. 193 auto newOp = rewriter.create<vector::MultiDimReductionOp>( 194 loc, cast, mask, multiReductionOp.kind()); 195 196 // 6. If there are no parallel shapes, the result is a scalar. 197 // TODO: support 0-d vectors when available. 198 if (parallelShapes.empty()) { 199 rewriter.replaceOp(multiReductionOp, newOp.dest()); 200 return success(); 201 } 202 203 // 7. Creates shape cast for the output n-D -> 2-D 204 VectorType outputCastedType = VectorType::get( 205 parallelShapes, 206 multiReductionOp.getSourceVectorType().getElementType()); 207 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 208 multiReductionOp, outputCastedType, newOp.dest()); 209 return success(); 210 } 211 212 private: 213 const bool useInnerDimsForReduction; 214 }; 215 216 /// Unrolls vector.multi_reduction with outermost reductions 217 /// and combines results 218 struct TwoDimMultiReductionToElementWise 219 : public OpRewritePattern<vector::MultiDimReductionOp> { 220 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 221 222 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 223 PatternRewriter &rewriter) const override { 224 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 225 // Rank-2 ["parallel", "reduce"] or bail. 226 if (srcRank != 2) 227 return failure(); 228 229 if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0)) 230 return failure(); 231 232 auto loc = multiReductionOp.getLoc(); 233 ArrayRef<int64_t> srcShape = 234 multiReductionOp.getSourceVectorType().getShape(); 235 236 Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType()); 237 if (!elementType.isIntOrIndexOrFloat()) 238 return failure(); 239 240 Value result = 241 rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), 0) 242 .getResult(); 243 for (int64_t i = 1; i < srcShape[0]; i++) { 244 auto operand = 245 rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), i); 246 switch (multiReductionOp.kind()) { 247 case vector::CombiningKind::ADD: 248 if (elementType.isIntOrIndex()) 249 result = rewriter.create<arith::AddIOp>(loc, operand, result); 250 else 251 result = rewriter.create<arith::AddFOp>(loc, operand, result); 252 break; 253 case vector::CombiningKind::MUL: 254 if (elementType.isIntOrIndex()) 255 result = rewriter.create<arith::MulIOp>(loc, operand, result); 256 else 257 result = rewriter.create<arith::MulFOp>(loc, operand, result); 258 break; 259 case vector::CombiningKind::MINUI: 260 result = rewriter.create<arith::MinUIOp>(loc, operand, result); 261 break; 262 case vector::CombiningKind::MINSI: 263 result = rewriter.create<arith::MinSIOp>(loc, operand, result); 264 break; 265 case vector::CombiningKind::MINF: 266 result = rewriter.create<arith::MinFOp>(loc, operand, result); 267 break; 268 case vector::CombiningKind::MAXUI: 269 result = rewriter.create<arith::MaxUIOp>(loc, operand, result); 270 break; 271 case vector::CombiningKind::MAXSI: 272 result = rewriter.create<arith::MaxSIOp>(loc, operand, result); 273 break; 274 case vector::CombiningKind::MAXF: 275 result = rewriter.create<arith::MaxFOp>(loc, operand, result); 276 break; 277 case vector::CombiningKind::AND: 278 result = rewriter.create<arith::AndIOp>(loc, operand, result); 279 break; 280 case vector::CombiningKind::OR: 281 result = rewriter.create<arith::OrIOp>(loc, operand, result); 282 break; 283 case vector::CombiningKind::XOR: 284 result = rewriter.create<arith::XOrIOp>(loc, operand, result); 285 break; 286 } 287 } 288 289 rewriter.replaceOp(multiReductionOp, result); 290 return success(); 291 } 292 }; 293 294 /// Converts 2d vector.multi_reduction with inner most reduction dimension into 295 /// a sequence of vector.reduction ops. 296 struct TwoDimMultiReductionToReduction 297 : public OpRewritePattern<vector::MultiDimReductionOp> { 298 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 299 300 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 301 PatternRewriter &rewriter) const override { 302 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 303 if (srcRank != 2) 304 return failure(); 305 306 if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1)) 307 return failure(); 308 309 auto loc = multiReductionOp.getLoc(); 310 Value result = rewriter.create<ConstantOp>( 311 loc, multiReductionOp.getDestType(), 312 rewriter.getZeroAttr(multiReductionOp.getDestType())); 313 int outerDim = multiReductionOp.getSourceVectorType().getShape()[0]; 314 315 // TODO: Add vector::CombiningKind attribute instead of string to 316 // vector.reduction. 317 auto getKindStr = [](vector::CombiningKind kind) { 318 switch (kind) { 319 case vector::CombiningKind::ADD: 320 return "add"; 321 case vector::CombiningKind::MUL: 322 return "mul"; 323 case vector::CombiningKind::MINUI: 324 return "minui"; 325 case vector::CombiningKind::MINSI: 326 return "minsi"; 327 case vector::CombiningKind::MINF: 328 return "minf"; 329 case vector::CombiningKind::MAXUI: 330 return "maxui"; 331 case vector::CombiningKind::MAXSI: 332 return "maxsi"; 333 case vector::CombiningKind::MAXF: 334 return "maxf"; 335 case vector::CombiningKind::AND: 336 return "and"; 337 case vector::CombiningKind::OR: 338 return "or"; 339 case vector::CombiningKind::XOR: 340 return "xor"; 341 } 342 llvm_unreachable("unknown combining kind"); 343 }; 344 345 for (int i = 0; i < outerDim; ++i) { 346 auto v = rewriter.create<vector::ExtractOp>( 347 loc, multiReductionOp.source(), ArrayRef<int64_t>{i}); 348 auto reducedValue = rewriter.create<vector::ReductionOp>( 349 loc, getElementTypeOrSelf(multiReductionOp.getDestType()), 350 rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v, 351 ValueRange{}); 352 result = rewriter.create<vector::InsertElementOp>( 353 loc, reducedValue, result, 354 rewriter.create<arith::ConstantIndexOp>(loc, i)); 355 } 356 rewriter.replaceOp(multiReductionOp, result); 357 return success(); 358 } 359 }; 360 361 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d 362 /// form with both a single parallel and reduction dimension. 363 /// This is achieved with a simple vector.shape_cast that inserts a leading 1. 364 /// The case with a single parallel dimension is a noop and folds away 365 /// separately. 366 struct OneDimMultiReductionToTwoDim 367 : public OpRewritePattern<vector::MultiDimReductionOp> { 368 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 369 370 LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 371 PatternRewriter &rewriter) const override { 372 auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 373 // Rank-1 or bail. 374 if (srcRank != 1) 375 return failure(); 376 377 auto loc = multiReductionOp.getLoc(); 378 auto srcVectorType = multiReductionOp.getSourceVectorType(); 379 auto srcShape = srcVectorType.getShape(); 380 auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()}, 381 srcVectorType.getElementType()); 382 assert(!multiReductionOp.getDestType().isa<VectorType>() && 383 "multi_reduction with a single dimension expects a scalar result"); 384 385 // If the unique dim is reduced and we insert a parallel in front, we need a 386 // {false, true} mask. 387 SmallVector<bool, 2> mask{false, true}; 388 389 /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0) 390 Value cast = rewriter.create<vector::ShapeCastOp>( 391 loc, castedType, multiReductionOp.source()); 392 Value reduced = rewriter.create<vector::MultiDimReductionOp>( 393 loc, cast, mask, multiReductionOp.kind()); 394 rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced, 395 ArrayRef<int64_t>{0}); 396 return success(); 397 } 398 }; 399 400 void mlir::vector::populateVectorMultiReductionLoweringPatterns( 401 RewritePatternSet &patterns, VectorMultiReductionLowering options) { 402 patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>( 403 patterns.getContext(), options); 404 patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext()); 405 if (options == VectorMultiReductionLowering ::InnerReduction) 406 patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext()); 407 else 408 patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext()); 409 } 410