1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites as 1->N patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 14 15 #include <type_traits> 16 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 19 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 20 #include "mlir/Dialect/Linalg/IR/Linalg.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/SCF.h" 23 #include "mlir/Dialect/Utils/IndexingUtils.h" 24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 26 #include "mlir/IR/ImplicitLocOpBuilder.h" 27 #include "mlir/IR/Matchers.h" 28 #include "mlir/IR/PatternMatch.h" 29 #include "mlir/Interfaces/VectorInterfaces.h" 30 31 #include "llvm/ADT/DenseSet.h" 32 #include "llvm/ADT/MapVector.h" 33 #include "llvm/ADT/STLExtras.h" 34 #include "llvm/Support/CommandLine.h" 35 #include "llvm/Support/Debug.h" 36 #include "llvm/Support/raw_ostream.h" 37 38 #define DEBUG_TYPE "vector-to-vector" 39 40 using namespace mlir; 41 using namespace mlir::vector; 42 43 // Helper to find an index in an affine map. 44 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) { 45 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 46 int64_t idx = map.getDimPosition(i); 47 if (idx == index) 48 return i; 49 } 50 return None; 51 } 52 53 // Helper to construct iterator types with one index removed. 54 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes, 55 int64_t index) { 56 SmallVector<Attribute, 4> results; 57 for (const auto &it : llvm::enumerate(iteratorTypes)) { 58 int64_t idx = it.index(); 59 if (idx == index) 60 continue; 61 results.push_back(it.value()); 62 } 63 return results; 64 } 65 66 // Helper to construct an affine map with one index removed. 67 static AffineMap adjustMap(AffineMap map, int64_t index, 68 PatternRewriter &rewriter) { 69 auto *ctx = rewriter.getContext(); 70 SmallVector<AffineExpr, 4> results; 71 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 72 int64_t idx = map.getDimPosition(i); 73 if (idx == index) 74 continue; 75 // Re-insert remaining indices, but renamed when occurring 76 // after the removed index. 77 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); 78 results.push_back(targetExpr); 79 } 80 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); 81 } 82 83 // Helper method to possibly drop a dimension in a load. 84 // TODO 85 static Value reshapeLoad(Location loc, Value val, VectorType type, 86 int64_t index, int64_t pos, 87 PatternRewriter &rewriter) { 88 if (index == -1) 89 return val; 90 Type lowType = VectorType::Builder(type).dropDim(0); 91 // At extraction dimension? 92 if (index == 0) { 93 auto posAttr = rewriter.getI64ArrayAttr(pos); 94 return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr); 95 } 96 // Unroll leading dimensions. 97 VectorType vType = lowType.cast<VectorType>(); 98 Type resType = VectorType::Builder(type).dropDim(index); 99 auto resVectorType = resType.cast<VectorType>(); 100 Value result = rewriter.create<arith::ConstantOp>( 101 loc, resVectorType, rewriter.getZeroAttr(resVectorType)); 102 for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { 103 auto posAttr = rewriter.getI64ArrayAttr(d); 104 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr); 105 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); 106 result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, 107 posAttr); 108 } 109 return result; 110 } 111 112 // Helper method to possibly drop a dimension in a store. 113 // TODO 114 static Value reshapeStore(Location loc, Value val, Value result, 115 VectorType type, int64_t index, int64_t pos, 116 PatternRewriter &rewriter) { 117 // Unmodified? 118 if (index == -1) 119 return val; 120 // At insertion dimension? 121 if (index == 0) { 122 auto posAttr = rewriter.getI64ArrayAttr(pos); 123 return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr); 124 } 125 // Unroll leading dimensions. 126 Type lowType = VectorType::Builder(type).dropDim(0); 127 VectorType vType = lowType.cast<VectorType>(); 128 Type insType = VectorType::Builder(vType).dropDim(0); 129 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { 130 auto posAttr = rewriter.getI64ArrayAttr(d); 131 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr); 132 Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr); 133 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); 134 result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr); 135 } 136 return result; 137 } 138 139 template <typename IntType> 140 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) { 141 return llvm::to_vector<4>(llvm::map_range( 142 arrayAttr.getAsRange<IntegerAttr>(), 143 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); 144 } 145 146 namespace { 147 148 /// ShapeCastOpFolder folds cancelling ShapeCastOps away. 149 // 150 // Example: 151 // 152 // The following MLIR with cancelling ShapeCastOps: 153 // 154 // %0 = source : vector<5x4x2xf32> 155 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> 156 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> 157 // %3 = user %2 : vector<5x4x2xf32> 158 // 159 // Should canonicalize to the following: 160 // 161 // %0 = source : vector<5x4x2xf32> 162 // %1 = user %0 : vector<5x4x2xf32> 163 // 164 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> { 165 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 166 167 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, 168 PatternRewriter &rewriter) const override { 169 // Check if 'shapeCastOp' has vector source/result type. 170 auto sourceVectorType = 171 shapeCastOp.source().getType().dyn_cast_or_null<VectorType>(); 172 auto resultVectorType = 173 shapeCastOp.result().getType().dyn_cast_or_null<VectorType>(); 174 if (!sourceVectorType || !resultVectorType) 175 return failure(); 176 177 // Check if shape cast op source operand is also a shape cast op. 178 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>( 179 shapeCastOp.source().getDefiningOp()); 180 if (!sourceShapeCastOp) 181 return failure(); 182 auto operandSourceVectorType = 183 sourceShapeCastOp.source().getType().cast<VectorType>(); 184 auto operandResultVectorType = sourceShapeCastOp.getType(); 185 186 // Check if shape cast operations invert each other. 187 if (operandSourceVectorType != resultVectorType || 188 operandResultVectorType != sourceVectorType) 189 return failure(); 190 191 rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source()); 192 return success(); 193 } 194 }; 195 196 /// Progressive lowering of BroadcastOp. 197 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> { 198 public: 199 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern; 200 201 LogicalResult matchAndRewrite(vector::BroadcastOp op, 202 PatternRewriter &rewriter) const override { 203 auto loc = op.getLoc(); 204 VectorType dstType = op.getVectorType(); 205 VectorType srcType = op.getSourceType().dyn_cast<VectorType>(); 206 Type eltType = dstType.getElementType(); 207 208 // Scalar to any vector can use splat. 209 if (!srcType) { 210 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.source()); 211 return success(); 212 } 213 214 // Determine rank of source and destination. 215 int64_t srcRank = srcType.getRank(); 216 int64_t dstRank = dstType.getRank(); 217 218 // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. 219 if (srcRank <= 1 && dstRank == 1) { 220 Value ext; 221 if (srcRank == 0) 222 ext = rewriter.create<vector::ExtractElementOp>(loc, op.source()); 223 else 224 ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0); 225 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext); 226 return success(); 227 } 228 229 // Duplicate this rank. 230 // For example: 231 // %x = broadcast %y : k-D to n-D, k < n 232 // becomes: 233 // %b = broadcast %y : k-D to (n-1)-D 234 // %x = [%b,%b,%b,%b] : n-D 235 // becomes: 236 // %b = [%y,%y] : (n-1)-D 237 // %x = [%b,%b,%b,%b] : n-D 238 if (srcRank < dstRank) { 239 // Duplication. 240 VectorType resType = 241 VectorType::get(dstType.getShape().drop_front(), eltType); 242 Value bcst = 243 rewriter.create<vector::BroadcastOp>(loc, resType, op.source()); 244 Value result = rewriter.create<arith::ConstantOp>( 245 loc, dstType, rewriter.getZeroAttr(dstType)); 246 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 247 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 248 rewriter.replaceOp(op, result); 249 return success(); 250 } 251 252 // Find non-matching dimension, if any. 253 assert(srcRank == dstRank); 254 int64_t m = -1; 255 for (int64_t r = 0; r < dstRank; r++) 256 if (srcType.getDimSize(r) != dstType.getDimSize(r)) { 257 m = r; 258 break; 259 } 260 261 // All trailing dimensions are the same. Simply pass through. 262 if (m == -1) { 263 rewriter.replaceOp(op, op.source()); 264 return success(); 265 } 266 267 // Any non-matching dimension forces a stretch along this rank. 268 // For example: 269 // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32> 270 // becomes: 271 // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32> 272 // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32> 273 // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32> 274 // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32> 275 // %x = [%a,%b,%c,%d] 276 // becomes: 277 // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32> 278 // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32> 279 // %a = [%u, %v] 280 // .. 281 // %x = [%a,%b,%c,%d] 282 VectorType resType = 283 VectorType::get(dstType.getShape().drop_front(), eltType); 284 Value result = rewriter.create<arith::ConstantOp>( 285 loc, dstType, rewriter.getZeroAttr(dstType)); 286 if (m == 0) { 287 // Stetch at start. 288 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0); 289 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 290 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 291 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 292 } else { 293 // Stetch not at start. 294 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) { 295 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d); 296 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 297 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 298 } 299 } 300 rewriter.replaceOp(op, result); 301 return success(); 302 } 303 }; 304 305 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not 306 /// transposed. 307 void pruneNonTransposedDims(ArrayRef<int64_t> transpose, 308 SmallVectorImpl<int64_t> &result) { 309 size_t numTransposedDims = transpose.size(); 310 for (size_t transpDim : llvm::reverse(transpose)) { 311 if (transpDim != numTransposedDims - 1) 312 break; 313 numTransposedDims--; 314 } 315 316 result.append(transpose.begin(), transpose.begin() + numTransposedDims); 317 } 318 319 /// Progressive lowering of TransposeOp. 320 /// One: 321 /// %x = vector.transpose %y, [1, 0] 322 /// is replaced by: 323 /// %z = arith.constant dense<0.000000e+00> 324 /// %0 = vector.extract %y[0, 0] 325 /// %1 = vector.insert %0, %z [0, 0] 326 /// .. 327 /// %x = vector.insert .., .. [.., ..] 328 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { 329 public: 330 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 331 332 TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions, 333 MLIRContext *context) 334 : OpRewritePattern<vector::TransposeOp>(context), 335 vectorTransformOptions(vectorTransformOptions) {} 336 337 LogicalResult matchAndRewrite(vector::TransposeOp op, 338 PatternRewriter &rewriter) const override { 339 auto loc = op.getLoc(); 340 341 Value input = op.vector(); 342 VectorType inputType = op.getVectorType(); 343 VectorType resType = op.getResultType(); 344 345 // Set up convenience transposition table. 346 SmallVector<int64_t, 4> transp; 347 for (auto attr : op.transp()) 348 transp.push_back(attr.cast<IntegerAttr>().getInt()); 349 350 if (vectorTransformOptions.vectorTransposeLowering == 351 vector::VectorTransposeLowering::Shuffle && 352 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) 353 return rewriter.notifyMatchFailure( 354 op, "Options specifies lowering to shuffle"); 355 356 // Handle a true 2-D matrix transpose differently when requested. 357 if (vectorTransformOptions.vectorTransposeLowering == 358 vector::VectorTransposeLowering::Flat && 359 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { 360 Type flattenedType = 361 VectorType::get(resType.getNumElements(), resType.getElementType()); 362 auto matrix = 363 rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input); 364 auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); 365 auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); 366 Value trans = rewriter.create<vector::FlatTransposeOp>( 367 loc, flattenedType, matrix, rows, columns); 368 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans); 369 return success(); 370 } 371 372 // Generate unrolled extract/insert ops. We do not unroll the rightmost 373 // (i.e., highest-order) dimensions that are not transposed and leave them 374 // in vector form to improve performance. Therefore, we prune those 375 // dimensions from the shape/transpose data structures used to generate the 376 // extract/insert ops. 377 SmallVector<int64_t, 4> prunedTransp; 378 pruneNonTransposedDims(transp, prunedTransp); 379 size_t numPrunedDims = transp.size() - prunedTransp.size(); 380 auto prunedInShape = inputType.getShape().drop_back(numPrunedDims); 381 SmallVector<int64_t, 4> ones(prunedInShape.size(), 1); 382 auto prunedInStrides = computeStrides(prunedInShape, ones); 383 384 // Generates the extract/insert operations for every scalar/vector element 385 // of the leftmost transposed dimensions. We traverse every transpose 386 // element using a linearized index that we delinearize to generate the 387 // appropriate indices for the extract/insert operations. 388 Value result = rewriter.create<arith::ConstantOp>( 389 loc, resType, rewriter.getZeroAttr(resType)); 390 int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape); 391 392 for (int64_t linearIdx = 0; linearIdx < numTransposedElements; 393 ++linearIdx) { 394 auto extractIdxs = delinearize(prunedInStrides, linearIdx); 395 SmallVector<int64_t, 4> insertIdxs(extractIdxs); 396 applyPermutationToVector(insertIdxs, prunedTransp); 397 Value extractOp = 398 rewriter.create<vector::ExtractOp>(loc, input, extractIdxs); 399 result = 400 rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs); 401 } 402 403 rewriter.replaceOp(op, result); 404 return success(); 405 } 406 407 private: 408 /// Options to control the vector patterns. 409 vector::VectorTransformsOptions vectorTransformOptions; 410 }; 411 412 /// Rewrite a 2-D vector.transpose as a sequence of: 413 /// vector.shape_cast 2D -> 1D 414 /// vector.shuffle 415 /// vector.shape_cast 1D -> 2D 416 class TransposeOp2DToShuffleLowering 417 : public OpRewritePattern<vector::TransposeOp> { 418 public: 419 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 420 421 TransposeOp2DToShuffleLowering( 422 vector::VectorTransformsOptions vectorTransformOptions, 423 MLIRContext *context) 424 : OpRewritePattern<vector::TransposeOp>(context), 425 vectorTransformOptions(vectorTransformOptions) {} 426 427 LogicalResult matchAndRewrite(vector::TransposeOp op, 428 PatternRewriter &rewriter) const override { 429 auto loc = op.getLoc(); 430 431 VectorType srcType = op.getVectorType(); 432 if (srcType.getRank() != 2) 433 return rewriter.notifyMatchFailure(op, "Not a 2D transpose"); 434 435 SmallVector<int64_t, 4> transp; 436 for (auto attr : op.transp()) 437 transp.push_back(attr.cast<IntegerAttr>().getInt()); 438 if (transp[0] != 1 && transp[1] != 0) 439 return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation"); 440 441 if (vectorTransformOptions.vectorTransposeLowering != 442 VectorTransposeLowering::Shuffle) 443 return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle"); 444 445 int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); 446 Value casted = rewriter.create<vector::ShapeCastOp>( 447 loc, VectorType::get({m * n}, srcType.getElementType()), op.vector()); 448 SmallVector<int64_t> mask; 449 mask.reserve(m * n); 450 for (int64_t j = 0; j < n; ++j) 451 for (int64_t i = 0; i < m; ++i) 452 mask.push_back(i * n + j); 453 454 Value shuffled = 455 rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask); 456 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(), 457 shuffled); 458 459 return success(); 460 } 461 462 private: 463 /// Options to control the vector patterns. 464 vector::VectorTransformsOptions vectorTransformOptions; 465 }; 466 467 /// Progressive lowering of OuterProductOp. 468 /// One: 469 /// %x = vector.outerproduct %lhs, %rhs, %acc 470 /// is replaced by: 471 /// %z = zero-result 472 /// %0 = vector.extract %lhs[0] 473 /// %1 = vector.broadcast %0 474 /// %2 = vector.extract %acc[0] 475 /// %3 = vector.fma %1, %rhs, %2 476 /// %4 = vector.insert %3, %z[0] 477 /// .. 478 /// %x = vector.insert %.., %..[N-1] 479 /// 480 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { 481 public: 482 using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern; 483 484 LogicalResult matchAndRewrite(vector::OuterProductOp op, 485 PatternRewriter &rewriter) const override { 486 auto loc = op.getLoc(); 487 488 VectorType lhsType = op.getOperandVectorTypeLHS(); 489 VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>(); 490 VectorType resType = op.getVectorType(); 491 Type eltType = resType.getElementType(); 492 bool isInt = eltType.isa<IntegerType, IndexType>(); 493 Value acc = (op.acc().empty()) ? nullptr : op.acc()[0]; 494 vector::CombiningKind kind = op.kind(); 495 496 if (!rhsType) { 497 // Special case: AXPY operation. 498 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs()); 499 Optional<Value> mult = 500 isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter) 501 : genMultF(loc, op.lhs(), b, acc, kind, rewriter); 502 if (!mult.hasValue()) 503 return failure(); 504 rewriter.replaceOp(op, mult.getValue()); 505 return success(); 506 } 507 508 Value result = rewriter.create<arith::ConstantOp>( 509 loc, resType, rewriter.getZeroAttr(resType)); 510 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { 511 auto pos = rewriter.getI64ArrayAttr(d); 512 Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos); 513 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x); 514 Value r = nullptr; 515 if (acc) 516 r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos); 517 Optional<Value> m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter) 518 : genMultF(loc, a, op.rhs(), r, kind, rewriter); 519 if (!m.hasValue()) 520 return failure(); 521 result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(), 522 result, pos); 523 } 524 rewriter.replaceOp(op, result); 525 return success(); 526 } 527 528 private: 529 static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc, 530 vector::CombiningKind kind, 531 PatternRewriter &rewriter) { 532 using vector::CombiningKind; 533 534 auto mul = rewriter.create<arith::MulIOp>(loc, x, y); 535 if (!acc) 536 return Optional<Value>(mul); 537 538 if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) 539 // Only valid for floating point types. 540 return Optional<Value>(); 541 542 return makeArithReduction(rewriter, loc, kind, mul, acc); 543 } 544 545 static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc, 546 vector::CombiningKind kind, 547 PatternRewriter &rewriter) { 548 using vector::CombiningKind; 549 550 // Special case for fused multiply-add. 551 if (acc && kind == CombiningKind::ADD) { 552 return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc)); 553 } 554 555 auto mul = rewriter.create<arith::MulFOp>(loc, x, y); 556 557 if (!acc) 558 return Optional<Value>(mul); 559 560 if (kind == CombiningKind::ADD || kind == CombiningKind::AND || 561 kind == CombiningKind::MINUI || kind == CombiningKind::MINSI || 562 kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI || 563 kind == CombiningKind::OR || kind == CombiningKind::XOR) 564 // Already handled or only valid for integer types. 565 return Optional<Value>(); 566 567 return makeArithReduction(rewriter, loc, kind, mul, acc); 568 } 569 }; 570 571 /// Progressive lowering of ConstantMaskOp. 572 /// One: 573 /// %x = vector.constant_mask [a,b] 574 /// is replaced by: 575 /// %z = zero-result 576 /// %l = vector.constant_mask [b] 577 /// %4 = vector.insert %l, %z[0] 578 /// .. 579 /// %x = vector.insert %l, %..[a-1] 580 /// until a one-dimensional vector is reached. All these operations 581 /// will be folded at LLVM IR level. 582 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { 583 public: 584 using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern; 585 586 LogicalResult matchAndRewrite(vector::ConstantMaskOp op, 587 PatternRewriter &rewriter) const override { 588 auto loc = op.getLoc(); 589 auto dstType = op.getType(); 590 auto eltType = dstType.getElementType(); 591 auto dimSizes = op.mask_dim_sizes(); 592 int64_t rank = dstType.getRank(); 593 594 if (rank == 0) { 595 assert(dimSizes.size() == 1 && 596 "Expected exactly one dim size for a 0-D vector"); 597 bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1; 598 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 599 op, dstType, 600 DenseIntElementsAttr::get( 601 VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()), 602 ArrayRef<bool>{value})); 603 return success(); 604 } 605 606 // Scalable constant masks can only be lowered for the "none set" case. 607 if (dstType.cast<VectorType>().isScalable()) { 608 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 609 op, DenseElementsAttr::get(dstType, false)); 610 return success(); 611 } 612 613 int64_t trueDim = std::min(dstType.getDimSize(0), 614 dimSizes[0].cast<IntegerAttr>().getInt()); 615 616 if (rank == 1) { 617 // Express constant 1-D case in explicit vector form: 618 // [T,..,T,F,..,F]. 619 SmallVector<bool, 4> values(dstType.getDimSize(0)); 620 for (int64_t d = 0; d < trueDim; d++) 621 values[d] = true; 622 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 623 op, dstType, rewriter.getBoolVectorAttr(values)); 624 return success(); 625 } 626 627 VectorType lowType = 628 VectorType::get(dstType.getShape().drop_front(), eltType); 629 SmallVector<int64_t, 4> newDimSizes; 630 for (int64_t r = 1; r < rank; r++) 631 newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt()); 632 Value trueVal = rewriter.create<vector::ConstantMaskOp>( 633 loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); 634 Value result = rewriter.create<arith::ConstantOp>( 635 loc, dstType, rewriter.getZeroAttr(dstType)); 636 for (int64_t d = 0; d < trueDim; d++) { 637 auto pos = rewriter.getI64ArrayAttr(d); 638 result = 639 rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos); 640 } 641 rewriter.replaceOp(op, result); 642 return success(); 643 } 644 }; 645 646 /// Progressive lowering of CreateMaskOp. 647 /// One: 648 /// %x = vector.create_mask %a, ... : vector<dx...> 649 /// is replaced by: 650 /// %l = vector.create_mask ... : vector<...> ; one lower rank 651 /// %0 = arith.cmpi "slt", %ci, %a | 652 /// %1 = select %0, %l, %zeroes | 653 /// %r = vector.insert %1, %pr [i] | d-times 654 /// %x = .... 655 /// until a one-dimensional vector is reached. 656 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { 657 public: 658 using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern; 659 660 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 661 PatternRewriter &rewriter) const override { 662 auto dstType = op.getResult().getType().cast<VectorType>(); 663 int64_t rank = dstType.getRank(); 664 if (rank <= 1) 665 return rewriter.notifyMatchFailure( 666 op, "0-D and 1-D vectors are handled separately"); 667 668 auto loc = op.getLoc(); 669 auto eltType = dstType.getElementType(); 670 int64_t dim = dstType.getDimSize(0); 671 Value idx = op.getOperand(0); 672 673 VectorType lowType = 674 VectorType::get(dstType.getShape().drop_front(), eltType); 675 Value trueVal = rewriter.create<vector::CreateMaskOp>( 676 loc, lowType, op.getOperands().drop_front()); 677 Value falseVal = rewriter.create<arith::ConstantOp>( 678 loc, lowType, rewriter.getZeroAttr(lowType)); 679 Value result = rewriter.create<arith::ConstantOp>( 680 loc, dstType, rewriter.getZeroAttr(dstType)); 681 for (int64_t d = 0; d < dim; d++) { 682 Value bnd = 683 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d)); 684 Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 685 bnd, idx); 686 Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal); 687 auto pos = rewriter.getI64ArrayAttr(d); 688 result = 689 rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos); 690 } 691 rewriter.replaceOp(op, result); 692 return success(); 693 } 694 }; 695 696 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D 697 /// vectors progressively on the way to target llvm.matrix intrinsics. 698 /// This iterates over the most major dimension of the 2-D vector and performs 699 /// rewrites into: 700 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D 701 class ShapeCastOp2DDownCastRewritePattern 702 : public OpRewritePattern<vector::ShapeCastOp> { 703 public: 704 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 705 706 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 707 PatternRewriter &rewriter) const override { 708 auto sourceVectorType = op.getSourceVectorType(); 709 auto resultVectorType = op.getResultVectorType(); 710 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) 711 return failure(); 712 713 auto loc = op.getLoc(); 714 Value desc = rewriter.create<arith::ConstantOp>( 715 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 716 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; 717 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { 718 Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i); 719 desc = rewriter.create<vector::InsertStridedSliceOp>( 720 loc, vec, desc, 721 /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); 722 } 723 rewriter.replaceOp(op, desc); 724 return success(); 725 } 726 }; 727 728 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D 729 /// vectors progressively. 730 /// This iterates over the most major dimension of the 2-D vector and performs 731 /// rewrites into: 732 /// vector.extract_strided_slice from 1-D + vector.insert into 2-D 733 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. 734 class ShapeCastOp2DUpCastRewritePattern 735 : public OpRewritePattern<vector::ShapeCastOp> { 736 public: 737 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 738 739 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 740 PatternRewriter &rewriter) const override { 741 auto sourceVectorType = op.getSourceVectorType(); 742 auto resultVectorType = op.getResultVectorType(); 743 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) 744 return failure(); 745 746 auto loc = op.getLoc(); 747 Value desc = rewriter.create<arith::ConstantOp>( 748 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 749 unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; 750 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { 751 Value vec = rewriter.create<vector::ExtractStridedSliceOp>( 752 loc, op.source(), /*offsets=*/i * mostMinorVectorSize, 753 /*sizes=*/mostMinorVectorSize, 754 /*strides=*/1); 755 desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i); 756 } 757 rewriter.replaceOp(op, desc); 758 return success(); 759 } 760 }; 761 762 // We typically should not lower general shape cast operations into data 763 // movement instructions, since the assumption is that these casts are 764 // optimized away during progressive lowering. For completeness, however, 765 // we fall back to a reference implementation that moves all elements 766 // into the right place if we get here. 767 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { 768 public: 769 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 770 771 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 772 PatternRewriter &rewriter) const override { 773 Location loc = op.getLoc(); 774 auto sourceVectorType = op.getSourceVectorType(); 775 auto resultVectorType = op.getResultVectorType(); 776 777 // Special case 2D/1D lowerings with better implementations. 778 // TODO: make is ND/1D to allow generic ND->1D->MD. 779 int64_t srcRank = sourceVectorType.getRank(); 780 int64_t resRank = resultVectorType.getRank(); 781 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2)) 782 return failure(); 783 784 // Generic ShapeCast lowering path goes all the way down to unrolled scalar 785 // extract/insert chains. 786 // TODO: consider evolving the semantics to only allow 1D source or dest and 787 // drop this potentially very expensive lowering. 788 // Compute number of elements involved in the reshape. 789 int64_t numElts = 1; 790 for (int64_t r = 0; r < srcRank; r++) 791 numElts *= sourceVectorType.getDimSize(r); 792 // Replace with data movement operations: 793 // x[0,0,0] = y[0,0] 794 // x[0,0,1] = y[0,1] 795 // x[0,1,0] = y[0,2] 796 // etc., incrementing the two index vectors "row-major" 797 // within the source and result shape. 798 SmallVector<int64_t, 4> srcIdx(srcRank); 799 SmallVector<int64_t, 4> resIdx(resRank); 800 Value result = rewriter.create<arith::ConstantOp>( 801 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 802 for (int64_t i = 0; i < numElts; i++) { 803 if (i != 0) { 804 incIdx(srcIdx, sourceVectorType, srcRank - 1); 805 incIdx(resIdx, resultVectorType, resRank - 1); 806 } 807 Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx); 808 result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx); 809 } 810 rewriter.replaceOp(op, result); 811 return success(); 812 } 813 814 private: 815 static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) { 816 assert(0 <= r && r < tp.getRank()); 817 if (++idx[r] == tp.getDimSize(r)) { 818 idx[r] = 0; 819 incIdx(idx, tp, r - 1); 820 } 821 } 822 }; 823 824 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp. 825 /// Ex: 826 /// ``` 827 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> 828 /// %1 = vector.multi_reduction add, %0 [1] 829 /// : vector<8x32x16xf32> to vector<8x16xf32> 830 /// ``` 831 /// Gets converted to: 832 /// ``` 833 /// %1 = vector.contract {indexing_maps = [ 834 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 835 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 836 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 837 /// iterator_types = ["parallel", "parallel", "reduction"], 838 /// kind = add} %0, %arg1, %cst_f0 839 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 840 /// ``` 841 struct MultiReduceToContract 842 : public OpRewritePattern<vector::MultiDimReductionOp> { 843 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 844 845 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, 846 PatternRewriter &rewriter) const override { 847 if (reduceOp.kind() != vector::CombiningKind::ADD) 848 return failure(); 849 Operation *mulOp = reduceOp.source().getDefiningOp(); 850 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp)) 851 return failure(); 852 SmallVector<bool> reductionMask = reduceOp.getReductionMask(); 853 auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size()); 854 SmallVector<AffineExpr> exprs; 855 SmallVector<StringRef> iteratorTypes; 856 for (const auto &isReduceDim : llvm::enumerate(reductionMask)) { 857 if (!isReduceDim.value()) { 858 iteratorTypes.push_back(getParallelIteratorTypeName()); 859 exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index())); 860 } else { 861 iteratorTypes.push_back(getReductionIteratorTypeName()); 862 } 863 } 864 auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(), 865 /*symCount=*/0, exprs, reduceOp.getContext()); 866 Value zero = rewriter.create<arith::ConstantOp>( 867 reduceOp.getLoc(), reduceOp.getDestType(), 868 rewriter.getZeroAttr(reduceOp.getDestType())); 869 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>( 870 reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero, 871 rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}), 872 rewriter.getStrArrayAttr(iteratorTypes)); 873 return success(); 874 } 875 }; 876 877 /// Merge TransposeOp into ContractionOp user. 878 /// Ex: 879 /// ``` 880 /// %0 = vector.transpose %arg0, [2, 0, 1] 881 /// : vector<32x16x8xf32> to vector<8x32x16xf32> 882 /// %1 = vector.contract {indexing_maps = [ 883 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 884 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 885 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 886 /// iterator_types = ["parallel", "parallel", "reduction"], 887 /// kind = add} %0, %arg1, %cst_f0 888 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 889 /// ``` 890 /// Gets converted to: 891 /// ``` 892 /// %1 = vector.contract {indexing_maps = [ 893 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>, 894 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 895 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 896 /// iterator_types = ["parallel", "parallel", "reduction"], 897 /// kind = add} %arg0, %arg1, %cst_f0 898 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 899 /// ``` 900 struct CombineContractTranspose 901 : public OpRewritePattern<vector::ContractionOp> { 902 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 903 904 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 905 PatternRewriter &rewriter) const override { 906 SmallVector<AffineMap, 4> maps = 907 llvm::to_vector<4>(contractOp.getIndexingMaps()); 908 Value lhs = contractOp.lhs(); 909 Value rhs = contractOp.rhs(); 910 size_t index = 0; 911 bool changed = false; 912 for (Value *operand : {&lhs, &rhs}) { 913 AffineMap &map = maps[index++]; 914 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>(); 915 if (!transposeOp) 916 continue; 917 SmallVector<int64_t> perm; 918 transposeOp.getTransp(perm); 919 AffineMap permutationMap = AffineMap::getPermutationMap( 920 extractVector<unsigned>(transposeOp.transp()), 921 contractOp.getContext()); 922 map = inversePermutation(permutationMap).compose(map); 923 *operand = transposeOp.vector(); 924 changed = true; 925 } 926 if (!changed) 927 return failure(); 928 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 929 contractOp, lhs, rhs, contractOp.acc(), 930 rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); 931 return success(); 932 } 933 }; 934 935 /// Merge BroadcastOp into ContractionOp user. 936 /// Ex: 937 /// ``` 938 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> 939 /// %1 = vector.contract {indexing_maps = [ 940 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 941 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 942 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 943 /// iterator_types = ["parallel", "parallel", "reduction"], 944 /// kind = add} %0, %arg1, %cst_f0 945 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 946 /// ``` 947 /// Gets converted to: 948 /// ``` 949 /// %1 = vector.contract {indexing_maps = [ 950 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 951 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 952 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 953 /// iterator_types = ["parallel", "parallel", "reduction"], 954 /// kind = add} %arg0, %arg1, %cst_f0 955 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 956 /// ``` 957 struct CombineContractBroadcast 958 : public OpRewritePattern<vector::ContractionOp> { 959 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 960 961 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 962 PatternRewriter &rewriter) const override { 963 SmallVector<AffineMap, 4> maps = 964 llvm::to_vector<4>(contractOp.getIndexingMaps()); 965 Value lhs = contractOp.lhs(); 966 Value rhs = contractOp.rhs(); 967 size_t index = 0; 968 bool changed = false; 969 for (Value *operand : {&lhs, &rhs}) { 970 AffineMap &map = maps[index++]; 971 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>(); 972 if (!broadcast) 973 continue; 974 // contractionOp can only take vector as operands. 975 auto srcType = broadcast.getSourceType().dyn_cast<VectorType>(); 976 if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank()) 977 continue; 978 int64_t rankDiff = 979 broadcast.getVectorType().getRank() - srcType.getRank(); 980 bool innerDimBroadcast = false; 981 SmallVector<AffineExpr> originalDims; 982 for (const auto &dim : llvm::enumerate(srcType.getShape())) { 983 if (dim.value() != 984 broadcast.getVectorType().getDimSize(rankDiff + dim.index())) { 985 innerDimBroadcast = true; 986 break; 987 } 988 originalDims.push_back( 989 rewriter.getAffineDimExpr(dim.index() + rankDiff)); 990 } 991 // Contract doesn't support inner dimension broadcast. Once this is 992 // relaxed we can remove this case. 993 if (innerDimBroadcast) 994 continue; 995 AffineMap broadcastMap = 996 AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims, 997 contractOp.getContext()); 998 map = broadcastMap.compose(map); 999 *operand = broadcast.source(); 1000 changed = true; 1001 } 1002 if (!changed) 1003 return failure(); 1004 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 1005 contractOp, lhs, rhs, contractOp.acc(), 1006 rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); 1007 return success(); 1008 } 1009 }; 1010 1011 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and 1012 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when 1013 /// casting ops are around these operations. 1014 /// Ex: 1015 /// ``` 1016 /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8> 1017 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> 1018 /// ``` 1019 /// Gets converted to: 1020 /// ``` 1021 /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32> 1022 /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32> 1023 /// ``` 1024 struct ReorderCastOpsOnBroadcast 1025 : public OpInterfaceRewritePattern<CastOpInterface> { 1026 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern; 1027 1028 LogicalResult matchAndRewrite(CastOpInterface op, 1029 PatternRewriter &rewriter) const override { 1030 if (op->getNumOperands() != 1) 1031 return failure(); 1032 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>(); 1033 if (!bcastOp) 1034 return failure(); 1035 1036 Type castResTy = getElementTypeOrSelf(op->getResult(0)); 1037 if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>()) 1038 castResTy = VectorType::get(vecTy.getShape(), castResTy); 1039 auto castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), 1040 bcastOp.source(), castResTy, op->getAttrs()); 1041 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1042 op, op->getResult(0).getType(), castOp->getResult(0)); 1043 return success(); 1044 } 1045 }; 1046 1047 /// Reorders cast(transpose) to transpose(cast). This makes broadcast ops and 1048 /// contraction ops closer, which kicks in CombineContractTranspose pattern when 1049 /// casting ops are around these operations. 1050 /// Ex: 1051 /// ``` 1052 /// %0 = vector.transpose %arg0, [2, 0, 1] 1053 /// : vector<32x16x8xi8> to vector<8x32x16xi8> 1054 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> 1055 /// ``` 1056 /// Gets converted to: 1057 /// ``` 1058 /// %0 = arith.extsi %0 : vector<32x16x8xi8> to vector<32x16x8xi32> 1059 /// %1 = vector.transpose %arg0, [2, 0, 1] 1060 /// : vector<32x16x8xi32> to vector<8x32x16xi32> 1061 /// ``` 1062 struct ReorderCastOpsOnTranspose 1063 : public OpInterfaceRewritePattern<CastOpInterface> { 1064 1065 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern; 1066 1067 LogicalResult matchAndRewrite(CastOpInterface op, 1068 PatternRewriter &rewriter) const override { 1069 if (op->getNumOperands() != 1) 1070 return failure(); 1071 auto transpOp = op->getOperand(0).getDefiningOp<vector::TransposeOp>(); 1072 if (!transpOp) 1073 return failure(); 1074 1075 auto castResTy = transpOp.getVectorType(); 1076 castResTy = VectorType::get(castResTy.getShape(), 1077 getElementTypeOrSelf(op->getResult(0))); 1078 auto castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), 1079 transpOp.vector(), castResTy, op->getAttrs()); 1080 rewriter.replaceOpWithNewOp<vector::TransposeOp>( 1081 op, op->getResult(0).getType(), castOp->getResult(0), 1082 transpOp.getTransp()); 1083 return success(); 1084 } 1085 }; 1086 1087 } // namespace 1088 1089 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using 1090 /// operands `x` and `y`. 1091 static Value createAdd(Location loc, Value x, Value y, bool isInt, 1092 PatternRewriter &rewriter) { 1093 if (isInt) 1094 return rewriter.create<arith::AddIOp>(loc, x, y); 1095 return rewriter.create<arith::AddFOp>(loc, x, y); 1096 } 1097 1098 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using 1099 /// operands `x and `y`. 1100 static Value createMul(Location loc, Value x, Value y, bool isInt, 1101 PatternRewriter &rewriter) { 1102 if (isInt) 1103 return rewriter.create<arith::MulIOp>(loc, x, y); 1104 return rewriter.create<arith::MulFOp>(loc, x, y); 1105 } 1106 1107 namespace mlir { 1108 1109 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1110 /// semantics to: 1111 /// ``` 1112 /// %mta = maybe_transpose 1113 /// %mtb = maybe_transpose 1114 /// %flattened_a = vector.shape_cast %mta 1115 /// %flattened_b = vector.shape_cast %mtb 1116 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 1117 /// %mtd = vector.shape_cast %flattened_d 1118 /// %d = maybe_untranspose %mtd 1119 /// %e = add %c, %d 1120 /// ``` 1121 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 1122 // 1123 /// This only kicks in when VectorTransformsOptions is set to `Matmul`. 1124 /// vector.transpose operations are inserted if the vector.contract op is not a 1125 /// row-major matrix multiply. 1126 LogicalResult 1127 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, 1128 PatternRewriter &rew) const { 1129 // TODO: implement masks 1130 if (llvm::size(op.masks()) != 0) 1131 return failure(); 1132 if (vectorTransformOptions.vectorContractLowering != 1133 vector::VectorContractLowering::Matmul) 1134 return failure(); 1135 if (failed(filter(op))) 1136 return failure(); 1137 1138 auto iteratorTypes = op.iterator_types().getValue(); 1139 if (!isParallelIterator(iteratorTypes[0]) || 1140 !isParallelIterator(iteratorTypes[1]) || 1141 !isReductionIterator(iteratorTypes[2])) 1142 return failure(); 1143 1144 Type elementType = op.getLhsType().getElementType(); 1145 if (!elementType.isIntOrFloat()) 1146 return failure(); 1147 1148 // Perform lhs + rhs transpositions to conform to matmul row-major semantics. 1149 // Bail out if the contraction cannot be put in this form. 1150 MLIRContext *ctx = op.getContext(); 1151 Location loc = op.getLoc(); 1152 AffineExpr m, n, k; 1153 bindDims(rew.getContext(), m, n, k); 1154 // LHS must be A(m, k) or A(k, m). 1155 Value lhs = op.lhs(); 1156 auto lhsMap = op.indexing_maps()[0]; 1157 if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) 1158 lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0}); 1159 else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) 1160 return failure(); 1161 1162 // RHS must be B(k, n) or B(n, k). 1163 Value rhs = op.rhs(); 1164 auto rhsMap = op.indexing_maps()[1]; 1165 if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) 1166 rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0}); 1167 else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) 1168 return failure(); 1169 1170 // At this point lhs and rhs are in row-major. 1171 VectorType lhsType = lhs.getType().cast<VectorType>(); 1172 VectorType rhsType = rhs.getType().cast<VectorType>(); 1173 int64_t lhsRows = lhsType.getDimSize(0); 1174 int64_t lhsColumns = lhsType.getDimSize(1); 1175 int64_t rhsColumns = rhsType.getDimSize(1); 1176 1177 Type flattenedLHSType = 1178 VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); 1179 lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs); 1180 1181 Type flattenedRHSType = 1182 VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); 1183 rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs); 1184 1185 Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns, 1186 rhsColumns); 1187 mul = rew.create<vector::ShapeCastOp>( 1188 loc, 1189 VectorType::get({lhsRows, rhsColumns}, 1190 getElementTypeOrSelf(op.acc().getType())), 1191 mul); 1192 1193 // ACC must be C(m, n) or C(n, m). 1194 auto accMap = op.indexing_maps()[2]; 1195 if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) 1196 mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0}); 1197 else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) 1198 llvm_unreachable("invalid contraction semantics"); 1199 1200 Value res = 1201 elementType.isa<IntegerType>() 1202 ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.acc(), mul)) 1203 : static_cast<Value>(rew.create<arith::AddFOp>(loc, op.acc(), mul)); 1204 1205 rew.replaceOp(op, res); 1206 return success(); 1207 } 1208 1209 namespace { 1210 struct IteratorType { 1211 IteratorType(StringRef strRef) : strRef(strRef) {} 1212 bool isOfType(Attribute attr) const { 1213 auto sAttr = attr.dyn_cast<StringAttr>(); 1214 return sAttr && sAttr.getValue() == strRef; 1215 } 1216 StringRef strRef; 1217 }; 1218 struct Par : public IteratorType { 1219 Par() : IteratorType(getParallelIteratorTypeName()) {} 1220 }; 1221 struct Red : public IteratorType { 1222 Red() : IteratorType(getReductionIteratorTypeName()) {} 1223 }; 1224 1225 /// Generate a vector implementation for matmat, matvec and tmatvec. 1226 /// This unrolls outer-products along the reduction dimension. 1227 struct UnrolledOuterProductGenerator 1228 : public StructuredGenerator<vector::ContractionOp> { 1229 1230 UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op) 1231 : StructuredGenerator<vector::ContractionOp>(builder, op), 1232 kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()), 1233 lhsType(op.getLhsType()) {} 1234 1235 Value t(Value v) { 1236 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1237 return builder.create<vector::TransposeOp>(loc, v, perm); 1238 } 1239 1240 Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) { 1241 assert(reductionSize > 0); 1242 for (int64_t k = 0; k < reductionSize; ++k) { 1243 Value a = builder.create<vector::ExtractOp>(loc, lhs, k); 1244 Value b = builder.create<vector::ExtractOp>(loc, rhs, k); 1245 res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b, 1246 res, kind); 1247 } 1248 return res; 1249 } 1250 1251 /// Two outer parallel, one inner reduction (matmat flavor). 1252 FailureOr<Value> matmat() { 1253 if (!iters({Par(), Par(), Red()})) 1254 return failure(); 1255 // Set up the parallel/reduction structure in the right form. 1256 AffineExpr m, n, k; 1257 bindDims(builder.getContext(), m, n, k); 1258 // Classical row-major matmul: Just permute the lhs. 1259 if (layout({{m, k}, {k, n}, {m, n}})) 1260 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1261 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1262 if (layout({{m, k}, {n, k}, {m, n}})) { 1263 Value tlhs = t(lhs); 1264 return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1)); 1265 } 1266 // No need to permute anything. 1267 if (layout({{k, m}, {k, n}, {m, n}})) 1268 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1269 // Just permute the rhs. 1270 if (layout({{k, m}, {n, k}, {m, n}})) 1271 return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0)); 1272 // Transposed output: swap RHS and LHS. 1273 // Classical row-major matmul: permute the lhs. 1274 if (layout({{m, k}, {k, n}, {n, m}})) 1275 return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1)); 1276 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1277 if (layout({{m, k}, {n, k}, {n, m}})) { 1278 Value trhs = t(rhs); 1279 return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1)); 1280 } 1281 if (layout({{k, m}, {k, n}, {n, m}})) 1282 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1283 if (layout({{k, m}, {n, k}, {n, m}})) 1284 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1285 return failure(); 1286 } 1287 1288 /// One outer parallel, one inner reduction (matvec flavor) 1289 FailureOr<Value> matvec() { 1290 if (!iters({Par(), Red()})) 1291 return failure(); 1292 AffineExpr m, k; 1293 bindDims(builder.getContext(), m, k); 1294 1295 // Case mat-vec: transpose. 1296 if (layout({{m, k}, {k}, {m}})) 1297 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1298 // Case mat-trans-vec: ready to go. 1299 if (layout({{k, m}, {k}, {m}})) 1300 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1301 // Case vec-mat: swap and transpose. 1302 if (layout({{k}, {m, k}, {m}})) 1303 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1304 // Case vec-mat-trans: swap and ready to go. 1305 if (layout({{k}, {k, m}, {m}})) 1306 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1307 return failure(); 1308 } 1309 1310 // 1311 // One outer reduction, one inner parallel (tmatvec flavor) 1312 // 1313 FailureOr<Value> tmatvec() { 1314 if (!iters({Red(), Par()})) 1315 return failure(); 1316 AffineExpr k, m; 1317 bindDims(builder.getContext(), k, m); 1318 1319 // Case mat-vec: transpose. 1320 if (layout({{m, k}, {k}, {m}})) 1321 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1322 // Case mat-trans-vec: ready to go. 1323 if (layout({{k, m}, {k}, {m}})) 1324 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1325 // Case vec-mat: swap and transpose. 1326 if (layout({{k}, {m, k}, {m}})) 1327 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1328 // Case vec-mat-trans: swap and ready to go. 1329 if (layout({{k}, {k, m}, {m}})) 1330 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1331 return failure(); 1332 } 1333 1334 private: 1335 vector::CombiningKind kind; 1336 Value lhs, rhs, res; 1337 VectorType lhsType; 1338 }; 1339 } // namespace 1340 1341 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1342 /// semantics to a reduction_size-unrolled sequence: 1343 /// ``` 1344 /// %at = vector.transpose %a, [1, 0] 1345 /// %bRow0 = vector.extract %b[0] 1346 /// %atRow0 = vector.extract %at[0] 1347 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 1348 /// ... 1349 /// %bRowK = vector.extract %b[K] 1350 /// %atRowK = vector.extract %at[K] 1351 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 1352 /// ``` 1353 /// 1354 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but 1355 /// otherwise supports any layout permutation of the matrix-multiply. 1356 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( 1357 vector::ContractionOp op, PatternRewriter &rewriter) const { 1358 // TODO: implement masks 1359 if (llvm::size(op.masks()) != 0) 1360 return failure(); 1361 1362 if (vectorTransformOptions.vectorContractLowering != 1363 vector::VectorContractLowering::OuterProduct) 1364 return failure(); 1365 1366 if (failed(filter(op))) 1367 return failure(); 1368 1369 UnrolledOuterProductGenerator e(rewriter, op); 1370 FailureOr<Value> matmatRes = e.matmat(); 1371 if (succeeded(matmatRes)) { 1372 rewriter.replaceOp(op, *matmatRes); 1373 return success(); 1374 } 1375 FailureOr<Value> matvecRes = e.matvec(); 1376 if (succeeded(matvecRes)) { 1377 rewriter.replaceOp(op, *matvecRes); 1378 return success(); 1379 } 1380 FailureOr<Value> tmatvecRes = e.tmatvec(); 1381 if (succeeded(tmatvecRes)) { 1382 rewriter.replaceOp(op, *tmatvecRes); 1383 return success(); 1384 } 1385 1386 return failure(); 1387 } 1388 1389 LogicalResult 1390 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, 1391 PatternRewriter &rewriter) const { 1392 // TODO: implement masks 1393 if (llvm::size(op.masks()) != 0) 1394 return failure(); 1395 1396 if (failed(filter(op))) 1397 return failure(); 1398 1399 if (vectorTransformOptions.vectorContractLowering != 1400 vector::VectorContractLowering::Dot) 1401 return failure(); 1402 1403 auto iteratorTypes = op.iterator_types().getValue(); 1404 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1405 Location loc = op.getLoc(); 1406 Value lhs = op.lhs(), rhs = op.rhs(); 1407 1408 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1409 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 1410 AffineExpr m, n, k; 1411 bindDims(rewriter.getContext(), m, n, k); 1412 SmallVector<AffineMap, 4> maps = op.getIndexingMaps(); 1413 // 1414 // In the following we wish to make the reduction dimension innermost so we 1415 // can load vectors and just fmul + reduce into a scalar. 1416 // 1417 if (isParallelIterator(iteratorTypes[0]) && 1418 isParallelIterator(iteratorTypes[1]) && 1419 isReductionIterator(iteratorTypes[2])) { 1420 // 1421 // Two outer parallel, one inner reduction (matmat flavor). 1422 // 1423 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 1424 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1425 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { 1426 // No need to permute anything. 1427 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 1428 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1429 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1430 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 1431 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1432 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 1433 // This is the classical row-major matmul. Just permute the lhs. 1434 Value tmp = lhs; 1435 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1436 rhs = tmp; 1437 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 1438 std::swap(lhs, rhs); 1439 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 1440 Value tmp = lhs; 1441 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1442 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm); 1443 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 1444 Value tmp = rhs; 1445 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1446 lhs = tmp; 1447 } else { 1448 return failure(); 1449 } 1450 } else if (isParallelIterator(iteratorTypes[0]) && 1451 isReductionIterator(iteratorTypes[1])) { 1452 // 1453 // One outer parallel, one inner reduction (matvec flavor) 1454 // 1455 if (maps == infer({{m, n}, {n}, {m}})) { 1456 // No need to permute anything. 1457 } else if (maps == infer({{n, m}, {n}, {m}})) { 1458 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1459 } else if (maps == infer({{n}, {m, n}, {m}})) { 1460 std::swap(lhs, rhs); 1461 } else if (maps == infer({{n}, {n, m}, {m}})) { 1462 std::swap(lhs, rhs); 1463 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1464 } else { 1465 return failure(); 1466 } 1467 } else { 1468 return failure(); 1469 } 1470 1471 VectorType dstType = op.getResultType().cast<VectorType>(); 1472 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && 1473 "Expected dst type of rank 1 or 2"); 1474 1475 unsigned rank = dstType.getRank(); 1476 unsigned dstRows = dstType.getShape()[0]; 1477 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; 1478 1479 // ExtractOp does not allow dynamic indexing, we must unroll explicitly. 1480 Value res = rewriter.create<arith::ConstantOp>(loc, dstType, 1481 rewriter.getZeroAttr(dstType)); 1482 bool isInt = dstType.getElementType().isa<IntegerType>(); 1483 for (unsigned r = 0; r < dstRows; ++r) { 1484 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r); 1485 for (unsigned c = 0; c < dstColumns; ++c) { 1486 Value b = rank == 1 1487 ? rhs 1488 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c); 1489 Value m = createMul(op.getLoc(), a, b, isInt, rewriter); 1490 Value reduced = rewriter.create<vector::ReductionOp>( 1491 op.getLoc(), vector::CombiningKind::ADD, m); 1492 1493 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r} 1494 : SmallVector<int64_t, 2>{r, c}; 1495 res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos); 1496 } 1497 } 1498 if (auto acc = op.acc()) 1499 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1500 rewriter.replaceOp(op, res); 1501 return success(); 1502 } 1503 1504 /// Progressive lowering of ContractionOp. 1505 /// One: 1506 /// %x = vector.contract with at least one free/batch dimension 1507 /// is replaced by: 1508 /// %a = vector.contract with one less free/batch dimension 1509 /// %b = vector.contract with one less free/batch dimension 1510 /// .. 1511 /// %x = combine %a %b .. 1512 /// until a pure contraction is reached (no free/batch dimensions), 1513 /// which is replaced by a dot-product. 1514 /// 1515 /// This only kicks in when either VectorTransformsOptions is set 1516 /// to DOT or when other contraction patterns fail. 1517 // 1518 // TODO: break down into transpose/reshape/cast ops 1519 // when they become available to avoid code dup 1520 // TODO: investigate lowering order impact on performance 1521 LogicalResult 1522 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, 1523 PatternRewriter &rewriter) const { 1524 // TODO: implement masks. 1525 if (llvm::size(op.masks()) != 0) 1526 return failure(); 1527 1528 if (failed(filter(op))) 1529 return failure(); 1530 1531 // TODO: support mixed mode contract lowering. 1532 if (op.getLhsType().getElementType() != 1533 getElementTypeOrSelf(op.getAccType()) || 1534 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) 1535 return failure(); 1536 1537 // TODO: implement benefits, cost models. 1538 MLIRContext *ctx = op.getContext(); 1539 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); 1540 if (succeeded(pat1.matchAndRewrite(op, rewriter))) 1541 return success(); 1542 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); 1543 if (succeeded(pat2.matchAndRewrite(op, rewriter))) 1544 return success(); 1545 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); 1546 if (succeeded(pat3.matchAndRewrite(op, rewriter))) 1547 return success(); 1548 1549 // Find first batch dimension in LHS/RHS, and lower when found. 1550 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap(); 1551 if (!batchDimMap.empty()) { 1552 int64_t lhsIndex = batchDimMap[0].first; 1553 int64_t rhsIndex = batchDimMap[0].second; 1554 rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter)); 1555 return success(); 1556 } 1557 1558 // Collect contracting dimensions. 1559 std::vector<std::pair<int64_t, int64_t>> contractingDimMap = 1560 op.getContractingDimMap(); 1561 DenseSet<int64_t> lhsContractingDimSet; 1562 DenseSet<int64_t> rhsContractingDimSet; 1563 for (auto &dimPair : contractingDimMap) { 1564 lhsContractingDimSet.insert(dimPair.first); 1565 rhsContractingDimSet.insert(dimPair.second); 1566 } 1567 1568 // Find first free dimension in LHS, and lower when found. 1569 VectorType lhsType = op.getLhsType(); 1570 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { 1571 if (lhsContractingDimSet.count(lhsIndex) == 0) { 1572 rewriter.replaceOp( 1573 op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter)); 1574 return success(); 1575 } 1576 } 1577 1578 // Find first free dimension in RHS, and lower when found. 1579 VectorType rhsType = op.getRhsType(); 1580 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { 1581 if (rhsContractingDimSet.count(rhsIndex) == 0) { 1582 rewriter.replaceOp( 1583 op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter)); 1584 return success(); 1585 } 1586 } 1587 1588 // Lower the first remaining reduction dimension. 1589 if (!contractingDimMap.empty()) { 1590 rewriter.replaceOp(op, lowerReduction(op, rewriter)); 1591 return success(); 1592 } 1593 1594 return failure(); 1595 } 1596 1597 // Lower one parallel dimension. 1598 // TODO: consider reusing existing contract unrolling 1599 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op, 1600 int64_t lhsIndex, int64_t rhsIndex, 1601 PatternRewriter &rewriter) const { 1602 VectorType lhsType = op.getLhsType(); 1603 VectorType rhsType = op.getRhsType(); 1604 VectorType resType = op.getResultType().cast<VectorType>(); 1605 // Find the iterator type index and result index. 1606 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1607 int64_t iterIndex = -1; 1608 int64_t dimSize = -1; 1609 if (lhsIndex >= 0) { 1610 iterIndex = iMap[0].getDimPosition(lhsIndex); 1611 assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) && 1612 "parallel index should be free in LHS or batch in LHS/RHS"); 1613 dimSize = lhsType.getDimSize(lhsIndex); 1614 } else { 1615 assert(rhsIndex >= 0 && "missing parallel index"); 1616 iterIndex = iMap[1].getDimPosition(rhsIndex); 1617 dimSize = rhsType.getDimSize(rhsIndex); 1618 } 1619 assert(iterIndex >= 0 && "parallel index not listed in operand mapping"); 1620 Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex); 1621 assert(lookup.hasValue() && "parallel index not listed in reduction"); 1622 int64_t resIndex = lookup.getValue(); 1623 // Construct new iterator types and affine map array attribute. 1624 std::array<AffineMap, 3> lowIndexingMaps = { 1625 adjustMap(iMap[0], iterIndex, rewriter), 1626 adjustMap(iMap[1], iterIndex, rewriter), 1627 adjustMap(iMap[2], iterIndex, rewriter)}; 1628 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1629 auto lowIter = 1630 rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); 1631 // Unroll into a series of lower dimensional vector.contract ops. 1632 Location loc = op.getLoc(); 1633 Value result = rewriter.create<arith::ConstantOp>( 1634 loc, resType, rewriter.getZeroAttr(resType)); 1635 for (int64_t d = 0; d < dimSize; ++d) { 1636 auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); 1637 auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); 1638 auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter); 1639 Value lowContract = rewriter.create<vector::ContractionOp>( 1640 loc, lhs, rhs, acc, lowAffine, lowIter); 1641 result = 1642 reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter); 1643 } 1644 return result; 1645 } 1646 1647 // Lower one reduction dimension. 1648 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op, 1649 PatternRewriter &rewriter) const { 1650 auto loc = op.getLoc(); 1651 VectorType lhsType = op.getLhsType(); 1652 VectorType rhsType = op.getRhsType(); 1653 Type resType = op.getResultType(); 1654 assert(!resType.isa<VectorType>()); 1655 bool isInt = resType.isa<IntegerType>(); 1656 // Use iterator index 0. 1657 int64_t iterIndex = 0; 1658 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1659 Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex); 1660 Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex); 1661 assert(lookupLhs.hasValue() && "missing LHS parallel index"); 1662 assert(lookupRhs.hasValue() && "missing RHS parallel index"); 1663 int64_t lhsIndex = lookupLhs.getValue(); 1664 int64_t rhsIndex = lookupRhs.getValue(); 1665 int64_t dimSize = lhsType.getDimSize(lhsIndex); 1666 assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape"); 1667 // Base case. 1668 if (lhsType.getRank() == 1) { 1669 assert(rhsType.getRank() == 1 && "corrupt contraction"); 1670 Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter); 1671 auto kind = vector::CombiningKind::ADD; 1672 Value res = rewriter.create<vector::ReductionOp>(loc, kind, m); 1673 if (auto acc = op.acc()) 1674 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1675 return res; 1676 } 1677 // Construct new iterator types and affine map array attribute. 1678 std::array<AffineMap, 3> lowIndexingMaps = { 1679 adjustMap(iMap[0], iterIndex, rewriter), 1680 adjustMap(iMap[1], iterIndex, rewriter), 1681 adjustMap(iMap[2], iterIndex, rewriter)}; 1682 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1683 auto lowIter = 1684 rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); 1685 // Unroll into a series of lower dimensional vector.contract ops. 1686 // By feeding the initial accumulator into the first contraction, 1687 // and the result of each contraction into the next, eventually 1688 // the sum of all reductions is computed. 1689 Value result = op.acc(); 1690 for (int64_t d = 0; d < dimSize; ++d) { 1691 auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); 1692 auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); 1693 result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result, 1694 lowAffine, lowIter); 1695 } 1696 return result; 1697 } 1698 1699 } // namespace mlir 1700 1701 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp( 1702 OpBuilder &builder, Operation *op, ArrayRef<Value> ids, 1703 ArrayRef<int64_t> multiplicity, const AffineMap &map) { 1704 OpBuilder::InsertionGuard guard(builder); 1705 builder.setInsertionPointAfter(op); 1706 Location loc = op->getLoc(); 1707 if (op->getNumResults() != 1) 1708 return {}; 1709 Value result = op->getResult(0); 1710 VectorType type = op->getResult(0).getType().dyn_cast<VectorType>(); 1711 if (!type || map.getNumResults() != multiplicity.size()) 1712 return {}; 1713 // For each dimension being distributed check that the size is a multiple of 1714 // the multiplicity. To handle more sizes we would need to support masking. 1715 unsigned multiplictyCount = 0; 1716 for (auto exp : map.getResults()) { 1717 auto affinExp = exp.dyn_cast<AffineDimExpr>(); 1718 if (!affinExp || affinExp.getPosition() >= type.getRank() || 1719 type.getDimSize(affinExp.getPosition()) % 1720 multiplicity[multiplictyCount++] != 1721 0) 1722 return {}; 1723 } 1724 DistributeOps ops; 1725 ops.extract = 1726 builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map); 1727 ops.insert = 1728 builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids); 1729 return ops; 1730 } 1731 1732 /// Progressive lowering of transfer_read. This pattern supports lowering of 1733 /// `vector.transfer_read` to a combination of `vector.load` and 1734 /// `vector.broadcast` if all of the following hold: 1735 /// - Stride of most minor memref dimension must be 1. 1736 /// - Out-of-bounds masking is not required. 1737 /// - If the memref's element type is a vector type then it coincides with the 1738 /// result type. 1739 /// - The permutation map doesn't perform permutation (broadcasting is allowed). 1740 struct TransferReadToVectorLoadLowering 1741 : public OpRewritePattern<vector::TransferReadOp> { 1742 TransferReadToVectorLoadLowering(MLIRContext *context, 1743 llvm::Optional<unsigned> maxRank) 1744 : OpRewritePattern<vector::TransferReadOp>(context), 1745 maxTransferRank(maxRank) {} 1746 1747 LogicalResult matchAndRewrite(vector::TransferReadOp read, 1748 PatternRewriter &rewriter) const override { 1749 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) 1750 return failure(); 1751 1752 SmallVector<unsigned, 4> broadcastedDims; 1753 // Permutations are handled by VectorToSCF or 1754 // populateVectorTransferPermutationMapLoweringPatterns. 1755 // We let the 0-d corner case pass-through as it is supported. 1756 if (!read.permutation_map().isMinorIdentityWithBroadcasting( 1757 &broadcastedDims)) 1758 return failure(); 1759 1760 auto memRefType = read.getShapedType().dyn_cast<MemRefType>(); 1761 if (!memRefType) 1762 return failure(); 1763 1764 // Non-unit strides are handled by VectorToSCF. 1765 if (!vector::isLastMemrefDimUnitStride(memRefType)) 1766 return failure(); 1767 1768 // If there is broadcasting involved then we first load the unbroadcasted 1769 // vector, and then broadcast it with `vector.broadcast`. 1770 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape(); 1771 SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(), 1772 vectorShape.end()); 1773 for (unsigned i : broadcastedDims) 1774 unbroadcastedVectorShape[i] = 1; 1775 VectorType unbroadcastedVectorType = VectorType::get( 1776 unbroadcastedVectorShape, read.getVectorType().getElementType()); 1777 1778 // `vector.load` supports vector types as memref's elements only when the 1779 // resulting vector type is the same as the element type. 1780 auto memrefElTy = memRefType.getElementType(); 1781 if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType) 1782 return failure(); 1783 1784 // Otherwise, element types of the memref and the vector must match. 1785 if (!memrefElTy.isa<VectorType>() && 1786 memrefElTy != read.getVectorType().getElementType()) 1787 return failure(); 1788 1789 // Out-of-bounds dims are handled by MaterializeTransferMask. 1790 if (read.hasOutOfBoundsDim()) 1791 return failure(); 1792 1793 // Create vector load op. 1794 Operation *loadOp; 1795 if (read.mask()) { 1796 Value fill = rewriter.create<vector::SplatOp>( 1797 read.getLoc(), unbroadcastedVectorType, read.padding()); 1798 loadOp = rewriter.create<vector::MaskedLoadOp>( 1799 read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(), 1800 read.mask(), fill); 1801 } else { 1802 loadOp = rewriter.create<vector::LoadOp>(read.getLoc(), 1803 unbroadcastedVectorType, 1804 read.source(), read.indices()); 1805 } 1806 1807 // Insert a broadcasting op if required. 1808 if (!broadcastedDims.empty()) { 1809 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1810 read, read.getVectorType(), loadOp->getResult(0)); 1811 } else { 1812 rewriter.replaceOp(read, loadOp->getResult(0)); 1813 } 1814 1815 return success(); 1816 } 1817 1818 llvm::Optional<unsigned> maxTransferRank; 1819 }; 1820 1821 /// Replace a 0-d vector.load with a memref.load + vector.broadcast. 1822 // TODO: we shouldn't cross the vector/scalar domains just for this 1823 // but atm we lack the infra to avoid it. Possible solutions include: 1824 // - go directly to LLVM + bitcast 1825 // - introduce a bitcast op and likely a new pointer dialect 1826 // - let memref.load/store additionally support the 0-d vector case 1827 // There are still deeper data layout issues lingering even in this 1828 // trivial case (for architectures for which this matters). 1829 struct VectorLoadToMemrefLoadLowering 1830 : public OpRewritePattern<vector::LoadOp> { 1831 using OpRewritePattern<vector::LoadOp>::OpRewritePattern; 1832 1833 LogicalResult matchAndRewrite(vector::LoadOp loadOp, 1834 PatternRewriter &rewriter) const override { 1835 auto vecType = loadOp.getVectorType(); 1836 if (vecType.getNumElements() != 1) 1837 return failure(); 1838 auto memrefLoad = rewriter.create<memref::LoadOp>( 1839 loadOp.getLoc(), loadOp.base(), loadOp.indices()); 1840 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType, 1841 memrefLoad); 1842 return success(); 1843 } 1844 }; 1845 1846 /// Replace a 0-d vector.store with a vector.extractelement + memref.store. 1847 struct VectorStoreToMemrefStoreLowering 1848 : public OpRewritePattern<vector::StoreOp> { 1849 using OpRewritePattern<vector::StoreOp>::OpRewritePattern; 1850 1851 LogicalResult matchAndRewrite(vector::StoreOp storeOp, 1852 PatternRewriter &rewriter) const override { 1853 auto vecType = storeOp.getVectorType(); 1854 if (vecType.getNumElements() != 1) 1855 return failure(); 1856 Value extracted; 1857 if (vecType.getRank() == 0) { 1858 // TODO: Unifiy once ExtractOp supports 0-d vectors. 1859 extracted = rewriter.create<vector::ExtractElementOp>( 1860 storeOp.getLoc(), storeOp.valueToStore()); 1861 } else { 1862 SmallVector<int64_t> indices(vecType.getRank(), 0); 1863 extracted = rewriter.create<vector::ExtractOp>( 1864 storeOp.getLoc(), storeOp.valueToStore(), indices); 1865 } 1866 1867 rewriter.replaceOpWithNewOp<memref::StoreOp>( 1868 storeOp, extracted, storeOp.base(), storeOp.indices()); 1869 return success(); 1870 } 1871 }; 1872 1873 /// Progressive lowering of transfer_write. This pattern supports lowering of 1874 /// `vector.transfer_write` to `vector.store` if all of the following hold: 1875 /// - Stride of most minor memref dimension must be 1. 1876 /// - Out-of-bounds masking is not required. 1877 /// - If the memref's element type is a vector type then it coincides with the 1878 /// type of the written value. 1879 /// - The permutation map is the minor identity map (neither permutation nor 1880 /// broadcasting is allowed). 1881 struct TransferWriteToVectorStoreLowering 1882 : public OpRewritePattern<vector::TransferWriteOp> { 1883 TransferWriteToVectorStoreLowering(MLIRContext *context, 1884 llvm::Optional<unsigned> maxRank) 1885 : OpRewritePattern<vector::TransferWriteOp>(context), 1886 maxTransferRank(maxRank) {} 1887 1888 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 1889 PatternRewriter &rewriter) const override { 1890 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) 1891 return failure(); 1892 1893 // Permutations are handled by VectorToSCF or 1894 // populateVectorTransferPermutationMapLoweringPatterns. 1895 if ( // pass-through for the 0-d corner case. 1896 !write.permutation_map().isMinorIdentity()) 1897 return failure(); 1898 1899 auto memRefType = write.getShapedType().dyn_cast<MemRefType>(); 1900 if (!memRefType) 1901 return failure(); 1902 1903 // Non-unit strides are handled by VectorToSCF. 1904 if (!vector::isLastMemrefDimUnitStride(memRefType)) 1905 return failure(); 1906 1907 // `vector.store` supports vector types as memref's elements only when the 1908 // type of the vector value being written is the same as the element type. 1909 auto memrefElTy = memRefType.getElementType(); 1910 if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType()) 1911 return failure(); 1912 1913 // Otherwise, element types of the memref and the vector must match. 1914 if (!memrefElTy.isa<VectorType>() && 1915 memrefElTy != write.getVectorType().getElementType()) 1916 return failure(); 1917 1918 // Out-of-bounds dims are handled by MaterializeTransferMask. 1919 if (write.hasOutOfBoundsDim()) 1920 return failure(); 1921 if (write.mask()) { 1922 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 1923 write, write.source(), write.indices(), write.mask(), write.vector()); 1924 } else { 1925 rewriter.replaceOpWithNewOp<vector::StoreOp>( 1926 write, write.vector(), write.source(), write.indices()); 1927 } 1928 return success(); 1929 } 1930 1931 llvm::Optional<unsigned> maxTransferRank; 1932 }; 1933 1934 // Returns the values in `arrayAttr` as an integer vector. 1935 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) { 1936 return llvm::to_vector<4>( 1937 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(), 1938 [](IntegerAttr attr) { return attr.getInt(); })); 1939 } 1940 1941 // Shuffles vector.bitcast op after vector.extract op. 1942 // 1943 // This transforms IR like: 1944 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> 1945 // %1 = vector.extract %0[3] : vector<8xf16> 1946 // Into: 1947 // %0 = vector.extract %src[1] : vector<4xf32> 1948 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16> 1949 // %2 = vector.extract %1[1] : vector<2xf16> 1950 struct BubbleDownVectorBitCastForExtract 1951 : public OpRewritePattern<vector::ExtractOp> { 1952 using OpRewritePattern::OpRewritePattern; 1953 1954 LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 1955 PatternRewriter &rewriter) const override { 1956 // Only support extracting scalars for now. 1957 if (extractOp.getVectorType().getRank() != 1) 1958 return failure(); 1959 1960 auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>(); 1961 if (!castOp) 1962 return failure(); 1963 1964 VectorType castSrcType = castOp.getSourceVectorType(); 1965 VectorType castDstType = castOp.getResultVectorType(); 1966 assert(castSrcType.getRank() == castDstType.getRank()); 1967 1968 // Fail to match if we only have one element in the cast op source. 1969 // This is to avoid infinite loop given that this pattern can generate 1970 // such cases. 1971 if (castSrcType.getNumElements() == 1) 1972 return failure(); 1973 1974 // Only support casting to a larger number of elements or now. 1975 // E.g., vector<4xf32> -> vector<8xf16>. 1976 if (castSrcType.getNumElements() > castDstType.getNumElements()) 1977 return failure(); 1978 1979 unsigned expandRatio = 1980 castDstType.getNumElements() / castSrcType.getNumElements(); 1981 1982 auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t { 1983 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 1984 }; 1985 1986 uint64_t index = getFirstIntValue(extractOp.position()); 1987 1988 // Get the single scalar (as a vector) in the source value that packs the 1989 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> 1990 VectorType oneScalarType = 1991 VectorType::get({1}, castSrcType.getElementType()); 1992 Value packedValue = rewriter.create<vector::ExtractOp>( 1993 extractOp.getLoc(), oneScalarType, castOp.source(), 1994 rewriter.getI64ArrayAttr(index / expandRatio)); 1995 1996 // Cast it to a vector with the desired scalar's type. 1997 // E.g. f32 -> vector<2xf16> 1998 VectorType packedType = 1999 VectorType::get({expandRatio}, castDstType.getElementType()); 2000 Value castedValue = rewriter.create<vector::BitCastOp>( 2001 extractOp.getLoc(), packedType, packedValue); 2002 2003 // Finally extract the desired scalar. 2004 rewriter.replaceOpWithNewOp<vector::ExtractOp>( 2005 extractOp, extractOp.getType(), castedValue, 2006 rewriter.getI64ArrayAttr(index % expandRatio)); 2007 2008 return success(); 2009 } 2010 }; 2011 2012 // Shuffles vector.bitcast op after vector.extract_strided_slice op. 2013 // 2014 // This transforms IR like: 2015 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> 2016 // %0 = vector.extract_strided_slice %cast { 2017 // offsets = [4], sizes = [4], strides = [1] 2018 // } : vector<8xf16> to vector<4xf16> 2019 // Into: 2020 // %0 = vector.extract_strided_slice %src { 2021 // offsets = [2], sizes = [2], strides = [1] 2022 // } : vector<4xf32> to vector<2xf32> 2023 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> 2024 struct BubbleDownBitCastForStridedSliceExtract 2025 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 2026 using OpRewritePattern::OpRewritePattern; 2027 2028 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 2029 PatternRewriter &rewriter) const override { 2030 auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>(); 2031 if (!castOp) 2032 return failure(); 2033 2034 VectorType castSrcType = castOp.getSourceVectorType(); 2035 VectorType castDstType = castOp.getResultVectorType(); 2036 assert(castSrcType.getRank() == castDstType.getRank()); 2037 2038 int64_t castSrcLastDim = castSrcType.getShape().back(); 2039 int64_t castDstLastDim = castDstType.getShape().back(); 2040 // Require casting to more elements for now; other cases to be implemented. 2041 if (castSrcLastDim > castDstLastDim) 2042 return failure(); 2043 2044 // Only accept all one strides for now. 2045 if (llvm::any_of(extractOp.strides().getAsValueRange<IntegerAttr>(), 2046 [](const APInt &val) { return !val.isOneValue(); })) 2047 return failure(); 2048 2049 unsigned rank = extractOp.getVectorType().getRank(); 2050 assert(castDstLastDim % castSrcLastDim == 0); 2051 int64_t expandRatio = castDstLastDim / castSrcLastDim; 2052 2053 // If we have a less number of offsets than the rank, then implicitly we 2054 // are selecting the full range for the last bitcasted dimension; other 2055 // dimensions aren't affected. Otherwise, we need to scale down the last 2056 // dimension's offset given we are extracting from less elements now. 2057 ArrayAttr newOffsets = extractOp.offsets(); 2058 if (newOffsets.size() == rank) { 2059 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2060 if (offsets.back() % expandRatio != 0) 2061 return failure(); 2062 offsets.back() = offsets.back() / expandRatio; 2063 newOffsets = rewriter.getI64ArrayAttr(offsets); 2064 } 2065 2066 // Similarly for sizes. 2067 ArrayAttr newSizes = extractOp.sizes(); 2068 if (newSizes.size() == rank) { 2069 SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes); 2070 if (sizes.back() % expandRatio != 0) 2071 return failure(); 2072 sizes.back() = sizes.back() / expandRatio; 2073 newSizes = rewriter.getI64ArrayAttr(sizes); 2074 } 2075 2076 SmallVector<int64_t, 4> dims = 2077 llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape()); 2078 dims.back() = dims.back() / expandRatio; 2079 VectorType newExtractType = 2080 VectorType::get(dims, castSrcType.getElementType()); 2081 2082 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 2083 extractOp.getLoc(), newExtractType, castOp.source(), newOffsets, 2084 newSizes, extractOp.strides()); 2085 2086 rewriter.replaceOpWithNewOp<vector::BitCastOp>( 2087 extractOp, extractOp.getType(), newExtractOp); 2088 2089 return success(); 2090 } 2091 }; 2092 2093 // Shuffles vector.bitcast op before vector.insert_strided_slice op. 2094 // 2095 // This transforms IR like: 2096 // %0 = vector.insert_strided_slice %src, %dst { 2097 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> 2098 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> 2099 // Into: 2100 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32> 2101 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32> 2102 // %2 = vector.insert_strided_slice %src, %dst { 2103 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 2104 struct BubbleUpBitCastForStridedSliceInsert 2105 : public OpRewritePattern<vector::BitCastOp> { 2106 using OpRewritePattern::OpRewritePattern; 2107 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 2108 PatternRewriter &rewriter) const override { 2109 VectorType castSrcType = bitcastOp.getSourceVectorType(); 2110 VectorType castDstType = bitcastOp.getResultVectorType(); 2111 assert(castSrcType.getRank() == castDstType.getRank()); 2112 2113 int64_t castSrcLastDim = castSrcType.getShape().back(); 2114 int64_t castDstLastDim = castDstType.getShape().back(); 2115 // Require casting to less elements for now; other cases to be implemented. 2116 if (castSrcLastDim < castDstLastDim) 2117 return failure(); 2118 2119 assert(castSrcLastDim % castDstLastDim == 0); 2120 int64_t shrinkRatio = castSrcLastDim / castDstLastDim; 2121 2122 auto insertOp = 2123 bitcastOp.source().getDefiningOp<vector::InsertStridedSliceOp>(); 2124 if (!insertOp) 2125 return failure(); 2126 2127 // Only accept all one strides for now. 2128 if (llvm::any_of(insertOp.strides().getAsValueRange<IntegerAttr>(), 2129 [](const APInt &val) { return !val.isOneValue(); })) 2130 return failure(); 2131 2132 unsigned rank = insertOp.getSourceVectorType().getRank(); 2133 // Require insert op to have the same rank for the source and destination 2134 // vector; other cases to be implemented. 2135 if (rank != insertOp.getDestVectorType().getRank()) 2136 return failure(); 2137 2138 ArrayAttr newOffsets = insertOp.offsets(); 2139 assert(newOffsets.size() == rank); 2140 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2141 if (offsets.back() % shrinkRatio != 0) 2142 return failure(); 2143 offsets.back() = offsets.back() / shrinkRatio; 2144 newOffsets = rewriter.getI64ArrayAttr(offsets); 2145 2146 SmallVector<int64_t, 4> srcDims = 2147 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape()); 2148 srcDims.back() = srcDims.back() / shrinkRatio; 2149 VectorType newCastSrcType = 2150 VectorType::get(srcDims, castDstType.getElementType()); 2151 2152 auto newCastSrcOp = rewriter.create<vector::BitCastOp>( 2153 bitcastOp.getLoc(), newCastSrcType, insertOp.source()); 2154 2155 SmallVector<int64_t, 4> dstDims = 2156 llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); 2157 dstDims.back() = dstDims.back() / shrinkRatio; 2158 VectorType newCastDstType = 2159 VectorType::get(dstDims, castDstType.getElementType()); 2160 2161 auto newCastDstOp = rewriter.create<vector::BitCastOp>( 2162 bitcastOp.getLoc(), newCastDstType, insertOp.dest()); 2163 2164 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>( 2165 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, 2166 insertOp.strides()); 2167 2168 return success(); 2169 } 2170 }; 2171 2172 // Helper that returns a vector comparison that constructs a mask: 2173 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 2174 // 2175 // If `dim == 0` then the result will be a 0-D vector. 2176 // 2177 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 2178 // much more compact, IR for this operation, but LLVM eventually 2179 // generates more elaborate instructions for this intrinsic since it 2180 // is very conservative on the boundary conditions. 2181 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, 2182 bool indexOptimizations, int64_t dim, 2183 Value b, Value *off = nullptr) { 2184 auto loc = op->getLoc(); 2185 // If we can assume all indices fit in 32-bit, we perform the vector 2186 // comparison in 32-bit to get a higher degree of SIMD parallelism. 2187 // Otherwise we perform the vector comparison using 64-bit indices. 2188 Type idxType = 2189 indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type(); 2190 DenseIntElementsAttr indicesAttr; 2191 if (dim == 0 && indexOptimizations) { 2192 indicesAttr = DenseIntElementsAttr::get( 2193 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0}); 2194 } else if (dim == 0) { 2195 indicesAttr = DenseIntElementsAttr::get( 2196 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0}); 2197 } else if (indexOptimizations) { 2198 indicesAttr = rewriter.getI32VectorAttr( 2199 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))); 2200 } else { 2201 indicesAttr = rewriter.getI64VectorAttr( 2202 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))); 2203 } 2204 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 2205 // Add in an offset if requested. 2206 if (off) { 2207 Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off); 2208 Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o); 2209 indices = rewriter.create<arith::AddIOp>(loc, ov, indices); 2210 } 2211 // Construct the vector comparison. 2212 Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b); 2213 Value bounds = 2214 rewriter.create<vector::SplatOp>(loc, indices.getType(), bound); 2215 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices, 2216 bounds); 2217 } 2218 2219 template <typename ConcreteOp> 2220 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> { 2221 public: 2222 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt) 2223 : mlir::OpRewritePattern<ConcreteOp>(context), 2224 indexOptimizations(enableIndexOpt) {} 2225 2226 LogicalResult matchAndRewrite(ConcreteOp xferOp, 2227 PatternRewriter &rewriter) const override { 2228 if (!xferOp.hasOutOfBoundsDim()) 2229 return failure(); 2230 2231 if (xferOp.getVectorType().getRank() > 1 || 2232 llvm::size(xferOp.indices()) == 0) 2233 return failure(); 2234 2235 Location loc = xferOp->getLoc(); 2236 VectorType vtp = xferOp.getVectorType(); 2237 2238 // Create the in-bounds mask with all elements between [0 .. dim - offset) 2239 // set and [dim - offset .. vector_length) unset. 2240 // 2241 // TODO: when the leaf transfer rank is k > 1, we need the last `k` 2242 // dimensions here. 2243 unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 2244 Value off = xferOp.indices()[lastIndex]; 2245 Value dim = 2246 vector::createOrFoldDimOp(rewriter, loc, xferOp.source(), lastIndex); 2247 Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off); 2248 Value mask = rewriter.create<vector::CreateMaskOp>( 2249 loc, 2250 VectorType::get(vtp.getShape(), rewriter.getI1Type(), 2251 vtp.getNumScalableDims()), 2252 b); 2253 if (xferOp.mask()) { 2254 // Intersect the in-bounds with the mask specified as an op parameter. 2255 mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.mask()); 2256 } 2257 2258 rewriter.updateRootInPlace(xferOp, [&]() { 2259 xferOp.maskMutable().assign(mask); 2260 xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true})); 2261 }); 2262 2263 return success(); 2264 } 2265 2266 private: 2267 const bool indexOptimizations; 2268 }; 2269 2270 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only). 2271 class VectorCreateMaskOpConversion 2272 : public OpRewritePattern<vector::CreateMaskOp> { 2273 public: 2274 explicit VectorCreateMaskOpConversion(MLIRContext *context, 2275 bool enableIndexOpt) 2276 : mlir::OpRewritePattern<vector::CreateMaskOp>(context), 2277 indexOptimizations(enableIndexOpt) {} 2278 2279 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 2280 PatternRewriter &rewriter) const override { 2281 auto dstType = op.getType(); 2282 if (dstType.cast<VectorType>().isScalable()) 2283 return failure(); 2284 int64_t rank = dstType.getRank(); 2285 if (rank > 1) 2286 return failure(); 2287 rewriter.replaceOp( 2288 op, buildVectorComparison(rewriter, op, indexOptimizations, 2289 rank == 0 ? 0 : dstType.getDimSize(0), 2290 op.getOperand(0))); 2291 return success(); 2292 } 2293 2294 private: 2295 const bool indexOptimizations; 2296 }; 2297 2298 // Drop inner most contiguous unit dimensions from transfer_read operand. 2299 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> { 2300 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 2301 2302 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 2303 PatternRewriter &rewriter) const override { 2304 // TODO: support 0-d corner case. 2305 if (readOp.getTransferRank() == 0) 2306 return failure(); 2307 2308 // TODO: support mask. 2309 if (readOp.mask()) 2310 return failure(); 2311 2312 auto srcType = readOp.source().getType().dyn_cast<MemRefType>(); 2313 if (!srcType || !srcType.hasStaticShape()) 2314 return failure(); 2315 2316 if (!readOp.permutation_map().isMinorIdentity()) 2317 return failure(); 2318 2319 auto targetType = readOp.getVectorType(); 2320 if (targetType.getRank() <= 1) 2321 return failure(); 2322 2323 SmallVector<int64_t> srcStrides; 2324 int64_t srcOffset; 2325 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) 2326 return failure(); 2327 2328 size_t dimsToDrop = 0; 2329 for (size_t i = 1; i < srcStrides.size(); ++i) { 2330 int dim = srcType.getRank() - i - 1; 2331 if (srcStrides[dim] == 1) { 2332 dimsToDrop++; 2333 } else { 2334 break; 2335 } 2336 } 2337 if (dimsToDrop == 0) 2338 return failure(); 2339 2340 auto resultTargetVecType = 2341 VectorType::get(targetType.getShape().drop_back(dimsToDrop), 2342 targetType.getElementType()); 2343 2344 MemRefType resultMemrefType; 2345 if (srcType.getLayout().getAffineMap().isIdentity()) { 2346 resultMemrefType = MemRefType::get( 2347 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2348 {}, srcType.getMemorySpaceAsInt()); 2349 } else { 2350 AffineMap map = srcType.getLayout().getAffineMap(); 2351 int numResultDims = map.getNumDims() - dimsToDrop; 2352 int numSymbols = map.getNumSymbols(); 2353 for (size_t i = 0; i < dimsToDrop; ++i) { 2354 int dim = srcType.getRank() - i - 1; 2355 map = map.replace(rewriter.getAffineDimExpr(dim), 2356 rewriter.getAffineConstantExpr(0), numResultDims, 2357 numSymbols); 2358 } 2359 resultMemrefType = MemRefType::get( 2360 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2361 map, srcType.getMemorySpaceAsInt()); 2362 } 2363 2364 auto loc = readOp.getLoc(); 2365 SmallVector<int64_t> offsets(srcType.getRank(), 0); 2366 SmallVector<int64_t> strides(srcType.getRank(), 1); 2367 2368 ArrayAttr inBoundsAttr = 2369 readOp.in_bounds() 2370 ? rewriter.getArrayAttr( 2371 readOp.in_boundsAttr().getValue().drop_back(dimsToDrop)) 2372 : ArrayAttr(); 2373 Value rankedReducedView = rewriter.create<memref::SubViewOp>( 2374 loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(), 2375 strides); 2376 auto permMap = getTransferMinorIdentityMap( 2377 rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType); 2378 Value result = rewriter.create<vector::TransferReadOp>( 2379 loc, resultTargetVecType, rankedReducedView, 2380 readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), 2381 readOp.padding(), 2382 // TODO: support mask. 2383 /*mask=*/Value(), inBoundsAttr); 2384 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType, 2385 result); 2386 return success(); 2387 } 2388 }; 2389 2390 namespace { 2391 2392 /// This function checks to see if the vector combining kind 2393 /// is consistent with the integer or float element type. 2394 static bool isValidKind(bool isInt, vector::CombiningKind kind) { 2395 using vector::CombiningKind; 2396 enum class KindType { FLOAT, INT, INVALID }; 2397 KindType type{KindType::INVALID}; 2398 switch (kind) { 2399 case CombiningKind::MINF: 2400 case CombiningKind::MAXF: 2401 type = KindType::FLOAT; 2402 break; 2403 case CombiningKind::MINUI: 2404 case CombiningKind::MINSI: 2405 case CombiningKind::MAXUI: 2406 case CombiningKind::MAXSI: 2407 case CombiningKind::AND: 2408 case CombiningKind::OR: 2409 case CombiningKind::XOR: 2410 type = KindType::INT; 2411 break; 2412 case CombiningKind::ADD: 2413 case CombiningKind::MUL: 2414 type = isInt ? KindType::INT : KindType::FLOAT; 2415 break; 2416 } 2417 bool isValidIntKind = (type == KindType::INT) && isInt; 2418 bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt); 2419 return (isValidIntKind || isValidFloatKind); 2420 } 2421 2422 /// This function constructs the appropriate integer or float 2423 /// operation given the vector combining kind and operands. The 2424 /// supported int operations are : add, mul, min (signed/unsigned), 2425 /// max(signed/unsigned), and, or, xor. The supported float 2426 /// operations are : add, mul, min and max. 2427 static Value genOperator(Location loc, Value x, Value y, 2428 vector::CombiningKind kind, 2429 PatternRewriter &rewriter) { 2430 using vector::CombiningKind; 2431 2432 auto elType = x.getType().cast<VectorType>().getElementType(); 2433 bool isInt = elType.isIntOrIndex(); 2434 2435 Value combinedResult{nullptr}; 2436 switch (kind) { 2437 case CombiningKind::ADD: 2438 if (isInt) 2439 combinedResult = rewriter.create<arith::AddIOp>(loc, x, y); 2440 else 2441 combinedResult = rewriter.create<arith::AddFOp>(loc, x, y); 2442 break; 2443 case CombiningKind::MUL: 2444 if (isInt) 2445 combinedResult = rewriter.create<arith::MulIOp>(loc, x, y); 2446 else 2447 combinedResult = rewriter.create<arith::MulFOp>(loc, x, y); 2448 break; 2449 case CombiningKind::MINUI: 2450 combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y); 2451 break; 2452 case CombiningKind::MINSI: 2453 combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y); 2454 break; 2455 case CombiningKind::MAXUI: 2456 combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y); 2457 break; 2458 case CombiningKind::MAXSI: 2459 combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y); 2460 break; 2461 case CombiningKind::AND: 2462 combinedResult = rewriter.create<arith::AndIOp>(loc, x, y); 2463 break; 2464 case CombiningKind::OR: 2465 combinedResult = rewriter.create<arith::OrIOp>(loc, x, y); 2466 break; 2467 case CombiningKind::XOR: 2468 combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y); 2469 break; 2470 case CombiningKind::MINF: 2471 combinedResult = rewriter.create<arith::MinFOp>(loc, x, y); 2472 break; 2473 case CombiningKind::MAXF: 2474 combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y); 2475 break; 2476 } 2477 return combinedResult; 2478 } 2479 2480 /// Convert vector.scan op into arith ops and 2481 /// vector.insert_strided_slice/extract_strided_slice 2482 /// 2483 /// Ex: 2484 /// ``` 2485 /// %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 2486 /// 1} : 2487 /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>) 2488 /// ``` 2489 /// Gets converted to: 2490 /// ``` 2491 /// %cst = arith.constant dense<0> : vector<2x3xi32> 2492 /// %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1], 2493 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 = 2494 /// vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]} 2495 /// : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice 2496 /// %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : 2497 /// vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 : 2498 /// vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1], 2499 /// strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 = 2500 /// vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1], 2501 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3, 2502 /// %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets = 2503 /// [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 = 2504 /// vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 : 2505 /// vector<2x3xi32>, vector<2xi32> 2506 /// ``` 2507 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { 2508 using OpRewritePattern<vector::ScanOp>::OpRewritePattern; 2509 2510 LogicalResult matchAndRewrite(vector::ScanOp scanOp, 2511 PatternRewriter &rewriter) const override { 2512 auto loc = scanOp.getLoc(); 2513 VectorType destType = scanOp.getDestType(); 2514 ArrayRef<int64_t> destShape = destType.getShape(); 2515 auto elType = destType.getElementType(); 2516 bool isInt = elType.isIntOrIndex(); 2517 if (!isValidKind(isInt, scanOp.kind())) 2518 return failure(); 2519 2520 VectorType resType = VectorType::get(destShape, elType); 2521 Value result = rewriter.create<arith::ConstantOp>( 2522 loc, resType, rewriter.getZeroAttr(resType)); 2523 int64_t reductionDim = scanOp.reduction_dim(); 2524 bool inclusive = scanOp.inclusive(); 2525 int64_t destRank = destType.getRank(); 2526 VectorType initialValueType = scanOp.getInitialValueType(); 2527 int64_t initialValueRank = initialValueType.getRank(); 2528 2529 SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end()); 2530 reductionShape[reductionDim] = 1; 2531 VectorType reductionType = VectorType::get(reductionShape, elType); 2532 SmallVector<int64_t> offsets(destRank, 0); 2533 SmallVector<int64_t> strides(destRank, 1); 2534 SmallVector<int64_t> sizes(destShape.begin(), destShape.end()); 2535 sizes[reductionDim] = 1; 2536 ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes); 2537 ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides); 2538 2539 Value lastOutput, lastInput; 2540 for (int i = 0; i < destShape[reductionDim]; i++) { 2541 offsets[reductionDim] = i; 2542 ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets); 2543 Value input = rewriter.create<vector::ExtractStridedSliceOp>( 2544 loc, reductionType, scanOp.source(), scanOffsets, scanSizes, 2545 scanStrides); 2546 Value output; 2547 if (i == 0) { 2548 if (inclusive) { 2549 output = input; 2550 } else { 2551 if (initialValueRank == 0) { 2552 // ShapeCastOp cannot handle 0-D vectors 2553 output = rewriter.create<vector::BroadcastOp>( 2554 loc, input.getType(), scanOp.initial_value()); 2555 } else { 2556 output = rewriter.create<vector::ShapeCastOp>( 2557 loc, input.getType(), scanOp.initial_value()); 2558 } 2559 } 2560 } else { 2561 Value y = inclusive ? input : lastInput; 2562 output = genOperator(loc, lastOutput, y, scanOp.kind(), rewriter); 2563 assert(output != nullptr); 2564 } 2565 result = rewriter.create<vector::InsertStridedSliceOp>( 2566 loc, output, result, offsets, strides); 2567 lastOutput = output; 2568 lastInput = input; 2569 } 2570 2571 Value reduction; 2572 if (initialValueRank == 0) { 2573 Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0); 2574 reduction = 2575 rewriter.create<vector::BroadcastOp>(loc, initialValueType, v); 2576 } else { 2577 reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType, 2578 lastOutput); 2579 } 2580 2581 rewriter.replaceOp(scanOp, {result, reduction}); 2582 return success(); 2583 } 2584 }; 2585 2586 } // namespace 2587 2588 void mlir::vector::populateVectorMaskMaterializationPatterns( 2589 RewritePatternSet &patterns, bool indexOptimizations) { 2590 patterns.add<VectorCreateMaskOpConversion, 2591 MaterializeTransferMask<vector::TransferReadOp>, 2592 MaterializeTransferMask<vector::TransferWriteOp>>( 2593 patterns.getContext(), indexOptimizations); 2594 } 2595 2596 void mlir::vector::populateShapeCastFoldingPatterns( 2597 RewritePatternSet &patterns) { 2598 patterns.add<ShapeCastOpFolder>(patterns.getContext()); 2599 } 2600 2601 void mlir::vector::populateBubbleVectorBitCastOpPatterns( 2602 RewritePatternSet &patterns) { 2603 patterns.add<BubbleDownVectorBitCastForExtract, 2604 BubbleDownBitCastForStridedSliceExtract, 2605 BubbleUpBitCastForStridedSliceInsert>(patterns.getContext()); 2606 } 2607 2608 void mlir::vector::populateVectorBroadcastLoweringPatterns( 2609 RewritePatternSet &patterns) { 2610 patterns.add<BroadcastOpLowering>(patterns.getContext()); 2611 } 2612 2613 void mlir::vector::populateVectorMaskOpLoweringPatterns( 2614 RewritePatternSet &patterns) { 2615 patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>( 2616 patterns.getContext()); 2617 } 2618 2619 void mlir::vector::populateVectorShapeCastLoweringPatterns( 2620 RewritePatternSet &patterns) { 2621 patterns.add<ShapeCastOp2DDownCastRewritePattern, 2622 ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>( 2623 patterns.getContext()); 2624 } 2625 2626 void mlir::vector::populateVectorContractLoweringPatterns( 2627 RewritePatternSet &patterns, VectorTransformsOptions options) { 2628 patterns.add<OuterProductOpLowering>(patterns.getContext()); 2629 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering, 2630 ContractionOpToOuterProductOpLowering>(options, 2631 patterns.getContext()); 2632 } 2633 2634 void mlir::vector::populateVectorTransposeLoweringPatterns( 2635 RewritePatternSet &patterns, VectorTransformsOptions options) { 2636 patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>( 2637 options, patterns.getContext()); 2638 } 2639 2640 void mlir::vector::populateVectorReductionToContractPatterns( 2641 RewritePatternSet &patterns) { 2642 patterns.add<MultiReduceToContract, CombineContractBroadcast, 2643 CombineContractTranspose, ReorderCastOpsOnBroadcast, 2644 ReorderCastOpsOnTranspose>(patterns.getContext()); 2645 } 2646 2647 void mlir::vector:: 2648 populateVectorTransferCollapseInnerMostContiguousDimsPatterns( 2649 RewritePatternSet &patterns) { 2650 patterns.add<DropInnerMostUnitDims>(patterns.getContext()); 2651 } 2652 2653 void mlir::vector::populateVectorTransferLoweringPatterns( 2654 RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) { 2655 patterns.add<TransferReadToVectorLoadLowering, 2656 TransferWriteToVectorStoreLowering>(patterns.getContext(), 2657 maxTransferRank); 2658 patterns 2659 .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>( 2660 patterns.getContext()); 2661 } 2662 2663 void mlir::vector::populateVectorScanLoweringPatterns( 2664 RewritePatternSet &patterns) { 2665 patterns.add<ScanToArithOps>(patterns.getContext()); 2666 } 2667