1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites as 1->N patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 14 15 #include <type_traits> 16 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 19 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 21 #include "mlir/Dialect/Linalg/IR/Linalg.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/Dialect/SCF/SCF.h" 24 #include "mlir/Dialect/Utils/IndexingUtils.h" 25 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 26 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 27 #include "mlir/IR/ImplicitLocOpBuilder.h" 28 #include "mlir/IR/Matchers.h" 29 #include "mlir/IR/PatternMatch.h" 30 #include "mlir/Interfaces/VectorInterfaces.h" 31 32 #include "llvm/ADT/DenseSet.h" 33 #include "llvm/ADT/MapVector.h" 34 #include "llvm/ADT/STLExtras.h" 35 #include "llvm/Support/CommandLine.h" 36 #include "llvm/Support/Debug.h" 37 #include "llvm/Support/raw_ostream.h" 38 39 #define DEBUG_TYPE "vector-to-vector" 40 41 using namespace mlir; 42 using namespace mlir::vector; 43 44 // Helper to find an index in an affine map. 45 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) { 46 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 47 int64_t idx = map.getDimPosition(i); 48 if (idx == index) 49 return i; 50 } 51 return None; 52 } 53 54 // Helper to construct iterator types with one index removed. 55 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes, 56 int64_t index) { 57 SmallVector<Attribute, 4> results; 58 for (const auto &it : llvm::enumerate(iteratorTypes)) { 59 int64_t idx = it.index(); 60 if (idx == index) 61 continue; 62 results.push_back(it.value()); 63 } 64 return results; 65 } 66 67 // Helper to construct an affine map with one index removed. 68 static AffineMap adjustMap(AffineMap map, int64_t index, 69 PatternRewriter &rewriter) { 70 auto *ctx = rewriter.getContext(); 71 SmallVector<AffineExpr, 4> results; 72 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 73 int64_t idx = map.getDimPosition(i); 74 if (idx == index) 75 continue; 76 // Re-insert remaining indices, but renamed when occurring 77 // after the removed index. 78 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); 79 results.push_back(targetExpr); 80 } 81 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); 82 } 83 84 // Helper method to possibly drop a dimension in a load. 85 // TODO 86 static Value reshapeLoad(Location loc, Value val, VectorType type, 87 int64_t index, int64_t pos, 88 PatternRewriter &rewriter) { 89 if (index == -1) 90 return val; 91 Type lowType = VectorType::Builder(type).dropDim(0); 92 // At extraction dimension? 93 if (index == 0) { 94 auto posAttr = rewriter.getI64ArrayAttr(pos); 95 return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr); 96 } 97 // Unroll leading dimensions. 98 VectorType vType = lowType.cast<VectorType>(); 99 Type resType = VectorType::Builder(type).dropDim(index); 100 auto resVectorType = resType.cast<VectorType>(); 101 Value result = rewriter.create<arith::ConstantOp>( 102 loc, resVectorType, rewriter.getZeroAttr(resVectorType)); 103 for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { 104 auto posAttr = rewriter.getI64ArrayAttr(d); 105 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr); 106 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); 107 result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, 108 posAttr); 109 } 110 return result; 111 } 112 113 // Helper method to possibly drop a dimension in a store. 114 // TODO 115 static Value reshapeStore(Location loc, Value val, Value result, 116 VectorType type, int64_t index, int64_t pos, 117 PatternRewriter &rewriter) { 118 // Unmodified? 119 if (index == -1) 120 return val; 121 // At insertion dimension? 122 if (index == 0) { 123 auto posAttr = rewriter.getI64ArrayAttr(pos); 124 return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr); 125 } 126 // Unroll leading dimensions. 127 Type lowType = VectorType::Builder(type).dropDim(0); 128 VectorType vType = lowType.cast<VectorType>(); 129 Type insType = VectorType::Builder(vType).dropDim(0); 130 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { 131 auto posAttr = rewriter.getI64ArrayAttr(d); 132 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr); 133 Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr); 134 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); 135 result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr); 136 } 137 return result; 138 } 139 140 template <typename IntType> 141 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) { 142 return llvm::to_vector<4>(llvm::map_range( 143 arrayAttr.getAsRange<IntegerAttr>(), 144 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); 145 } 146 147 namespace { 148 149 /// ShapeCastOpFolder folds cancelling ShapeCastOps away. 150 // 151 // Example: 152 // 153 // The following MLIR with cancelling ShapeCastOps: 154 // 155 // %0 = source : vector<5x4x2xf32> 156 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> 157 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> 158 // %3 = user %2 : vector<5x4x2xf32> 159 // 160 // Should canonicalize to the following: 161 // 162 // %0 = source : vector<5x4x2xf32> 163 // %1 = user %0 : vector<5x4x2xf32> 164 // 165 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> { 166 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 167 168 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, 169 PatternRewriter &rewriter) const override { 170 // Check if 'shapeCastOp' has vector source/result type. 171 auto sourceVectorType = 172 shapeCastOp.source().getType().dyn_cast_or_null<VectorType>(); 173 auto resultVectorType = 174 shapeCastOp.result().getType().dyn_cast_or_null<VectorType>(); 175 if (!sourceVectorType || !resultVectorType) 176 return failure(); 177 178 // Check if shape cast op source operand is also a shape cast op. 179 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>( 180 shapeCastOp.source().getDefiningOp()); 181 if (!sourceShapeCastOp) 182 return failure(); 183 auto operandSourceVectorType = 184 sourceShapeCastOp.source().getType().cast<VectorType>(); 185 auto operandResultVectorType = sourceShapeCastOp.getType(); 186 187 // Check if shape cast operations invert each other. 188 if (operandSourceVectorType != resultVectorType || 189 operandResultVectorType != sourceVectorType) 190 return failure(); 191 192 rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source()); 193 return success(); 194 } 195 }; 196 197 /// Progressive lowering of BroadcastOp. 198 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> { 199 public: 200 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern; 201 202 LogicalResult matchAndRewrite(vector::BroadcastOp op, 203 PatternRewriter &rewriter) const override { 204 auto loc = op.getLoc(); 205 VectorType dstType = op.getVectorType(); 206 VectorType srcType = op.getSourceType().dyn_cast<VectorType>(); 207 Type eltType = dstType.getElementType(); 208 209 // Scalar to any vector can use splat. 210 if (!srcType) { 211 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.source()); 212 return success(); 213 } 214 215 // Determine rank of source and destination. 216 int64_t srcRank = srcType.getRank(); 217 int64_t dstRank = dstType.getRank(); 218 219 // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. 220 if (srcRank <= 1 && dstRank == 1) { 221 Value ext; 222 if (srcRank == 0) 223 ext = rewriter.create<vector::ExtractElementOp>(loc, op.source()); 224 else 225 ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0); 226 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext); 227 return success(); 228 } 229 230 // Duplicate this rank. 231 // For example: 232 // %x = broadcast %y : k-D to n-D, k < n 233 // becomes: 234 // %b = broadcast %y : k-D to (n-1)-D 235 // %x = [%b,%b,%b,%b] : n-D 236 // becomes: 237 // %b = [%y,%y] : (n-1)-D 238 // %x = [%b,%b,%b,%b] : n-D 239 if (srcRank < dstRank) { 240 // Duplication. 241 VectorType resType = 242 VectorType::get(dstType.getShape().drop_front(), eltType); 243 Value bcst = 244 rewriter.create<vector::BroadcastOp>(loc, resType, op.source()); 245 Value result = rewriter.create<arith::ConstantOp>( 246 loc, dstType, rewriter.getZeroAttr(dstType)); 247 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 248 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 249 rewriter.replaceOp(op, result); 250 return success(); 251 } 252 253 // Find non-matching dimension, if any. 254 assert(srcRank == dstRank); 255 int64_t m = -1; 256 for (int64_t r = 0; r < dstRank; r++) 257 if (srcType.getDimSize(r) != dstType.getDimSize(r)) { 258 m = r; 259 break; 260 } 261 262 // All trailing dimensions are the same. Simply pass through. 263 if (m == -1) { 264 rewriter.replaceOp(op, op.source()); 265 return success(); 266 } 267 268 // Any non-matching dimension forces a stretch along this rank. 269 // For example: 270 // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32> 271 // becomes: 272 // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32> 273 // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32> 274 // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32> 275 // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32> 276 // %x = [%a,%b,%c,%d] 277 // becomes: 278 // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32> 279 // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32> 280 // %a = [%u, %v] 281 // .. 282 // %x = [%a,%b,%c,%d] 283 VectorType resType = 284 VectorType::get(dstType.getShape().drop_front(), eltType); 285 Value result = rewriter.create<arith::ConstantOp>( 286 loc, dstType, rewriter.getZeroAttr(dstType)); 287 if (m == 0) { 288 // Stetch at start. 289 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0); 290 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 291 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 292 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 293 } else { 294 // Stetch not at start. 295 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) { 296 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d); 297 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 298 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 299 } 300 } 301 rewriter.replaceOp(op, result); 302 return success(); 303 } 304 }; 305 306 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not 307 /// transposed. 308 void pruneNonTransposedDims(ArrayRef<int64_t> transpose, 309 SmallVectorImpl<int64_t> &result) { 310 size_t numTransposedDims = transpose.size(); 311 for (size_t transpDim : llvm::reverse(transpose)) { 312 if (transpDim != numTransposedDims - 1) 313 break; 314 numTransposedDims--; 315 } 316 317 result.append(transpose.begin(), transpose.begin() + numTransposedDims); 318 } 319 320 /// Progressive lowering of TransposeOp. 321 /// One: 322 /// %x = vector.transpose %y, [1, 0] 323 /// is replaced by: 324 /// %z = arith.constant dense<0.000000e+00> 325 /// %0 = vector.extract %y[0, 0] 326 /// %1 = vector.insert %0, %z [0, 0] 327 /// .. 328 /// %x = vector.insert .., .. [.., ..] 329 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { 330 public: 331 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 332 333 TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions, 334 MLIRContext *context) 335 : OpRewritePattern<vector::TransposeOp>(context), 336 vectorTransformOptions(vectorTransformOptions) {} 337 338 LogicalResult matchAndRewrite(vector::TransposeOp op, 339 PatternRewriter &rewriter) const override { 340 auto loc = op.getLoc(); 341 342 Value input = op.vector(); 343 VectorType inputType = op.getVectorType(); 344 VectorType resType = op.getResultType(); 345 346 // Set up convenience transposition table. 347 SmallVector<int64_t, 4> transp; 348 for (auto attr : op.transp()) 349 transp.push_back(attr.cast<IntegerAttr>().getInt()); 350 351 if (vectorTransformOptions.vectorTransposeLowering == 352 vector::VectorTransposeLowering::Shuffle && 353 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) 354 return rewriter.notifyMatchFailure( 355 op, "Options specifies lowering to shuffle"); 356 357 // Handle a true 2-D matrix transpose differently when requested. 358 if (vectorTransformOptions.vectorTransposeLowering == 359 vector::VectorTransposeLowering::Flat && 360 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { 361 Type flattenedType = 362 VectorType::get(resType.getNumElements(), resType.getElementType()); 363 auto matrix = 364 rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input); 365 auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); 366 auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); 367 Value trans = rewriter.create<vector::FlatTransposeOp>( 368 loc, flattenedType, matrix, rows, columns); 369 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans); 370 return success(); 371 } 372 373 // Generate unrolled extract/insert ops. We do not unroll the rightmost 374 // (i.e., highest-order) dimensions that are not transposed and leave them 375 // in vector form to improve performance. Therefore, we prune those 376 // dimensions from the shape/transpose data structures used to generate the 377 // extract/insert ops. 378 SmallVector<int64_t, 4> prunedTransp; 379 pruneNonTransposedDims(transp, prunedTransp); 380 size_t numPrunedDims = transp.size() - prunedTransp.size(); 381 auto prunedInShape = inputType.getShape().drop_back(numPrunedDims); 382 SmallVector<int64_t, 4> ones(prunedInShape.size(), 1); 383 auto prunedInStrides = computeStrides(prunedInShape, ones); 384 385 // Generates the extract/insert operations for every scalar/vector element 386 // of the leftmost transposed dimensions. We traverse every transpose 387 // element using a linearized index that we delinearize to generate the 388 // appropriate indices for the extract/insert operations. 389 Value result = rewriter.create<arith::ConstantOp>( 390 loc, resType, rewriter.getZeroAttr(resType)); 391 int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape); 392 393 for (int64_t linearIdx = 0; linearIdx < numTransposedElements; 394 ++linearIdx) { 395 auto extractIdxs = delinearize(prunedInStrides, linearIdx); 396 SmallVector<int64_t, 4> insertIdxs(extractIdxs); 397 applyPermutationToVector(insertIdxs, prunedTransp); 398 Value extractOp = 399 rewriter.create<vector::ExtractOp>(loc, input, extractIdxs); 400 result = 401 rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs); 402 } 403 404 rewriter.replaceOp(op, result); 405 return success(); 406 } 407 408 private: 409 /// Options to control the vector patterns. 410 vector::VectorTransformsOptions vectorTransformOptions; 411 }; 412 413 /// Rewrite a 2-D vector.transpose as a sequence of: 414 /// vector.shape_cast 2D -> 1D 415 /// vector.shuffle 416 /// vector.shape_cast 1D -> 2D 417 class TransposeOp2DToShuffleLowering 418 : public OpRewritePattern<vector::TransposeOp> { 419 public: 420 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 421 422 TransposeOp2DToShuffleLowering( 423 vector::VectorTransformsOptions vectorTransformOptions, 424 MLIRContext *context) 425 : OpRewritePattern<vector::TransposeOp>(context), 426 vectorTransformOptions(vectorTransformOptions) {} 427 428 LogicalResult matchAndRewrite(vector::TransposeOp op, 429 PatternRewriter &rewriter) const override { 430 auto loc = op.getLoc(); 431 432 VectorType srcType = op.getVectorType(); 433 if (srcType.getRank() != 2) 434 return rewriter.notifyMatchFailure(op, "Not a 2D transpose"); 435 436 SmallVector<int64_t, 4> transp; 437 for (auto attr : op.transp()) 438 transp.push_back(attr.cast<IntegerAttr>().getInt()); 439 if (transp[0] != 1 && transp[1] != 0) 440 return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation"); 441 442 if (vectorTransformOptions.vectorTransposeLowering != 443 VectorTransposeLowering::Shuffle) 444 return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle"); 445 446 int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); 447 Value casted = rewriter.create<vector::ShapeCastOp>( 448 loc, VectorType::get({m * n}, srcType.getElementType()), op.vector()); 449 SmallVector<int64_t> mask; 450 mask.reserve(m * n); 451 for (int64_t j = 0; j < n; ++j) 452 for (int64_t i = 0; i < m; ++i) 453 mask.push_back(i * n + j); 454 455 Value shuffled = 456 rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask); 457 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(), 458 shuffled); 459 460 return success(); 461 } 462 463 private: 464 /// Options to control the vector patterns. 465 vector::VectorTransformsOptions vectorTransformOptions; 466 }; 467 468 /// Progressive lowering of OuterProductOp. 469 /// One: 470 /// %x = vector.outerproduct %lhs, %rhs, %acc 471 /// is replaced by: 472 /// %z = zero-result 473 /// %0 = vector.extract %lhs[0] 474 /// %1 = vector.broadcast %0 475 /// %2 = vector.extract %acc[0] 476 /// %3 = vector.fma %1, %rhs, %2 477 /// %4 = vector.insert %3, %z[0] 478 /// .. 479 /// %x = vector.insert %.., %..[N-1] 480 /// 481 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { 482 public: 483 using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern; 484 485 LogicalResult matchAndRewrite(vector::OuterProductOp op, 486 PatternRewriter &rewriter) const override { 487 auto loc = op.getLoc(); 488 489 VectorType lhsType = op.getOperandVectorTypeLHS(); 490 VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>(); 491 VectorType resType = op.getVectorType(); 492 Type eltType = resType.getElementType(); 493 bool isInt = eltType.isa<IntegerType, IndexType>(); 494 Value acc = (op.acc().empty()) ? nullptr : op.acc()[0]; 495 vector::CombiningKind kind = op.kind(); 496 497 if (!rhsType) { 498 // Special case: AXPY operation. 499 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs()); 500 Optional<Value> mult = 501 isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter) 502 : genMultF(loc, op.lhs(), b, acc, kind, rewriter); 503 if (!mult.hasValue()) 504 return failure(); 505 rewriter.replaceOp(op, mult.getValue()); 506 return success(); 507 } 508 509 Value result = rewriter.create<arith::ConstantOp>( 510 loc, resType, rewriter.getZeroAttr(resType)); 511 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { 512 auto pos = rewriter.getI64ArrayAttr(d); 513 Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos); 514 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x); 515 Value r = nullptr; 516 if (acc) 517 r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos); 518 Optional<Value> m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter) 519 : genMultF(loc, a, op.rhs(), r, kind, rewriter); 520 if (!m.hasValue()) 521 return failure(); 522 result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(), 523 result, pos); 524 } 525 rewriter.replaceOp(op, result); 526 return success(); 527 } 528 529 private: 530 static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc, 531 vector::CombiningKind kind, 532 PatternRewriter &rewriter) { 533 using vector::CombiningKind; 534 535 auto mul = rewriter.create<arith::MulIOp>(loc, x, y); 536 if (!acc) 537 return Optional<Value>(mul); 538 539 if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) 540 // Only valid for floating point types. 541 return Optional<Value>(); 542 543 return makeArithReduction(rewriter, loc, kind, mul, acc); 544 } 545 546 static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc, 547 vector::CombiningKind kind, 548 PatternRewriter &rewriter) { 549 using vector::CombiningKind; 550 551 // Special case for fused multiply-add. 552 if (acc && kind == CombiningKind::ADD) { 553 return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc)); 554 } 555 556 auto mul = rewriter.create<arith::MulFOp>(loc, x, y); 557 558 if (!acc) 559 return Optional<Value>(mul); 560 561 if (kind == CombiningKind::ADD || kind == CombiningKind::AND || 562 kind == CombiningKind::MINUI || kind == CombiningKind::MINSI || 563 kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI || 564 kind == CombiningKind::OR || kind == CombiningKind::XOR) 565 // Already handled or only valid for integer types. 566 return Optional<Value>(); 567 568 return makeArithReduction(rewriter, loc, kind, mul, acc); 569 } 570 }; 571 572 /// Progressive lowering of ConstantMaskOp. 573 /// One: 574 /// %x = vector.constant_mask [a,b] 575 /// is replaced by: 576 /// %z = zero-result 577 /// %l = vector.constant_mask [b] 578 /// %4 = vector.insert %l, %z[0] 579 /// .. 580 /// %x = vector.insert %l, %..[a-1] 581 /// until a one-dimensional vector is reached. All these operations 582 /// will be folded at LLVM IR level. 583 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { 584 public: 585 using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern; 586 587 LogicalResult matchAndRewrite(vector::ConstantMaskOp op, 588 PatternRewriter &rewriter) const override { 589 auto loc = op.getLoc(); 590 auto dstType = op.getType(); 591 auto eltType = dstType.getElementType(); 592 auto dimSizes = op.mask_dim_sizes(); 593 int64_t rank = dstType.getRank(); 594 595 if (rank == 0) { 596 assert(dimSizes.size() == 1 && 597 "Expected exactly one dim size for a 0-D vector"); 598 bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1; 599 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 600 op, dstType, 601 DenseIntElementsAttr::get( 602 VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()), 603 ArrayRef<bool>{value})); 604 return success(); 605 } 606 607 // Scalable constant masks can only be lowered for the "none set" case. 608 if (dstType.cast<VectorType>().isScalable()) { 609 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 610 op, DenseElementsAttr::get(dstType, false)); 611 return success(); 612 } 613 614 int64_t trueDim = std::min(dstType.getDimSize(0), 615 dimSizes[0].cast<IntegerAttr>().getInt()); 616 617 if (rank == 1) { 618 // Express constant 1-D case in explicit vector form: 619 // [T,..,T,F,..,F]. 620 SmallVector<bool, 4> values(dstType.getDimSize(0)); 621 for (int64_t d = 0; d < trueDim; d++) 622 values[d] = true; 623 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 624 op, dstType, rewriter.getBoolVectorAttr(values)); 625 return success(); 626 } 627 628 VectorType lowType = 629 VectorType::get(dstType.getShape().drop_front(), eltType); 630 SmallVector<int64_t, 4> newDimSizes; 631 for (int64_t r = 1; r < rank; r++) 632 newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt()); 633 Value trueVal = rewriter.create<vector::ConstantMaskOp>( 634 loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); 635 Value result = rewriter.create<arith::ConstantOp>( 636 loc, dstType, rewriter.getZeroAttr(dstType)); 637 for (int64_t d = 0; d < trueDim; d++) { 638 auto pos = rewriter.getI64ArrayAttr(d); 639 result = 640 rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos); 641 } 642 rewriter.replaceOp(op, result); 643 return success(); 644 } 645 }; 646 647 /// Progressive lowering of CreateMaskOp. 648 /// One: 649 /// %x = vector.create_mask %a, ... : vector<dx...> 650 /// is replaced by: 651 /// %l = vector.create_mask ... : vector<...> ; one lower rank 652 /// %0 = arith.cmpi "slt", %ci, %a | 653 /// %1 = select %0, %l, %zeroes | 654 /// %r = vector.insert %1, %pr [i] | d-times 655 /// %x = .... 656 /// until a one-dimensional vector is reached. 657 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { 658 public: 659 using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern; 660 661 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 662 PatternRewriter &rewriter) const override { 663 auto dstType = op.getResult().getType().cast<VectorType>(); 664 int64_t rank = dstType.getRank(); 665 if (rank <= 1) 666 return rewriter.notifyMatchFailure( 667 op, "0-D and 1-D vectors are handled separately"); 668 669 auto loc = op.getLoc(); 670 auto eltType = dstType.getElementType(); 671 int64_t dim = dstType.getDimSize(0); 672 Value idx = op.getOperand(0); 673 674 VectorType lowType = 675 VectorType::get(dstType.getShape().drop_front(), eltType); 676 Value trueVal = rewriter.create<vector::CreateMaskOp>( 677 loc, lowType, op.getOperands().drop_front()); 678 Value falseVal = rewriter.create<arith::ConstantOp>( 679 loc, lowType, rewriter.getZeroAttr(lowType)); 680 Value result = rewriter.create<arith::ConstantOp>( 681 loc, dstType, rewriter.getZeroAttr(dstType)); 682 for (int64_t d = 0; d < dim; d++) { 683 Value bnd = 684 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d)); 685 Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 686 bnd, idx); 687 Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal); 688 auto pos = rewriter.getI64ArrayAttr(d); 689 result = 690 rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos); 691 } 692 rewriter.replaceOp(op, result); 693 return success(); 694 } 695 }; 696 697 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D 698 /// vectors progressively on the way to target llvm.matrix intrinsics. 699 /// This iterates over the most major dimension of the 2-D vector and performs 700 /// rewrites into: 701 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D 702 class ShapeCastOp2DDownCastRewritePattern 703 : public OpRewritePattern<vector::ShapeCastOp> { 704 public: 705 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 706 707 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 708 PatternRewriter &rewriter) const override { 709 auto sourceVectorType = op.getSourceVectorType(); 710 auto resultVectorType = op.getResultVectorType(); 711 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) 712 return failure(); 713 714 auto loc = op.getLoc(); 715 Value desc = rewriter.create<arith::ConstantOp>( 716 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 717 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; 718 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { 719 Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i); 720 desc = rewriter.create<vector::InsertStridedSliceOp>( 721 loc, vec, desc, 722 /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); 723 } 724 rewriter.replaceOp(op, desc); 725 return success(); 726 } 727 }; 728 729 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D 730 /// vectors progressively. 731 /// This iterates over the most major dimension of the 2-D vector and performs 732 /// rewrites into: 733 /// vector.extract_strided_slice from 1-D + vector.insert into 2-D 734 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. 735 class ShapeCastOp2DUpCastRewritePattern 736 : public OpRewritePattern<vector::ShapeCastOp> { 737 public: 738 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 739 740 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 741 PatternRewriter &rewriter) const override { 742 auto sourceVectorType = op.getSourceVectorType(); 743 auto resultVectorType = op.getResultVectorType(); 744 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) 745 return failure(); 746 747 auto loc = op.getLoc(); 748 Value desc = rewriter.create<arith::ConstantOp>( 749 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 750 unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; 751 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { 752 Value vec = rewriter.create<vector::ExtractStridedSliceOp>( 753 loc, op.source(), /*offsets=*/i * mostMinorVectorSize, 754 /*sizes=*/mostMinorVectorSize, 755 /*strides=*/1); 756 desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i); 757 } 758 rewriter.replaceOp(op, desc); 759 return success(); 760 } 761 }; 762 763 // We typically should not lower general shape cast operations into data 764 // movement instructions, since the assumption is that these casts are 765 // optimized away during progressive lowering. For completeness, however, 766 // we fall back to a reference implementation that moves all elements 767 // into the right place if we get here. 768 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { 769 public: 770 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 771 772 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 773 PatternRewriter &rewriter) const override { 774 Location loc = op.getLoc(); 775 auto sourceVectorType = op.getSourceVectorType(); 776 auto resultVectorType = op.getResultVectorType(); 777 778 // Special case 2D/1D lowerings with better implementations. 779 // TODO: make is ND/1D to allow generic ND->1D->MD. 780 int64_t srcRank = sourceVectorType.getRank(); 781 int64_t resRank = resultVectorType.getRank(); 782 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2)) 783 return failure(); 784 785 // Generic ShapeCast lowering path goes all the way down to unrolled scalar 786 // extract/insert chains. 787 // TODO: consider evolving the semantics to only allow 1D source or dest and 788 // drop this potentially very expensive lowering. 789 // Compute number of elements involved in the reshape. 790 int64_t numElts = 1; 791 for (int64_t r = 0; r < srcRank; r++) 792 numElts *= sourceVectorType.getDimSize(r); 793 // Replace with data movement operations: 794 // x[0,0,0] = y[0,0] 795 // x[0,0,1] = y[0,1] 796 // x[0,1,0] = y[0,2] 797 // etc., incrementing the two index vectors "row-major" 798 // within the source and result shape. 799 SmallVector<int64_t, 4> srcIdx(srcRank); 800 SmallVector<int64_t, 4> resIdx(resRank); 801 Value result = rewriter.create<arith::ConstantOp>( 802 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 803 for (int64_t i = 0; i < numElts; i++) { 804 if (i != 0) { 805 incIdx(srcIdx, sourceVectorType, srcRank - 1); 806 incIdx(resIdx, resultVectorType, resRank - 1); 807 } 808 Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx); 809 result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx); 810 } 811 rewriter.replaceOp(op, result); 812 return success(); 813 } 814 815 private: 816 static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) { 817 assert(0 <= r && r < tp.getRank()); 818 if (++idx[r] == tp.getDimSize(r)) { 819 idx[r] = 0; 820 incIdx(idx, tp, r - 1); 821 } 822 } 823 }; 824 825 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp. 826 /// Ex: 827 /// ``` 828 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> 829 /// %1 = vector.multi_reduction add, %0 [1] 830 /// : vector<8x32x16xf32> to vector<8x16xf32> 831 /// ``` 832 /// Gets converted to: 833 /// ``` 834 /// %1 = vector.contract {indexing_maps = [ 835 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 836 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 837 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 838 /// iterator_types = ["parallel", "parallel", "reduction"], 839 /// kind = add} %0, %arg1, %cst_f0 840 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 841 /// ``` 842 struct MultiReduceToContract 843 : public OpRewritePattern<vector::MultiDimReductionOp> { 844 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 845 846 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, 847 PatternRewriter &rewriter) const override { 848 if (reduceOp.kind() != vector::CombiningKind::ADD) 849 return failure(); 850 Operation *mulOp = reduceOp.source().getDefiningOp(); 851 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp)) 852 return failure(); 853 SmallVector<bool> reductionMask = reduceOp.getReductionMask(); 854 auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size()); 855 SmallVector<AffineExpr> exprs; 856 SmallVector<StringRef> iteratorTypes; 857 for (const auto &isReduceDim : llvm::enumerate(reductionMask)) { 858 if (!isReduceDim.value()) { 859 iteratorTypes.push_back(getParallelIteratorTypeName()); 860 exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index())); 861 } else { 862 iteratorTypes.push_back(getReductionIteratorTypeName()); 863 } 864 } 865 auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(), 866 /*symCount=*/0, exprs, reduceOp.getContext()); 867 Value zero = rewriter.create<arith::ConstantOp>( 868 reduceOp.getLoc(), reduceOp.getDestType(), 869 rewriter.getZeroAttr(reduceOp.getDestType())); 870 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>( 871 reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero, 872 rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}), 873 rewriter.getStrArrayAttr(iteratorTypes)); 874 return success(); 875 } 876 }; 877 878 /// Merge TransposeOp into ContractionOp user. 879 /// Ex: 880 /// ``` 881 /// %0 = vector.transpose %arg0, [2, 0, 1] 882 /// : vector<32x16x8xf32> to vector<8x32x16xf32> 883 /// %1 = vector.contract {indexing_maps = [ 884 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 885 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 886 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 887 /// iterator_types = ["parallel", "parallel", "reduction"], 888 /// kind = add} %0, %arg1, %cst_f0 889 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 890 /// ``` 891 /// Gets converted to: 892 /// ``` 893 /// %1 = vector.contract {indexing_maps = [ 894 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>, 895 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 896 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 897 /// iterator_types = ["parallel", "parallel", "reduction"], 898 /// kind = add} %arg0, %arg1, %cst_f0 899 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 900 /// ``` 901 struct CombineContractTranspose 902 : public OpRewritePattern<vector::ContractionOp> { 903 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 904 905 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 906 PatternRewriter &rewriter) const override { 907 SmallVector<AffineMap, 4> maps = 908 llvm::to_vector<4>(contractOp.getIndexingMaps()); 909 Value lhs = contractOp.lhs(); 910 Value rhs = contractOp.rhs(); 911 size_t index = 0; 912 bool changed = false; 913 for (Value *operand : {&lhs, &rhs}) { 914 AffineMap &map = maps[index++]; 915 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>(); 916 if (!transposeOp) 917 continue; 918 SmallVector<int64_t> perm; 919 transposeOp.getTransp(perm); 920 AffineMap permutationMap = AffineMap::getPermutationMap( 921 extractVector<unsigned>(transposeOp.transp()), 922 contractOp.getContext()); 923 map = inversePermutation(permutationMap).compose(map); 924 *operand = transposeOp.vector(); 925 changed = true; 926 } 927 if (!changed) 928 return failure(); 929 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 930 contractOp, lhs, rhs, contractOp.acc(), 931 rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); 932 return success(); 933 } 934 }; 935 936 /// Merge BroadcastOp into ContractionOp user. 937 /// Ex: 938 /// ``` 939 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> 940 /// %1 = vector.contract {indexing_maps = [ 941 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 942 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 943 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 944 /// iterator_types = ["parallel", "parallel", "reduction"], 945 /// kind = add} %0, %arg1, %cst_f0 946 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 947 /// ``` 948 /// Gets converted to: 949 /// ``` 950 /// %1 = vector.contract {indexing_maps = [ 951 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 952 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 953 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 954 /// iterator_types = ["parallel", "parallel", "reduction"], 955 /// kind = add} %arg0, %arg1, %cst_f0 956 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 957 /// ``` 958 struct CombineContractBroadcast 959 : public OpRewritePattern<vector::ContractionOp> { 960 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 961 962 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 963 PatternRewriter &rewriter) const override { 964 SmallVector<AffineMap, 4> maps = 965 llvm::to_vector<4>(contractOp.getIndexingMaps()); 966 Value lhs = contractOp.lhs(); 967 Value rhs = contractOp.rhs(); 968 size_t index = 0; 969 bool changed = false; 970 for (Value *operand : {&lhs, &rhs}) { 971 AffineMap &map = maps[index++]; 972 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>(); 973 if (!broadcast) 974 continue; 975 // contractionOp can only take vector as operands. 976 auto srcType = broadcast.getSourceType().dyn_cast<VectorType>(); 977 if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank()) 978 continue; 979 int64_t rankDiff = 980 broadcast.getVectorType().getRank() - srcType.getRank(); 981 bool innerDimBroadcast = false; 982 SmallVector<AffineExpr> originalDims; 983 for (const auto &dim : llvm::enumerate(srcType.getShape())) { 984 if (dim.value() != 985 broadcast.getVectorType().getDimSize(rankDiff + dim.index())) { 986 innerDimBroadcast = true; 987 break; 988 } 989 originalDims.push_back( 990 rewriter.getAffineDimExpr(dim.index() + rankDiff)); 991 } 992 // Contract doesn't support inner dimension broadcast. Once this is 993 // relaxed we can remove this case. 994 if (innerDimBroadcast) 995 continue; 996 AffineMap broadcastMap = 997 AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims, 998 contractOp.getContext()); 999 map = broadcastMap.compose(map); 1000 *operand = broadcast.source(); 1001 changed = true; 1002 } 1003 if (!changed) 1004 return failure(); 1005 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 1006 contractOp, lhs, rhs, contractOp.acc(), 1007 rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); 1008 return success(); 1009 } 1010 }; 1011 1012 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and 1013 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when 1014 /// casting ops are around these operations. 1015 /// Ex: 1016 /// ``` 1017 /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8> 1018 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> 1019 /// ``` 1020 /// Gets converted to: 1021 /// ``` 1022 /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32> 1023 /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32> 1024 /// ``` 1025 struct ReorderCastOpsOnBroadcast 1026 : public OpInterfaceRewritePattern<CastOpInterface> { 1027 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern; 1028 1029 LogicalResult matchAndRewrite(CastOpInterface op, 1030 PatternRewriter &rewriter) const override { 1031 if (op->getNumOperands() != 1) 1032 return failure(); 1033 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>(); 1034 if (!bcastOp) 1035 return failure(); 1036 1037 Type castResTy = getElementTypeOrSelf(op->getResult(0)); 1038 if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>()) 1039 castResTy = VectorType::get(vecTy.getShape(), castResTy); 1040 auto castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), 1041 bcastOp.source(), castResTy, op->getAttrs()); 1042 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1043 op, op->getResult(0).getType(), castOp->getResult(0)); 1044 return success(); 1045 } 1046 }; 1047 1048 /// Reorders cast(transpose) to transpose(cast). This makes broadcast ops and 1049 /// contraction ops closer, which kicks in CombineContractTranspose pattern when 1050 /// casting ops are around these operations. 1051 /// Ex: 1052 /// ``` 1053 /// %0 = vector.transpose %arg0, [2, 0, 1] 1054 /// : vector<32x16x8xi8> to vector<8x32x16xi8> 1055 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> 1056 /// ``` 1057 /// Gets converted to: 1058 /// ``` 1059 /// %0 = arith.extsi %0 : vector<32x16x8xi8> to vector<32x16x8xi32> 1060 /// %1 = vector.transpose %arg0, [2, 0, 1] 1061 /// : vector<32x16x8xi32> to vector<8x32x16xi32> 1062 /// ``` 1063 struct ReorderCastOpsOnTranspose 1064 : public OpInterfaceRewritePattern<CastOpInterface> { 1065 1066 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern; 1067 1068 LogicalResult matchAndRewrite(CastOpInterface op, 1069 PatternRewriter &rewriter) const override { 1070 if (op->getNumOperands() != 1) 1071 return failure(); 1072 auto transpOp = op->getOperand(0).getDefiningOp<vector::TransposeOp>(); 1073 if (!transpOp) 1074 return failure(); 1075 1076 auto castResTy = transpOp.getVectorType(); 1077 castResTy = VectorType::get(castResTy.getShape(), 1078 getElementTypeOrSelf(op->getResult(0))); 1079 auto castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), 1080 transpOp.vector(), castResTy, op->getAttrs()); 1081 rewriter.replaceOpWithNewOp<vector::TransposeOp>( 1082 op, op->getResult(0).getType(), castOp->getResult(0), 1083 transpOp.getTransp()); 1084 return success(); 1085 } 1086 }; 1087 1088 } // namespace 1089 1090 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using 1091 /// operands `x` and `y`. 1092 static Value createAdd(Location loc, Value x, Value y, bool isInt, 1093 PatternRewriter &rewriter) { 1094 if (isInt) 1095 return rewriter.create<arith::AddIOp>(loc, x, y); 1096 return rewriter.create<arith::AddFOp>(loc, x, y); 1097 } 1098 1099 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using 1100 /// operands `x and `y`. 1101 static Value createMul(Location loc, Value x, Value y, bool isInt, 1102 PatternRewriter &rewriter) { 1103 if (isInt) 1104 return rewriter.create<arith::MulIOp>(loc, x, y); 1105 return rewriter.create<arith::MulFOp>(loc, x, y); 1106 } 1107 1108 namespace mlir { 1109 1110 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1111 /// semantics to: 1112 /// ``` 1113 /// %mta = maybe_transpose 1114 /// %mtb = maybe_transpose 1115 /// %flattened_a = vector.shape_cast %mta 1116 /// %flattened_b = vector.shape_cast %mtb 1117 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 1118 /// %mtd = vector.shape_cast %flattened_d 1119 /// %d = maybe_untranspose %mtd 1120 /// %e = add %c, %d 1121 /// ``` 1122 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 1123 // 1124 /// This only kicks in when VectorTransformsOptions is set to `Matmul`. 1125 /// vector.transpose operations are inserted if the vector.contract op is not a 1126 /// row-major matrix multiply. 1127 LogicalResult 1128 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, 1129 PatternRewriter &rew) const { 1130 // TODO: implement masks 1131 if (llvm::size(op.masks()) != 0) 1132 return failure(); 1133 if (vectorTransformOptions.vectorContractLowering != 1134 vector::VectorContractLowering::Matmul) 1135 return failure(); 1136 if (failed(filter(op))) 1137 return failure(); 1138 1139 auto iteratorTypes = op.iterator_types().getValue(); 1140 if (!isParallelIterator(iteratorTypes[0]) || 1141 !isParallelIterator(iteratorTypes[1]) || 1142 !isReductionIterator(iteratorTypes[2])) 1143 return failure(); 1144 1145 Type elementType = op.getLhsType().getElementType(); 1146 if (!elementType.isIntOrFloat()) 1147 return failure(); 1148 1149 // Perform lhs + rhs transpositions to conform to matmul row-major semantics. 1150 // Bail out if the contraction cannot be put in this form. 1151 MLIRContext *ctx = op.getContext(); 1152 Location loc = op.getLoc(); 1153 AffineExpr m, n, k; 1154 bindDims(rew.getContext(), m, n, k); 1155 // LHS must be A(m, k) or A(k, m). 1156 Value lhs = op.lhs(); 1157 auto lhsMap = op.indexing_maps()[0]; 1158 if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) 1159 lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0}); 1160 else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) 1161 return failure(); 1162 1163 // RHS must be B(k, n) or B(n, k). 1164 Value rhs = op.rhs(); 1165 auto rhsMap = op.indexing_maps()[1]; 1166 if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) 1167 rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0}); 1168 else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) 1169 return failure(); 1170 1171 // At this point lhs and rhs are in row-major. 1172 VectorType lhsType = lhs.getType().cast<VectorType>(); 1173 VectorType rhsType = rhs.getType().cast<VectorType>(); 1174 int64_t lhsRows = lhsType.getDimSize(0); 1175 int64_t lhsColumns = lhsType.getDimSize(1); 1176 int64_t rhsColumns = rhsType.getDimSize(1); 1177 1178 Type flattenedLHSType = 1179 VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); 1180 lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs); 1181 1182 Type flattenedRHSType = 1183 VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); 1184 rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs); 1185 1186 Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns, 1187 rhsColumns); 1188 mul = rew.create<vector::ShapeCastOp>( 1189 loc, 1190 VectorType::get({lhsRows, rhsColumns}, 1191 getElementTypeOrSelf(op.acc().getType())), 1192 mul); 1193 1194 // ACC must be C(m, n) or C(n, m). 1195 auto accMap = op.indexing_maps()[2]; 1196 if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) 1197 mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0}); 1198 else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) 1199 llvm_unreachable("invalid contraction semantics"); 1200 1201 Value res = 1202 elementType.isa<IntegerType>() 1203 ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.acc(), mul)) 1204 : static_cast<Value>(rew.create<arith::AddFOp>(loc, op.acc(), mul)); 1205 1206 rew.replaceOp(op, res); 1207 return success(); 1208 } 1209 1210 namespace { 1211 struct IteratorType { 1212 IteratorType(StringRef strRef) : strRef(strRef) {} 1213 bool isOfType(Attribute attr) const { 1214 auto sAttr = attr.dyn_cast<StringAttr>(); 1215 return sAttr && sAttr.getValue() == strRef; 1216 } 1217 StringRef strRef; 1218 }; 1219 struct Par : public IteratorType { 1220 Par() : IteratorType(getParallelIteratorTypeName()) {} 1221 }; 1222 struct Red : public IteratorType { 1223 Red() : IteratorType(getReductionIteratorTypeName()) {} 1224 }; 1225 1226 /// Generate a vector implementation for matmat, matvec and tmatvec. 1227 /// This unrolls outer-products along the reduction dimension. 1228 struct UnrolledOuterProductGenerator 1229 : public StructuredGenerator<vector::ContractionOp> { 1230 1231 UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op) 1232 : StructuredGenerator<vector::ContractionOp>(builder, op), 1233 kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()), 1234 lhsType(op.getLhsType()) {} 1235 1236 Value t(Value v) { 1237 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1238 return builder.create<vector::TransposeOp>(loc, v, perm); 1239 } 1240 1241 Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) { 1242 assert(reductionSize > 0); 1243 for (int64_t k = 0; k < reductionSize; ++k) { 1244 Value a = builder.create<vector::ExtractOp>(loc, lhs, k); 1245 Value b = builder.create<vector::ExtractOp>(loc, rhs, k); 1246 res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b, 1247 res, kind); 1248 } 1249 return res; 1250 } 1251 1252 /// Two outer parallel, one inner reduction (matmat flavor). 1253 FailureOr<Value> matmat() { 1254 if (!iters({Par(), Par(), Red()})) 1255 return failure(); 1256 // Set up the parallel/reduction structure in the right form. 1257 AffineExpr m, n, k; 1258 bindDims(builder.getContext(), m, n, k); 1259 // Classical row-major matmul: Just permute the lhs. 1260 if (layout({{m, k}, {k, n}, {m, n}})) 1261 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1262 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1263 if (layout({{m, k}, {n, k}, {m, n}})) { 1264 Value tlhs = t(lhs); 1265 return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1)); 1266 } 1267 // No need to permute anything. 1268 if (layout({{k, m}, {k, n}, {m, n}})) 1269 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1270 // Just permute the rhs. 1271 if (layout({{k, m}, {n, k}, {m, n}})) 1272 return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0)); 1273 // Transposed output: swap RHS and LHS. 1274 // Classical row-major matmul: permute the lhs. 1275 if (layout({{m, k}, {k, n}, {n, m}})) 1276 return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1)); 1277 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1278 if (layout({{m, k}, {n, k}, {n, m}})) { 1279 Value trhs = t(rhs); 1280 return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1)); 1281 } 1282 if (layout({{k, m}, {k, n}, {n, m}})) 1283 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1284 if (layout({{k, m}, {n, k}, {n, m}})) 1285 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1286 return failure(); 1287 } 1288 1289 /// One outer parallel, one inner reduction (matvec flavor) 1290 FailureOr<Value> matvec() { 1291 if (!iters({Par(), Red()})) 1292 return failure(); 1293 AffineExpr m, k; 1294 bindDims(builder.getContext(), m, k); 1295 1296 // Case mat-vec: transpose. 1297 if (layout({{m, k}, {k}, {m}})) 1298 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1299 // Case mat-trans-vec: ready to go. 1300 if (layout({{k, m}, {k}, {m}})) 1301 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1302 // Case vec-mat: swap and transpose. 1303 if (layout({{k}, {m, k}, {m}})) 1304 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1305 // Case vec-mat-trans: swap and ready to go. 1306 if (layout({{k}, {k, m}, {m}})) 1307 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1308 return failure(); 1309 } 1310 1311 // 1312 // One outer reduction, one inner parallel (tmatvec flavor) 1313 // 1314 FailureOr<Value> tmatvec() { 1315 if (!iters({Red(), Par()})) 1316 return failure(); 1317 AffineExpr k, m; 1318 bindDims(builder.getContext(), k, m); 1319 1320 // Case mat-vec: transpose. 1321 if (layout({{m, k}, {k}, {m}})) 1322 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1323 // Case mat-trans-vec: ready to go. 1324 if (layout({{k, m}, {k}, {m}})) 1325 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1326 // Case vec-mat: swap and transpose. 1327 if (layout({{k}, {m, k}, {m}})) 1328 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1329 // Case vec-mat-trans: swap and ready to go. 1330 if (layout({{k}, {k, m}, {m}})) 1331 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1332 return failure(); 1333 } 1334 1335 private: 1336 vector::CombiningKind kind; 1337 Value lhs, rhs, res; 1338 VectorType lhsType; 1339 }; 1340 } // namespace 1341 1342 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1343 /// semantics to a reduction_size-unrolled sequence: 1344 /// ``` 1345 /// %at = vector.transpose %a, [1, 0] 1346 /// %bRow0 = vector.extract %b[0] 1347 /// %atRow0 = vector.extract %at[0] 1348 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 1349 /// ... 1350 /// %bRowK = vector.extract %b[K] 1351 /// %atRowK = vector.extract %at[K] 1352 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 1353 /// ``` 1354 /// 1355 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but 1356 /// otherwise supports any layout permutation of the matrix-multiply. 1357 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( 1358 vector::ContractionOp op, PatternRewriter &rewriter) const { 1359 // TODO: implement masks 1360 if (llvm::size(op.masks()) != 0) 1361 return failure(); 1362 1363 if (vectorTransformOptions.vectorContractLowering != 1364 vector::VectorContractLowering::OuterProduct) 1365 return failure(); 1366 1367 if (failed(filter(op))) 1368 return failure(); 1369 1370 UnrolledOuterProductGenerator e(rewriter, op); 1371 FailureOr<Value> matmatRes = e.matmat(); 1372 if (succeeded(matmatRes)) { 1373 rewriter.replaceOp(op, *matmatRes); 1374 return success(); 1375 } 1376 FailureOr<Value> matvecRes = e.matvec(); 1377 if (succeeded(matvecRes)) { 1378 rewriter.replaceOp(op, *matvecRes); 1379 return success(); 1380 } 1381 FailureOr<Value> tmatvecRes = e.tmatvec(); 1382 if (succeeded(tmatvecRes)) { 1383 rewriter.replaceOp(op, *tmatvecRes); 1384 return success(); 1385 } 1386 1387 return failure(); 1388 } 1389 1390 LogicalResult 1391 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, 1392 PatternRewriter &rewriter) const { 1393 // TODO: implement masks 1394 if (llvm::size(op.masks()) != 0) 1395 return failure(); 1396 1397 if (failed(filter(op))) 1398 return failure(); 1399 1400 if (vectorTransformOptions.vectorContractLowering != 1401 vector::VectorContractLowering::Dot) 1402 return failure(); 1403 1404 auto iteratorTypes = op.iterator_types().getValue(); 1405 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1406 Location loc = op.getLoc(); 1407 Value lhs = op.lhs(), rhs = op.rhs(); 1408 1409 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1410 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 1411 AffineExpr m, n, k; 1412 bindDims(rewriter.getContext(), m, n, k); 1413 SmallVector<AffineMap, 4> maps = op.getIndexingMaps(); 1414 // 1415 // In the following we wish to make the reduction dimension innermost so we 1416 // can load vectors and just fmul + reduce into a scalar. 1417 // 1418 if (isParallelIterator(iteratorTypes[0]) && 1419 isParallelIterator(iteratorTypes[1]) && 1420 isReductionIterator(iteratorTypes[2])) { 1421 // 1422 // Two outer parallel, one inner reduction (matmat flavor). 1423 // 1424 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 1425 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1426 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { 1427 // No need to permute anything. 1428 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 1429 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1430 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1431 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 1432 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1433 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 1434 // This is the classical row-major matmul. Just permute the lhs. 1435 Value tmp = lhs; 1436 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1437 rhs = tmp; 1438 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 1439 std::swap(lhs, rhs); 1440 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 1441 Value tmp = lhs; 1442 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1443 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm); 1444 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 1445 Value tmp = rhs; 1446 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1447 lhs = tmp; 1448 } else { 1449 return failure(); 1450 } 1451 } else if (isParallelIterator(iteratorTypes[0]) && 1452 isReductionIterator(iteratorTypes[1])) { 1453 // 1454 // One outer parallel, one inner reduction (matvec flavor) 1455 // 1456 if (maps == infer({{m, n}, {n}, {m}})) { 1457 // No need to permute anything. 1458 } else if (maps == infer({{n, m}, {n}, {m}})) { 1459 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1460 } else if (maps == infer({{n}, {m, n}, {m}})) { 1461 std::swap(lhs, rhs); 1462 } else if (maps == infer({{n}, {n, m}, {m}})) { 1463 std::swap(lhs, rhs); 1464 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1465 } else { 1466 return failure(); 1467 } 1468 } else { 1469 return failure(); 1470 } 1471 1472 VectorType dstType = op.getResultType().cast<VectorType>(); 1473 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && 1474 "Expected dst type of rank 1 or 2"); 1475 1476 unsigned rank = dstType.getRank(); 1477 unsigned dstRows = dstType.getShape()[0]; 1478 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; 1479 1480 // ExtractOp does not allow dynamic indexing, we must unroll explicitly. 1481 Value res = rewriter.create<arith::ConstantOp>(loc, dstType, 1482 rewriter.getZeroAttr(dstType)); 1483 bool isInt = dstType.getElementType().isa<IntegerType>(); 1484 for (unsigned r = 0; r < dstRows; ++r) { 1485 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r); 1486 for (unsigned c = 0; c < dstColumns; ++c) { 1487 Value b = rank == 1 1488 ? rhs 1489 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c); 1490 Value m = createMul(op.getLoc(), a, b, isInt, rewriter); 1491 Value reduced = rewriter.create<vector::ReductionOp>( 1492 op.getLoc(), vector::CombiningKind::ADD, m); 1493 1494 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r} 1495 : SmallVector<int64_t, 2>{r, c}; 1496 res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos); 1497 } 1498 } 1499 if (auto acc = op.acc()) 1500 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1501 rewriter.replaceOp(op, res); 1502 return success(); 1503 } 1504 1505 /// Progressive lowering of ContractionOp. 1506 /// One: 1507 /// %x = vector.contract with at least one free/batch dimension 1508 /// is replaced by: 1509 /// %a = vector.contract with one less free/batch dimension 1510 /// %b = vector.contract with one less free/batch dimension 1511 /// .. 1512 /// %x = combine %a %b .. 1513 /// until a pure contraction is reached (no free/batch dimensions), 1514 /// which is replaced by a dot-product. 1515 /// 1516 /// This only kicks in when either VectorTransformsOptions is set 1517 /// to DOT or when other contraction patterns fail. 1518 // 1519 // TODO: break down into transpose/reshape/cast ops 1520 // when they become available to avoid code dup 1521 // TODO: investigate lowering order impact on performance 1522 LogicalResult 1523 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, 1524 PatternRewriter &rewriter) const { 1525 // TODO: implement masks. 1526 if (llvm::size(op.masks()) != 0) 1527 return failure(); 1528 1529 if (failed(filter(op))) 1530 return failure(); 1531 1532 // TODO: support mixed mode contract lowering. 1533 if (op.getLhsType().getElementType() != 1534 getElementTypeOrSelf(op.getAccType()) || 1535 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) 1536 return failure(); 1537 1538 // TODO: implement benefits, cost models. 1539 MLIRContext *ctx = op.getContext(); 1540 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); 1541 if (succeeded(pat1.matchAndRewrite(op, rewriter))) 1542 return success(); 1543 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); 1544 if (succeeded(pat2.matchAndRewrite(op, rewriter))) 1545 return success(); 1546 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); 1547 if (succeeded(pat3.matchAndRewrite(op, rewriter))) 1548 return success(); 1549 1550 // Find first batch dimension in LHS/RHS, and lower when found. 1551 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap(); 1552 if (!batchDimMap.empty()) { 1553 int64_t lhsIndex = batchDimMap[0].first; 1554 int64_t rhsIndex = batchDimMap[0].second; 1555 rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter)); 1556 return success(); 1557 } 1558 1559 // Collect contracting dimensions. 1560 std::vector<std::pair<int64_t, int64_t>> contractingDimMap = 1561 op.getContractingDimMap(); 1562 DenseSet<int64_t> lhsContractingDimSet; 1563 DenseSet<int64_t> rhsContractingDimSet; 1564 for (auto &dimPair : contractingDimMap) { 1565 lhsContractingDimSet.insert(dimPair.first); 1566 rhsContractingDimSet.insert(dimPair.second); 1567 } 1568 1569 // Find first free dimension in LHS, and lower when found. 1570 VectorType lhsType = op.getLhsType(); 1571 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { 1572 if (lhsContractingDimSet.count(lhsIndex) == 0) { 1573 rewriter.replaceOp( 1574 op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter)); 1575 return success(); 1576 } 1577 } 1578 1579 // Find first free dimension in RHS, and lower when found. 1580 VectorType rhsType = op.getRhsType(); 1581 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { 1582 if (rhsContractingDimSet.count(rhsIndex) == 0) { 1583 rewriter.replaceOp( 1584 op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter)); 1585 return success(); 1586 } 1587 } 1588 1589 // Lower the first remaining reduction dimension. 1590 if (!contractingDimMap.empty()) { 1591 rewriter.replaceOp(op, lowerReduction(op, rewriter)); 1592 return success(); 1593 } 1594 1595 return failure(); 1596 } 1597 1598 // Lower one parallel dimension. 1599 // TODO: consider reusing existing contract unrolling 1600 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op, 1601 int64_t lhsIndex, int64_t rhsIndex, 1602 PatternRewriter &rewriter) const { 1603 VectorType lhsType = op.getLhsType(); 1604 VectorType rhsType = op.getRhsType(); 1605 VectorType resType = op.getResultType().cast<VectorType>(); 1606 // Find the iterator type index and result index. 1607 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1608 int64_t iterIndex = -1; 1609 int64_t dimSize = -1; 1610 if (lhsIndex >= 0) { 1611 iterIndex = iMap[0].getDimPosition(lhsIndex); 1612 assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) && 1613 "parallel index should be free in LHS or batch in LHS/RHS"); 1614 dimSize = lhsType.getDimSize(lhsIndex); 1615 } else { 1616 assert(rhsIndex >= 0 && "missing parallel index"); 1617 iterIndex = iMap[1].getDimPosition(rhsIndex); 1618 dimSize = rhsType.getDimSize(rhsIndex); 1619 } 1620 assert(iterIndex >= 0 && "parallel index not listed in operand mapping"); 1621 Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex); 1622 assert(lookup.hasValue() && "parallel index not listed in reduction"); 1623 int64_t resIndex = lookup.getValue(); 1624 // Construct new iterator types and affine map array attribute. 1625 std::array<AffineMap, 3> lowIndexingMaps = { 1626 adjustMap(iMap[0], iterIndex, rewriter), 1627 adjustMap(iMap[1], iterIndex, rewriter), 1628 adjustMap(iMap[2], iterIndex, rewriter)}; 1629 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1630 auto lowIter = 1631 rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); 1632 // Unroll into a series of lower dimensional vector.contract ops. 1633 Location loc = op.getLoc(); 1634 Value result = rewriter.create<arith::ConstantOp>( 1635 loc, resType, rewriter.getZeroAttr(resType)); 1636 for (int64_t d = 0; d < dimSize; ++d) { 1637 auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); 1638 auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); 1639 auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter); 1640 Value lowContract = rewriter.create<vector::ContractionOp>( 1641 loc, lhs, rhs, acc, lowAffine, lowIter); 1642 result = 1643 reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter); 1644 } 1645 return result; 1646 } 1647 1648 // Lower one reduction dimension. 1649 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op, 1650 PatternRewriter &rewriter) const { 1651 auto loc = op.getLoc(); 1652 VectorType lhsType = op.getLhsType(); 1653 VectorType rhsType = op.getRhsType(); 1654 Type resType = op.getResultType(); 1655 assert(!resType.isa<VectorType>()); 1656 bool isInt = resType.isa<IntegerType>(); 1657 // Use iterator index 0. 1658 int64_t iterIndex = 0; 1659 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1660 Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex); 1661 Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex); 1662 assert(lookupLhs.hasValue() && "missing LHS parallel index"); 1663 assert(lookupRhs.hasValue() && "missing RHS parallel index"); 1664 int64_t lhsIndex = lookupLhs.getValue(); 1665 int64_t rhsIndex = lookupRhs.getValue(); 1666 int64_t dimSize = lhsType.getDimSize(lhsIndex); 1667 assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape"); 1668 // Base case. 1669 if (lhsType.getRank() == 1) { 1670 assert(rhsType.getRank() == 1 && "corrupt contraction"); 1671 Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter); 1672 auto kind = vector::CombiningKind::ADD; 1673 Value res = rewriter.create<vector::ReductionOp>(loc, kind, m); 1674 if (auto acc = op.acc()) 1675 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1676 return res; 1677 } 1678 // Construct new iterator types and affine map array attribute. 1679 std::array<AffineMap, 3> lowIndexingMaps = { 1680 adjustMap(iMap[0], iterIndex, rewriter), 1681 adjustMap(iMap[1], iterIndex, rewriter), 1682 adjustMap(iMap[2], iterIndex, rewriter)}; 1683 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1684 auto lowIter = 1685 rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); 1686 // Unroll into a series of lower dimensional vector.contract ops. 1687 // By feeding the initial accumulator into the first contraction, 1688 // and the result of each contraction into the next, eventually 1689 // the sum of all reductions is computed. 1690 Value result = op.acc(); 1691 for (int64_t d = 0; d < dimSize; ++d) { 1692 auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); 1693 auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); 1694 result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result, 1695 lowAffine, lowIter); 1696 } 1697 return result; 1698 } 1699 1700 } // namespace mlir 1701 1702 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp( 1703 OpBuilder &builder, Operation *op, ArrayRef<Value> ids, 1704 ArrayRef<int64_t> multiplicity, const AffineMap &map) { 1705 OpBuilder::InsertionGuard guard(builder); 1706 builder.setInsertionPointAfter(op); 1707 Location loc = op->getLoc(); 1708 if (op->getNumResults() != 1) 1709 return {}; 1710 Value result = op->getResult(0); 1711 VectorType type = op->getResult(0).getType().dyn_cast<VectorType>(); 1712 if (!type || map.getNumResults() != multiplicity.size()) 1713 return {}; 1714 // For each dimension being distributed check that the size is a multiple of 1715 // the multiplicity. To handle more sizes we would need to support masking. 1716 unsigned multiplictyCount = 0; 1717 for (auto exp : map.getResults()) { 1718 auto affinExp = exp.dyn_cast<AffineDimExpr>(); 1719 if (!affinExp || affinExp.getPosition() >= type.getRank() || 1720 type.getDimSize(affinExp.getPosition()) % 1721 multiplicity[multiplictyCount++] != 1722 0) 1723 return {}; 1724 } 1725 DistributeOps ops; 1726 ops.extract = 1727 builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map); 1728 ops.insert = 1729 builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids); 1730 return ops; 1731 } 1732 1733 /// Progressive lowering of transfer_read. This pattern supports lowering of 1734 /// `vector.transfer_read` to a combination of `vector.load` and 1735 /// `vector.broadcast` if all of the following hold: 1736 /// - Stride of most minor memref dimension must be 1. 1737 /// - Out-of-bounds masking is not required. 1738 /// - If the memref's element type is a vector type then it coincides with the 1739 /// result type. 1740 /// - The permutation map doesn't perform permutation (broadcasting is allowed). 1741 struct TransferReadToVectorLoadLowering 1742 : public OpRewritePattern<vector::TransferReadOp> { 1743 TransferReadToVectorLoadLowering(MLIRContext *context, 1744 llvm::Optional<unsigned> maxRank) 1745 : OpRewritePattern<vector::TransferReadOp>(context), 1746 maxTransferRank(maxRank) {} 1747 1748 LogicalResult matchAndRewrite(vector::TransferReadOp read, 1749 PatternRewriter &rewriter) const override { 1750 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) 1751 return failure(); 1752 1753 SmallVector<unsigned, 4> broadcastedDims; 1754 // Permutations are handled by VectorToSCF or 1755 // populateVectorTransferPermutationMapLoweringPatterns. 1756 // We let the 0-d corner case pass-through as it is supported. 1757 if (!read.permutation_map().isMinorIdentityWithBroadcasting( 1758 &broadcastedDims)) 1759 return failure(); 1760 1761 auto memRefType = read.getShapedType().dyn_cast<MemRefType>(); 1762 if (!memRefType) 1763 return failure(); 1764 1765 // Non-unit strides are handled by VectorToSCF. 1766 if (!vector::isLastMemrefDimUnitStride(memRefType)) 1767 return failure(); 1768 1769 // If there is broadcasting involved then we first load the unbroadcasted 1770 // vector, and then broadcast it with `vector.broadcast`. 1771 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape(); 1772 SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(), 1773 vectorShape.end()); 1774 for (unsigned i : broadcastedDims) 1775 unbroadcastedVectorShape[i] = 1; 1776 VectorType unbroadcastedVectorType = VectorType::get( 1777 unbroadcastedVectorShape, read.getVectorType().getElementType()); 1778 1779 // `vector.load` supports vector types as memref's elements only when the 1780 // resulting vector type is the same as the element type. 1781 auto memrefElTy = memRefType.getElementType(); 1782 if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType) 1783 return failure(); 1784 1785 // Otherwise, element types of the memref and the vector must match. 1786 if (!memrefElTy.isa<VectorType>() && 1787 memrefElTy != read.getVectorType().getElementType()) 1788 return failure(); 1789 1790 // Out-of-bounds dims are handled by MaterializeTransferMask. 1791 if (read.hasOutOfBoundsDim()) 1792 return failure(); 1793 1794 // Create vector load op. 1795 Operation *loadOp; 1796 if (read.mask()) { 1797 Value fill = rewriter.create<vector::SplatOp>( 1798 read.getLoc(), unbroadcastedVectorType, read.padding()); 1799 loadOp = rewriter.create<vector::MaskedLoadOp>( 1800 read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(), 1801 read.mask(), fill); 1802 } else { 1803 loadOp = rewriter.create<vector::LoadOp>(read.getLoc(), 1804 unbroadcastedVectorType, 1805 read.source(), read.indices()); 1806 } 1807 1808 // Insert a broadcasting op if required. 1809 if (!broadcastedDims.empty()) { 1810 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1811 read, read.getVectorType(), loadOp->getResult(0)); 1812 } else { 1813 rewriter.replaceOp(read, loadOp->getResult(0)); 1814 } 1815 1816 return success(); 1817 } 1818 1819 llvm::Optional<unsigned> maxTransferRank; 1820 }; 1821 1822 /// Replace a 0-d vector.load with a memref.load + vector.broadcast. 1823 // TODO: we shouldn't cross the vector/scalar domains just for this 1824 // but atm we lack the infra to avoid it. Possible solutions include: 1825 // - go directly to LLVM + bitcast 1826 // - introduce a bitcast op and likely a new pointer dialect 1827 // - let memref.load/store additionally support the 0-d vector case 1828 // There are still deeper data layout issues lingering even in this 1829 // trivial case (for architectures for which this matters). 1830 struct VectorLoadToMemrefLoadLowering 1831 : public OpRewritePattern<vector::LoadOp> { 1832 using OpRewritePattern<vector::LoadOp>::OpRewritePattern; 1833 1834 LogicalResult matchAndRewrite(vector::LoadOp loadOp, 1835 PatternRewriter &rewriter) const override { 1836 auto vecType = loadOp.getVectorType(); 1837 if (vecType.getNumElements() != 1) 1838 return failure(); 1839 auto memrefLoad = rewriter.create<memref::LoadOp>( 1840 loadOp.getLoc(), loadOp.base(), loadOp.indices()); 1841 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType, 1842 memrefLoad); 1843 return success(); 1844 } 1845 }; 1846 1847 /// Replace a 0-d vector.store with a vector.extractelement + memref.store. 1848 struct VectorStoreToMemrefStoreLowering 1849 : public OpRewritePattern<vector::StoreOp> { 1850 using OpRewritePattern<vector::StoreOp>::OpRewritePattern; 1851 1852 LogicalResult matchAndRewrite(vector::StoreOp storeOp, 1853 PatternRewriter &rewriter) const override { 1854 auto vecType = storeOp.getVectorType(); 1855 if (vecType.getNumElements() != 1) 1856 return failure(); 1857 Value extracted; 1858 if (vecType.getRank() == 0) { 1859 // TODO: Unifiy once ExtractOp supports 0-d vectors. 1860 extracted = rewriter.create<vector::ExtractElementOp>( 1861 storeOp.getLoc(), storeOp.valueToStore()); 1862 } else { 1863 SmallVector<int64_t> indices(vecType.getRank(), 0); 1864 extracted = rewriter.create<vector::ExtractOp>( 1865 storeOp.getLoc(), storeOp.valueToStore(), indices); 1866 } 1867 1868 rewriter.replaceOpWithNewOp<memref::StoreOp>( 1869 storeOp, extracted, storeOp.base(), storeOp.indices()); 1870 return success(); 1871 } 1872 }; 1873 1874 /// Progressive lowering of transfer_write. This pattern supports lowering of 1875 /// `vector.transfer_write` to `vector.store` if all of the following hold: 1876 /// - Stride of most minor memref dimension must be 1. 1877 /// - Out-of-bounds masking is not required. 1878 /// - If the memref's element type is a vector type then it coincides with the 1879 /// type of the written value. 1880 /// - The permutation map is the minor identity map (neither permutation nor 1881 /// broadcasting is allowed). 1882 struct TransferWriteToVectorStoreLowering 1883 : public OpRewritePattern<vector::TransferWriteOp> { 1884 TransferWriteToVectorStoreLowering(MLIRContext *context, 1885 llvm::Optional<unsigned> maxRank) 1886 : OpRewritePattern<vector::TransferWriteOp>(context), 1887 maxTransferRank(maxRank) {} 1888 1889 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 1890 PatternRewriter &rewriter) const override { 1891 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) 1892 return failure(); 1893 1894 // Permutations are handled by VectorToSCF or 1895 // populateVectorTransferPermutationMapLoweringPatterns. 1896 if ( // pass-through for the 0-d corner case. 1897 !write.permutation_map().isMinorIdentity()) 1898 return failure(); 1899 1900 auto memRefType = write.getShapedType().dyn_cast<MemRefType>(); 1901 if (!memRefType) 1902 return failure(); 1903 1904 // Non-unit strides are handled by VectorToSCF. 1905 if (!vector::isLastMemrefDimUnitStride(memRefType)) 1906 return failure(); 1907 1908 // `vector.store` supports vector types as memref's elements only when the 1909 // type of the vector value being written is the same as the element type. 1910 auto memrefElTy = memRefType.getElementType(); 1911 if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType()) 1912 return failure(); 1913 1914 // Otherwise, element types of the memref and the vector must match. 1915 if (!memrefElTy.isa<VectorType>() && 1916 memrefElTy != write.getVectorType().getElementType()) 1917 return failure(); 1918 1919 // Out-of-bounds dims are handled by MaterializeTransferMask. 1920 if (write.hasOutOfBoundsDim()) 1921 return failure(); 1922 if (write.mask()) { 1923 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 1924 write, write.source(), write.indices(), write.mask(), write.vector()); 1925 } else { 1926 rewriter.replaceOpWithNewOp<vector::StoreOp>( 1927 write, write.vector(), write.source(), write.indices()); 1928 } 1929 return success(); 1930 } 1931 1932 llvm::Optional<unsigned> maxTransferRank; 1933 }; 1934 1935 // Returns the values in `arrayAttr` as an integer vector. 1936 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) { 1937 return llvm::to_vector<4>( 1938 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(), 1939 [](IntegerAttr attr) { return attr.getInt(); })); 1940 } 1941 1942 // Shuffles vector.bitcast op after vector.extract op. 1943 // 1944 // This transforms IR like: 1945 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> 1946 // %1 = vector.extract %0[3] : vector<8xf16> 1947 // Into: 1948 // %0 = vector.extract %src[1] : vector<4xf32> 1949 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16> 1950 // %2 = vector.extract %1[1] : vector<2xf16> 1951 struct BubbleDownVectorBitCastForExtract 1952 : public OpRewritePattern<vector::ExtractOp> { 1953 using OpRewritePattern::OpRewritePattern; 1954 1955 LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 1956 PatternRewriter &rewriter) const override { 1957 // Only support extracting scalars for now. 1958 if (extractOp.getVectorType().getRank() != 1) 1959 return failure(); 1960 1961 auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>(); 1962 if (!castOp) 1963 return failure(); 1964 1965 VectorType castSrcType = castOp.getSourceVectorType(); 1966 VectorType castDstType = castOp.getResultVectorType(); 1967 assert(castSrcType.getRank() == castDstType.getRank()); 1968 1969 // Fail to match if we only have one element in the cast op source. 1970 // This is to avoid infinite loop given that this pattern can generate 1971 // such cases. 1972 if (castSrcType.getNumElements() == 1) 1973 return failure(); 1974 1975 // Only support casting to a larger number of elements or now. 1976 // E.g., vector<4xf32> -> vector<8xf16>. 1977 if (castSrcType.getNumElements() > castDstType.getNumElements()) 1978 return failure(); 1979 1980 unsigned expandRatio = 1981 castDstType.getNumElements() / castSrcType.getNumElements(); 1982 1983 auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t { 1984 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 1985 }; 1986 1987 uint64_t index = getFirstIntValue(extractOp.position()); 1988 1989 // Get the single scalar (as a vector) in the source value that packs the 1990 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> 1991 VectorType oneScalarType = 1992 VectorType::get({1}, castSrcType.getElementType()); 1993 Value packedValue = rewriter.create<vector::ExtractOp>( 1994 extractOp.getLoc(), oneScalarType, castOp.source(), 1995 rewriter.getI64ArrayAttr(index / expandRatio)); 1996 1997 // Cast it to a vector with the desired scalar's type. 1998 // E.g. f32 -> vector<2xf16> 1999 VectorType packedType = 2000 VectorType::get({expandRatio}, castDstType.getElementType()); 2001 Value castedValue = rewriter.create<vector::BitCastOp>( 2002 extractOp.getLoc(), packedType, packedValue); 2003 2004 // Finally extract the desired scalar. 2005 rewriter.replaceOpWithNewOp<vector::ExtractOp>( 2006 extractOp, extractOp.getType(), castedValue, 2007 rewriter.getI64ArrayAttr(index % expandRatio)); 2008 2009 return success(); 2010 } 2011 }; 2012 2013 // Shuffles vector.bitcast op after vector.extract_strided_slice op. 2014 // 2015 // This transforms IR like: 2016 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> 2017 // %0 = vector.extract_strided_slice %cast { 2018 // offsets = [4], sizes = [4], strides = [1] 2019 // } : vector<8xf16> to vector<4xf16> 2020 // Into: 2021 // %0 = vector.extract_strided_slice %src { 2022 // offsets = [2], sizes = [2], strides = [1] 2023 // } : vector<4xf32> to vector<2xf32> 2024 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> 2025 struct BubbleDownBitCastForStridedSliceExtract 2026 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 2027 using OpRewritePattern::OpRewritePattern; 2028 2029 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 2030 PatternRewriter &rewriter) const override { 2031 auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>(); 2032 if (!castOp) 2033 return failure(); 2034 2035 VectorType castSrcType = castOp.getSourceVectorType(); 2036 VectorType castDstType = castOp.getResultVectorType(); 2037 assert(castSrcType.getRank() == castDstType.getRank()); 2038 2039 int64_t castSrcLastDim = castSrcType.getShape().back(); 2040 int64_t castDstLastDim = castDstType.getShape().back(); 2041 // Require casting to more elements for now; other cases to be implemented. 2042 if (castSrcLastDim > castDstLastDim) 2043 return failure(); 2044 2045 // Only accept all one strides for now. 2046 if (llvm::any_of(extractOp.strides().getAsValueRange<IntegerAttr>(), 2047 [](const APInt &val) { return !val.isOneValue(); })) 2048 return failure(); 2049 2050 unsigned rank = extractOp.getVectorType().getRank(); 2051 assert(castDstLastDim % castSrcLastDim == 0); 2052 int64_t expandRatio = castDstLastDim / castSrcLastDim; 2053 2054 // If we have a less number of offsets than the rank, then implicitly we 2055 // are selecting the full range for the last bitcasted dimension; other 2056 // dimensions aren't affected. Otherwise, we need to scale down the last 2057 // dimension's offset given we are extracting from less elements now. 2058 ArrayAttr newOffsets = extractOp.offsets(); 2059 if (newOffsets.size() == rank) { 2060 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2061 if (offsets.back() % expandRatio != 0) 2062 return failure(); 2063 offsets.back() = offsets.back() / expandRatio; 2064 newOffsets = rewriter.getI64ArrayAttr(offsets); 2065 } 2066 2067 // Similarly for sizes. 2068 ArrayAttr newSizes = extractOp.sizes(); 2069 if (newSizes.size() == rank) { 2070 SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes); 2071 if (sizes.back() % expandRatio != 0) 2072 return failure(); 2073 sizes.back() = sizes.back() / expandRatio; 2074 newSizes = rewriter.getI64ArrayAttr(sizes); 2075 } 2076 2077 SmallVector<int64_t, 4> dims = 2078 llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape()); 2079 dims.back() = dims.back() / expandRatio; 2080 VectorType newExtractType = 2081 VectorType::get(dims, castSrcType.getElementType()); 2082 2083 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 2084 extractOp.getLoc(), newExtractType, castOp.source(), newOffsets, 2085 newSizes, extractOp.strides()); 2086 2087 rewriter.replaceOpWithNewOp<vector::BitCastOp>( 2088 extractOp, extractOp.getType(), newExtractOp); 2089 2090 return success(); 2091 } 2092 }; 2093 2094 // Shuffles vector.bitcast op before vector.insert_strided_slice op. 2095 // 2096 // This transforms IR like: 2097 // %0 = vector.insert_strided_slice %src, %dst { 2098 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> 2099 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> 2100 // Into: 2101 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32> 2102 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32> 2103 // %2 = vector.insert_strided_slice %src, %dst { 2104 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 2105 struct BubbleUpBitCastForStridedSliceInsert 2106 : public OpRewritePattern<vector::BitCastOp> { 2107 using OpRewritePattern::OpRewritePattern; 2108 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 2109 PatternRewriter &rewriter) const override { 2110 VectorType castSrcType = bitcastOp.getSourceVectorType(); 2111 VectorType castDstType = bitcastOp.getResultVectorType(); 2112 assert(castSrcType.getRank() == castDstType.getRank()); 2113 2114 int64_t castSrcLastDim = castSrcType.getShape().back(); 2115 int64_t castDstLastDim = castDstType.getShape().back(); 2116 // Require casting to less elements for now; other cases to be implemented. 2117 if (castSrcLastDim < castDstLastDim) 2118 return failure(); 2119 2120 assert(castSrcLastDim % castDstLastDim == 0); 2121 int64_t shrinkRatio = castSrcLastDim / castDstLastDim; 2122 2123 auto insertOp = 2124 bitcastOp.source().getDefiningOp<vector::InsertStridedSliceOp>(); 2125 if (!insertOp) 2126 return failure(); 2127 2128 // Only accept all one strides for now. 2129 if (llvm::any_of(insertOp.strides().getAsValueRange<IntegerAttr>(), 2130 [](const APInt &val) { return !val.isOneValue(); })) 2131 return failure(); 2132 2133 unsigned rank = insertOp.getSourceVectorType().getRank(); 2134 // Require insert op to have the same rank for the source and destination 2135 // vector; other cases to be implemented. 2136 if (rank != insertOp.getDestVectorType().getRank()) 2137 return failure(); 2138 2139 ArrayAttr newOffsets = insertOp.offsets(); 2140 assert(newOffsets.size() == rank); 2141 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2142 if (offsets.back() % shrinkRatio != 0) 2143 return failure(); 2144 offsets.back() = offsets.back() / shrinkRatio; 2145 newOffsets = rewriter.getI64ArrayAttr(offsets); 2146 2147 SmallVector<int64_t, 4> srcDims = 2148 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape()); 2149 srcDims.back() = srcDims.back() / shrinkRatio; 2150 VectorType newCastSrcType = 2151 VectorType::get(srcDims, castDstType.getElementType()); 2152 2153 auto newCastSrcOp = rewriter.create<vector::BitCastOp>( 2154 bitcastOp.getLoc(), newCastSrcType, insertOp.source()); 2155 2156 SmallVector<int64_t, 4> dstDims = 2157 llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); 2158 dstDims.back() = dstDims.back() / shrinkRatio; 2159 VectorType newCastDstType = 2160 VectorType::get(dstDims, castDstType.getElementType()); 2161 2162 auto newCastDstOp = rewriter.create<vector::BitCastOp>( 2163 bitcastOp.getLoc(), newCastDstType, insertOp.dest()); 2164 2165 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>( 2166 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, 2167 insertOp.strides()); 2168 2169 return success(); 2170 } 2171 }; 2172 2173 // Helper that returns a vector comparison that constructs a mask: 2174 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 2175 // 2176 // If `dim == 0` then the result will be a 0-D vector. 2177 // 2178 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 2179 // much more compact, IR for this operation, but LLVM eventually 2180 // generates more elaborate instructions for this intrinsic since it 2181 // is very conservative on the boundary conditions. 2182 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, 2183 bool indexOptimizations, int64_t dim, 2184 Value b, Value *off = nullptr) { 2185 auto loc = op->getLoc(); 2186 // If we can assume all indices fit in 32-bit, we perform the vector 2187 // comparison in 32-bit to get a higher degree of SIMD parallelism. 2188 // Otherwise we perform the vector comparison using 64-bit indices. 2189 Type idxType = 2190 indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type(); 2191 DenseIntElementsAttr indicesAttr; 2192 if (dim == 0 && indexOptimizations) { 2193 indicesAttr = DenseIntElementsAttr::get( 2194 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0}); 2195 } else if (dim == 0) { 2196 indicesAttr = DenseIntElementsAttr::get( 2197 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0}); 2198 } else if (indexOptimizations) { 2199 indicesAttr = rewriter.getI32VectorAttr( 2200 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))); 2201 } else { 2202 indicesAttr = rewriter.getI64VectorAttr( 2203 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))); 2204 } 2205 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 2206 // Add in an offset if requested. 2207 if (off) { 2208 Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off); 2209 Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o); 2210 indices = rewriter.create<arith::AddIOp>(loc, ov, indices); 2211 } 2212 // Construct the vector comparison. 2213 Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b); 2214 Value bounds = 2215 rewriter.create<vector::SplatOp>(loc, indices.getType(), bound); 2216 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices, 2217 bounds); 2218 } 2219 2220 template <typename ConcreteOp> 2221 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> { 2222 public: 2223 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt) 2224 : mlir::OpRewritePattern<ConcreteOp>(context), 2225 indexOptimizations(enableIndexOpt) {} 2226 2227 LogicalResult matchAndRewrite(ConcreteOp xferOp, 2228 PatternRewriter &rewriter) const override { 2229 if (!xferOp.hasOutOfBoundsDim()) 2230 return failure(); 2231 2232 if (xferOp.getVectorType().getRank() > 1 || 2233 llvm::size(xferOp.indices()) == 0) 2234 return failure(); 2235 2236 Location loc = xferOp->getLoc(); 2237 VectorType vtp = xferOp.getVectorType(); 2238 2239 // Create the in-bounds mask with all elements between [0 .. dim - offset) 2240 // set and [dim - offset .. vector_length) unset. 2241 // 2242 // TODO: when the leaf transfer rank is k > 1, we need the last `k` 2243 // dimensions here. 2244 unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 2245 Value off = xferOp.indices()[lastIndex]; 2246 Value dim = 2247 vector::createOrFoldDimOp(rewriter, loc, xferOp.source(), lastIndex); 2248 Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off); 2249 Value mask = rewriter.create<vector::CreateMaskOp>( 2250 loc, 2251 VectorType::get(vtp.getShape(), rewriter.getI1Type(), 2252 vtp.getNumScalableDims()), 2253 b); 2254 if (xferOp.mask()) { 2255 // Intersect the in-bounds with the mask specified as an op parameter. 2256 mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.mask()); 2257 } 2258 2259 rewriter.updateRootInPlace(xferOp, [&]() { 2260 xferOp.maskMutable().assign(mask); 2261 xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true})); 2262 }); 2263 2264 return success(); 2265 } 2266 2267 private: 2268 const bool indexOptimizations; 2269 }; 2270 2271 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only). 2272 class VectorCreateMaskOpConversion 2273 : public OpRewritePattern<vector::CreateMaskOp> { 2274 public: 2275 explicit VectorCreateMaskOpConversion(MLIRContext *context, 2276 bool enableIndexOpt) 2277 : mlir::OpRewritePattern<vector::CreateMaskOp>(context), 2278 indexOptimizations(enableIndexOpt) {} 2279 2280 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 2281 PatternRewriter &rewriter) const override { 2282 auto dstType = op.getType(); 2283 if (dstType.cast<VectorType>().isScalable()) 2284 return failure(); 2285 int64_t rank = dstType.getRank(); 2286 if (rank > 1) 2287 return failure(); 2288 rewriter.replaceOp( 2289 op, buildVectorComparison(rewriter, op, indexOptimizations, 2290 rank == 0 ? 0 : dstType.getDimSize(0), 2291 op.getOperand(0))); 2292 return success(); 2293 } 2294 2295 private: 2296 const bool indexOptimizations; 2297 }; 2298 2299 // Drop inner most contiguous unit dimensions from transfer_read operand. 2300 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> { 2301 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 2302 2303 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 2304 PatternRewriter &rewriter) const override { 2305 // TODO: support 0-d corner case. 2306 if (readOp.getTransferRank() == 0) 2307 return failure(); 2308 2309 // TODO: support mask. 2310 if (readOp.mask()) 2311 return failure(); 2312 2313 auto srcType = readOp.source().getType().dyn_cast<MemRefType>(); 2314 if (!srcType || !srcType.hasStaticShape()) 2315 return failure(); 2316 2317 if (!readOp.permutation_map().isMinorIdentity()) 2318 return failure(); 2319 2320 auto targetType = readOp.getVectorType(); 2321 if (targetType.getRank() <= 1) 2322 return failure(); 2323 2324 SmallVector<int64_t> srcStrides; 2325 int64_t srcOffset; 2326 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) 2327 return failure(); 2328 2329 size_t dimsToDrop = 0; 2330 for (size_t i = 1; i < srcStrides.size(); ++i) { 2331 int dim = srcType.getRank() - i - 1; 2332 if (srcStrides[dim] == 1) { 2333 dimsToDrop++; 2334 } else { 2335 break; 2336 } 2337 } 2338 if (dimsToDrop == 0) 2339 return failure(); 2340 2341 auto resultTargetVecType = 2342 VectorType::get(targetType.getShape().drop_back(dimsToDrop), 2343 targetType.getElementType()); 2344 2345 MemRefType resultMemrefType; 2346 if (srcType.getLayout().getAffineMap().isIdentity()) { 2347 resultMemrefType = MemRefType::get( 2348 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2349 {}, srcType.getMemorySpaceAsInt()); 2350 } else { 2351 AffineMap map = srcType.getLayout().getAffineMap(); 2352 int numResultDims = map.getNumDims() - dimsToDrop; 2353 int numSymbols = map.getNumSymbols(); 2354 for (size_t i = 0; i < dimsToDrop; ++i) { 2355 int dim = srcType.getRank() - i - 1; 2356 map = map.replace(rewriter.getAffineDimExpr(dim), 2357 rewriter.getAffineConstantExpr(0), numResultDims, 2358 numSymbols); 2359 } 2360 resultMemrefType = MemRefType::get( 2361 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2362 map, srcType.getMemorySpaceAsInt()); 2363 } 2364 2365 auto loc = readOp.getLoc(); 2366 SmallVector<int64_t> offsets(srcType.getRank(), 0); 2367 SmallVector<int64_t> strides(srcType.getRank(), 1); 2368 2369 ArrayAttr inBoundsAttr = 2370 readOp.in_bounds() 2371 ? rewriter.getArrayAttr( 2372 readOp.in_boundsAttr().getValue().drop_back(dimsToDrop)) 2373 : ArrayAttr(); 2374 Value rankedReducedView = rewriter.create<memref::SubViewOp>( 2375 loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(), 2376 strides); 2377 auto permMap = getTransferMinorIdentityMap( 2378 rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType); 2379 Value result = rewriter.create<vector::TransferReadOp>( 2380 loc, resultTargetVecType, rankedReducedView, 2381 readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), 2382 readOp.padding(), 2383 // TODO: support mask. 2384 /*mask=*/Value(), inBoundsAttr); 2385 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType, 2386 result); 2387 return success(); 2388 } 2389 }; 2390 2391 namespace { 2392 2393 /// This function checks to see if the vector combining kind 2394 /// is consistent with the integer or float element type. 2395 static bool isValidKind(bool isInt, vector::CombiningKind kind) { 2396 using vector::CombiningKind; 2397 enum class KindType { FLOAT, INT, INVALID }; 2398 KindType type{KindType::INVALID}; 2399 switch (kind) { 2400 case CombiningKind::MINF: 2401 case CombiningKind::MAXF: 2402 type = KindType::FLOAT; 2403 break; 2404 case CombiningKind::MINUI: 2405 case CombiningKind::MINSI: 2406 case CombiningKind::MAXUI: 2407 case CombiningKind::MAXSI: 2408 case CombiningKind::AND: 2409 case CombiningKind::OR: 2410 case CombiningKind::XOR: 2411 type = KindType::INT; 2412 break; 2413 case CombiningKind::ADD: 2414 case CombiningKind::MUL: 2415 type = isInt ? KindType::INT : KindType::FLOAT; 2416 break; 2417 } 2418 bool isValidIntKind = (type == KindType::INT) && isInt; 2419 bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt); 2420 return (isValidIntKind || isValidFloatKind); 2421 } 2422 2423 /// This function constructs the appropriate integer or float 2424 /// operation given the vector combining kind and operands. The 2425 /// supported int operations are : add, mul, min (signed/unsigned), 2426 /// max(signed/unsigned), and, or, xor. The supported float 2427 /// operations are : add, mul, min and max. 2428 static Value genOperator(Location loc, Value x, Value y, 2429 vector::CombiningKind kind, 2430 PatternRewriter &rewriter) { 2431 using vector::CombiningKind; 2432 2433 auto elType = x.getType().cast<VectorType>().getElementType(); 2434 bool isInt = elType.isIntOrIndex(); 2435 2436 Value combinedResult{nullptr}; 2437 switch (kind) { 2438 case CombiningKind::ADD: 2439 if (isInt) 2440 combinedResult = rewriter.create<arith::AddIOp>(loc, x, y); 2441 else 2442 combinedResult = rewriter.create<arith::AddFOp>(loc, x, y); 2443 break; 2444 case CombiningKind::MUL: 2445 if (isInt) 2446 combinedResult = rewriter.create<arith::MulIOp>(loc, x, y); 2447 else 2448 combinedResult = rewriter.create<arith::MulFOp>(loc, x, y); 2449 break; 2450 case CombiningKind::MINUI: 2451 combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y); 2452 break; 2453 case CombiningKind::MINSI: 2454 combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y); 2455 break; 2456 case CombiningKind::MAXUI: 2457 combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y); 2458 break; 2459 case CombiningKind::MAXSI: 2460 combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y); 2461 break; 2462 case CombiningKind::AND: 2463 combinedResult = rewriter.create<arith::AndIOp>(loc, x, y); 2464 break; 2465 case CombiningKind::OR: 2466 combinedResult = rewriter.create<arith::OrIOp>(loc, x, y); 2467 break; 2468 case CombiningKind::XOR: 2469 combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y); 2470 break; 2471 case CombiningKind::MINF: 2472 combinedResult = rewriter.create<arith::MinFOp>(loc, x, y); 2473 break; 2474 case CombiningKind::MAXF: 2475 combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y); 2476 break; 2477 } 2478 return combinedResult; 2479 } 2480 2481 /// Convert vector.scan op into arith ops and 2482 /// vector.insert_strided_slice/extract_strided_slice 2483 /// 2484 /// Ex: 2485 /// ``` 2486 /// %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 2487 /// 1} : 2488 /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>) 2489 /// ``` 2490 /// Gets converted to: 2491 /// ``` 2492 /// %cst = arith.constant dense<0> : vector<2x3xi32> 2493 /// %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1], 2494 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 = 2495 /// vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]} 2496 /// : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice 2497 /// %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : 2498 /// vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 : 2499 /// vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1], 2500 /// strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 = 2501 /// vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1], 2502 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3, 2503 /// %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets = 2504 /// [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 = 2505 /// vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 : 2506 /// vector<2x3xi32>, vector<2xi32> 2507 /// ``` 2508 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { 2509 using OpRewritePattern<vector::ScanOp>::OpRewritePattern; 2510 2511 LogicalResult matchAndRewrite(vector::ScanOp scanOp, 2512 PatternRewriter &rewriter) const override { 2513 auto loc = scanOp.getLoc(); 2514 VectorType destType = scanOp.getDestType(); 2515 ArrayRef<int64_t> destShape = destType.getShape(); 2516 auto elType = destType.getElementType(); 2517 bool isInt = elType.isIntOrIndex(); 2518 if (!isValidKind(isInt, scanOp.kind())) 2519 return failure(); 2520 2521 VectorType resType = VectorType::get(destShape, elType); 2522 Value result = rewriter.create<arith::ConstantOp>( 2523 loc, resType, rewriter.getZeroAttr(resType)); 2524 int64_t reductionDim = scanOp.reduction_dim(); 2525 bool inclusive = scanOp.inclusive(); 2526 int64_t destRank = destType.getRank(); 2527 VectorType initialValueType = scanOp.getInitialValueType(); 2528 int64_t initialValueRank = initialValueType.getRank(); 2529 2530 SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end()); 2531 reductionShape[reductionDim] = 1; 2532 VectorType reductionType = VectorType::get(reductionShape, elType); 2533 SmallVector<int64_t> offsets(destRank, 0); 2534 SmallVector<int64_t> strides(destRank, 1); 2535 SmallVector<int64_t> sizes(destShape.begin(), destShape.end()); 2536 sizes[reductionDim] = 1; 2537 ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes); 2538 ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides); 2539 2540 Value lastOutput, lastInput; 2541 for (int i = 0; i < destShape[reductionDim]; i++) { 2542 offsets[reductionDim] = i; 2543 ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets); 2544 Value input = rewriter.create<vector::ExtractStridedSliceOp>( 2545 loc, reductionType, scanOp.source(), scanOffsets, scanSizes, 2546 scanStrides); 2547 Value output; 2548 if (i == 0) { 2549 if (inclusive) { 2550 output = input; 2551 } else { 2552 if (initialValueRank == 0) { 2553 // ShapeCastOp cannot handle 0-D vectors 2554 output = rewriter.create<vector::BroadcastOp>( 2555 loc, input.getType(), scanOp.initial_value()); 2556 } else { 2557 output = rewriter.create<vector::ShapeCastOp>( 2558 loc, input.getType(), scanOp.initial_value()); 2559 } 2560 } 2561 } else { 2562 Value y = inclusive ? input : lastInput; 2563 output = genOperator(loc, lastOutput, y, scanOp.kind(), rewriter); 2564 assert(output != nullptr); 2565 } 2566 result = rewriter.create<vector::InsertStridedSliceOp>( 2567 loc, output, result, offsets, strides); 2568 lastOutput = output; 2569 lastInput = input; 2570 } 2571 2572 Value reduction; 2573 if (initialValueRank == 0) { 2574 Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0); 2575 reduction = 2576 rewriter.create<vector::BroadcastOp>(loc, initialValueType, v); 2577 } else { 2578 reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType, 2579 lastOutput); 2580 } 2581 2582 rewriter.replaceOp(scanOp, {result, reduction}); 2583 return success(); 2584 } 2585 }; 2586 2587 } // namespace 2588 2589 void mlir::vector::populateVectorMaskMaterializationPatterns( 2590 RewritePatternSet &patterns, bool indexOptimizations) { 2591 patterns.add<VectorCreateMaskOpConversion, 2592 MaterializeTransferMask<vector::TransferReadOp>, 2593 MaterializeTransferMask<vector::TransferWriteOp>>( 2594 patterns.getContext(), indexOptimizations); 2595 } 2596 2597 void mlir::vector::populateShapeCastFoldingPatterns( 2598 RewritePatternSet &patterns) { 2599 patterns.add<ShapeCastOpFolder>(patterns.getContext()); 2600 } 2601 2602 void mlir::vector::populateBubbleVectorBitCastOpPatterns( 2603 RewritePatternSet &patterns) { 2604 patterns.add<BubbleDownVectorBitCastForExtract, 2605 BubbleDownBitCastForStridedSliceExtract, 2606 BubbleUpBitCastForStridedSliceInsert>(patterns.getContext()); 2607 } 2608 2609 void mlir::vector::populateVectorBroadcastLoweringPatterns( 2610 RewritePatternSet &patterns) { 2611 patterns.add<BroadcastOpLowering>(patterns.getContext()); 2612 } 2613 2614 void mlir::vector::populateVectorMaskOpLoweringPatterns( 2615 RewritePatternSet &patterns) { 2616 patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>( 2617 patterns.getContext()); 2618 } 2619 2620 void mlir::vector::populateVectorShapeCastLoweringPatterns( 2621 RewritePatternSet &patterns) { 2622 patterns.add<ShapeCastOp2DDownCastRewritePattern, 2623 ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>( 2624 patterns.getContext()); 2625 } 2626 2627 void mlir::vector::populateVectorContractLoweringPatterns( 2628 RewritePatternSet &patterns, VectorTransformsOptions options) { 2629 patterns.add<OuterProductOpLowering>(patterns.getContext()); 2630 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering, 2631 ContractionOpToOuterProductOpLowering>(options, 2632 patterns.getContext()); 2633 } 2634 2635 void mlir::vector::populateVectorTransposeLoweringPatterns( 2636 RewritePatternSet &patterns, VectorTransformsOptions options) { 2637 patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>( 2638 options, patterns.getContext()); 2639 } 2640 2641 void mlir::vector::populateVectorReductionToContractPatterns( 2642 RewritePatternSet &patterns) { 2643 patterns.add<MultiReduceToContract, CombineContractBroadcast, 2644 CombineContractTranspose, ReorderCastOpsOnBroadcast, 2645 ReorderCastOpsOnTranspose>(patterns.getContext()); 2646 } 2647 2648 void mlir::vector:: 2649 populateVectorTransferCollapseInnerMostContiguousDimsPatterns( 2650 RewritePatternSet &patterns) { 2651 patterns.add<DropInnerMostUnitDims>(patterns.getContext()); 2652 } 2653 2654 void mlir::vector::populateVectorTransferLoweringPatterns( 2655 RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) { 2656 patterns.add<TransferReadToVectorLoadLowering, 2657 TransferWriteToVectorStoreLowering>(patterns.getContext(), 2658 maxTransferRank); 2659 patterns 2660 .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>( 2661 patterns.getContext()); 2662 } 2663 2664 void mlir::vector::populateVectorScanLoweringPatterns( 2665 RewritePatternSet &patterns) { 2666 patterns.add<ScanToArithOps>(patterns.getContext()); 2667 } 2668