1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites as 1->N patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 14 15 #include <type_traits> 16 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 19 #include "mlir/Dialect/Linalg/IR/Linalg.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/SCF/SCF.h" 22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 23 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 24 #include "mlir/IR/ImplicitLocOpBuilder.h" 25 #include "mlir/IR/Matchers.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "mlir/Interfaces/VectorInterfaces.h" 28 29 #include "llvm/ADT/DenseSet.h" 30 #include "llvm/ADT/MapVector.h" 31 #include "llvm/ADT/STLExtras.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/raw_ostream.h" 35 36 #define DEBUG_TYPE "vector-to-vector" 37 38 using namespace mlir; 39 using namespace mlir::vector; 40 41 // Helper to find an index in an affine map. 42 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) { 43 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 44 int64_t idx = map.getDimPosition(i); 45 if (idx == index) 46 return i; 47 } 48 return None; 49 } 50 51 // Helper to construct iterator types with one index removed. 52 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes, 53 int64_t index) { 54 SmallVector<Attribute, 4> results; 55 for (const auto &it : llvm::enumerate(iteratorTypes)) { 56 int64_t idx = it.index(); 57 if (idx == index) 58 continue; 59 results.push_back(it.value()); 60 } 61 return results; 62 } 63 64 // Helper to construct an affine map with one index removed. 65 static AffineMap adjustMap(AffineMap map, int64_t index, 66 PatternRewriter &rewriter) { 67 auto *ctx = rewriter.getContext(); 68 SmallVector<AffineExpr, 4> results; 69 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 70 int64_t idx = map.getDimPosition(i); 71 if (idx == index) 72 continue; 73 // Re-insert remaining indices, but renamed when occurring 74 // after the removed index. 75 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); 76 results.push_back(targetExpr); 77 } 78 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); 79 } 80 81 // Helper method to possibly drop a dimension in a load. 82 // TODO 83 static Value reshapeLoad(Location loc, Value val, VectorType type, 84 int64_t index, int64_t pos, 85 PatternRewriter &rewriter) { 86 if (index == -1) 87 return val; 88 Type lowType = VectorType::Builder(type).dropDim(0); 89 // At extraction dimension? 90 if (index == 0) { 91 auto posAttr = rewriter.getI64ArrayAttr(pos); 92 return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr); 93 } 94 // Unroll leading dimensions. 95 VectorType vType = lowType.cast<VectorType>(); 96 Type resType = VectorType::Builder(type).dropDim(index); 97 auto resVectorType = resType.cast<VectorType>(); 98 Value result = rewriter.create<arith::ConstantOp>( 99 loc, resVectorType, rewriter.getZeroAttr(resVectorType)); 100 for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { 101 auto posAttr = rewriter.getI64ArrayAttr(d); 102 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr); 103 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); 104 result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, 105 posAttr); 106 } 107 return result; 108 } 109 110 // Helper method to possibly drop a dimension in a store. 111 // TODO 112 static Value reshapeStore(Location loc, Value val, Value result, 113 VectorType type, int64_t index, int64_t pos, 114 PatternRewriter &rewriter) { 115 // Unmodified? 116 if (index == -1) 117 return val; 118 // At insertion dimension? 119 if (index == 0) { 120 auto posAttr = rewriter.getI64ArrayAttr(pos); 121 return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr); 122 } 123 // Unroll leading dimensions. 124 Type lowType = VectorType::Builder(type).dropDim(0); 125 VectorType vType = lowType.cast<VectorType>(); 126 Type insType = VectorType::Builder(vType).dropDim(0); 127 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { 128 auto posAttr = rewriter.getI64ArrayAttr(d); 129 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr); 130 Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr); 131 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); 132 result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr); 133 } 134 return result; 135 } 136 137 template <typename IntType> 138 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) { 139 return llvm::to_vector<4>(llvm::map_range( 140 arrayAttr.getAsRange<IntegerAttr>(), 141 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); 142 } 143 144 namespace { 145 146 /// ShapeCastOpFolder folds cancelling ShapeCastOps away. 147 // 148 // Example: 149 // 150 // The following MLIR with cancelling ShapeCastOps: 151 // 152 // %0 = source : vector<5x4x2xf32> 153 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> 154 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> 155 // %3 = user %2 : vector<5x4x2xf32> 156 // 157 // Should canonicalize to the following: 158 // 159 // %0 = source : vector<5x4x2xf32> 160 // %1 = user %0 : vector<5x4x2xf32> 161 // 162 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> { 163 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 164 165 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, 166 PatternRewriter &rewriter) const override { 167 // Check if 'shapeCastOp' has vector source/result type. 168 auto sourceVectorType = 169 shapeCastOp.source().getType().dyn_cast_or_null<VectorType>(); 170 auto resultVectorType = 171 shapeCastOp.result().getType().dyn_cast_or_null<VectorType>(); 172 if (!sourceVectorType || !resultVectorType) 173 return failure(); 174 175 // Check if shape cast op source operand is also a shape cast op. 176 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>( 177 shapeCastOp.source().getDefiningOp()); 178 if (!sourceShapeCastOp) 179 return failure(); 180 auto operandSourceVectorType = 181 sourceShapeCastOp.source().getType().cast<VectorType>(); 182 auto operandResultVectorType = sourceShapeCastOp.getType(); 183 184 // Check if shape cast operations invert each other. 185 if (operandSourceVectorType != resultVectorType || 186 operandResultVectorType != sourceVectorType) 187 return failure(); 188 189 rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source()); 190 return success(); 191 } 192 }; 193 194 /// Progressive lowering of BroadcastOp. 195 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> { 196 public: 197 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern; 198 199 LogicalResult matchAndRewrite(vector::BroadcastOp op, 200 PatternRewriter &rewriter) const override { 201 auto loc = op.getLoc(); 202 VectorType dstType = op.getVectorType(); 203 VectorType srcType = op.getSourceType().dyn_cast<VectorType>(); 204 Type eltType = dstType.getElementType(); 205 206 // Scalar to any vector can use splat. 207 if (!srcType) { 208 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.source()); 209 return success(); 210 } 211 212 // Determine rank of source and destination. 213 int64_t srcRank = srcType.getRank(); 214 int64_t dstRank = dstType.getRank(); 215 216 // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. 217 if (srcRank <= 1 && dstRank == 1) { 218 Value ext; 219 if (srcRank == 0) 220 ext = rewriter.create<vector::ExtractElementOp>(loc, op.source()); 221 else 222 ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0); 223 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext); 224 return success(); 225 } 226 227 // Duplicate this rank. 228 // For example: 229 // %x = broadcast %y : k-D to n-D, k < n 230 // becomes: 231 // %b = broadcast %y : k-D to (n-1)-D 232 // %x = [%b,%b,%b,%b] : n-D 233 // becomes: 234 // %b = [%y,%y] : (n-1)-D 235 // %x = [%b,%b,%b,%b] : n-D 236 if (srcRank < dstRank) { 237 // Duplication. 238 VectorType resType = 239 VectorType::get(dstType.getShape().drop_front(), eltType); 240 Value bcst = 241 rewriter.create<vector::BroadcastOp>(loc, resType, op.source()); 242 Value result = rewriter.create<arith::ConstantOp>( 243 loc, dstType, rewriter.getZeroAttr(dstType)); 244 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 245 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 246 rewriter.replaceOp(op, result); 247 return success(); 248 } 249 250 // Find non-matching dimension, if any. 251 assert(srcRank == dstRank); 252 int64_t m = -1; 253 for (int64_t r = 0; r < dstRank; r++) 254 if (srcType.getDimSize(r) != dstType.getDimSize(r)) { 255 m = r; 256 break; 257 } 258 259 // All trailing dimensions are the same. Simply pass through. 260 if (m == -1) { 261 rewriter.replaceOp(op, op.source()); 262 return success(); 263 } 264 265 // Any non-matching dimension forces a stretch along this rank. 266 // For example: 267 // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32> 268 // becomes: 269 // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32> 270 // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32> 271 // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32> 272 // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32> 273 // %x = [%a,%b,%c,%d] 274 // becomes: 275 // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32> 276 // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32> 277 // %a = [%u, %v] 278 // .. 279 // %x = [%a,%b,%c,%d] 280 VectorType resType = 281 VectorType::get(dstType.getShape().drop_front(), eltType); 282 Value result = rewriter.create<arith::ConstantOp>( 283 loc, dstType, rewriter.getZeroAttr(dstType)); 284 if (m == 0) { 285 // Stetch at start. 286 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0); 287 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 288 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 289 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 290 } else { 291 // Stetch not at start. 292 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) { 293 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d); 294 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 295 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 296 } 297 } 298 rewriter.replaceOp(op, result); 299 return success(); 300 } 301 }; 302 303 /// Return the number of leftmost dimensions from the first rightmost transposed 304 /// dimension found in 'transpose'. 305 size_t getNumDimsFromFirstTransposedDim(ArrayRef<int64_t> transpose) { 306 size_t numTransposedDims = transpose.size(); 307 for (size_t transpDim : llvm::reverse(transpose)) { 308 if (transpDim != numTransposedDims - 1) 309 break; 310 numTransposedDims--; 311 } 312 return numTransposedDims; 313 } 314 315 /// Progressive lowering of TransposeOp. 316 /// One: 317 /// %x = vector.transpose %y, [1, 0] 318 /// is replaced by: 319 /// %z = arith.constant dense<0.000000e+00> 320 /// %0 = vector.extract %y[0, 0] 321 /// %1 = vector.insert %0, %z [0, 0] 322 /// .. 323 /// %x = vector.insert .., .. [.., ..] 324 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { 325 public: 326 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 327 328 TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions, 329 MLIRContext *context) 330 : OpRewritePattern<vector::TransposeOp>(context), 331 vectorTransformOptions(vectorTransformOptions) {} 332 333 LogicalResult matchAndRewrite(vector::TransposeOp op, 334 PatternRewriter &rewriter) const override { 335 auto loc = op.getLoc(); 336 337 VectorType resType = op.getResultType(); 338 339 // Set up convenience transposition table. 340 SmallVector<int64_t, 4> transp; 341 for (auto attr : op.transp()) 342 transp.push_back(attr.cast<IntegerAttr>().getInt()); 343 344 if (vectorTransformOptions.vectorTransposeLowering == 345 vector::VectorTransposeLowering::Shuffle && 346 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) 347 return rewriter.notifyMatchFailure( 348 op, "Options specifies lowering to shuffle"); 349 350 // Handle a true 2-D matrix transpose differently when requested. 351 if (vectorTransformOptions.vectorTransposeLowering == 352 vector::VectorTransposeLowering::Flat && 353 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { 354 Type flattenedType = 355 VectorType::get(resType.getNumElements(), resType.getElementType()); 356 auto matrix = 357 rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector()); 358 auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); 359 auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); 360 Value trans = rewriter.create<vector::FlatTransposeOp>( 361 loc, flattenedType, matrix, rows, columns); 362 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans); 363 return success(); 364 } 365 366 // Generate unrolled extract/insert ops. We do not unroll the rightmost 367 // (i.e., highest-order) dimensions that are not transposed and leave them 368 // in vector form to improve performance. 369 size_t numLeftmostTransposedDims = getNumDimsFromFirstTransposedDim(transp); 370 371 // The type of the extract operation will be scalar if all the dimensions 372 // are unrolled. Otherwise, it will be a vector with the shape of the 373 // dimensions that are not transposed. 374 Type extractType = 375 numLeftmostTransposedDims == transp.size() 376 ? resType.getElementType() 377 : VectorType::Builder(resType).setShape( 378 resType.getShape().drop_front(numLeftmostTransposedDims)); 379 380 Value result = rewriter.create<arith::ConstantOp>( 381 loc, resType, rewriter.getZeroAttr(resType)); 382 SmallVector<int64_t, 4> lhs(numLeftmostTransposedDims, 0); 383 SmallVector<int64_t, 4> rhs(numLeftmostTransposedDims, 0); 384 rewriter.replaceOp(op, expandIndices(loc, resType, extractType, 0, 385 numLeftmostTransposedDims, transp, lhs, 386 rhs, op.vector(), result, rewriter)); 387 return success(); 388 } 389 390 private: 391 // Builds the indices arrays for the lhs and rhs. Generates the extract/insert 392 // operations when all the ranks go over the last dimension being transposed. 393 Value expandIndices(Location loc, VectorType resType, Type extractType, 394 int64_t pos, int64_t numLeftmostTransposedDims, 395 SmallVector<int64_t, 4> &transp, 396 SmallVector<int64_t, 4> &lhs, 397 SmallVector<int64_t, 4> &rhs, Value input, Value result, 398 PatternRewriter &rewriter) const { 399 if (pos >= numLeftmostTransposedDims) { 400 auto ridx = rewriter.getI64ArrayAttr(rhs); 401 auto lidx = rewriter.getI64ArrayAttr(lhs); 402 Value e = 403 rewriter.create<vector::ExtractOp>(loc, extractType, input, ridx); 404 return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx); 405 } 406 for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) { 407 lhs[pos] = d; 408 rhs[transp[pos]] = d; 409 result = expandIndices(loc, resType, extractType, pos + 1, 410 numLeftmostTransposedDims, transp, lhs, rhs, input, 411 result, rewriter); 412 } 413 return result; 414 } 415 416 /// Options to control the vector patterns. 417 vector::VectorTransformsOptions vectorTransformOptions; 418 }; 419 420 /// Rewrite a 2-D vector.transpose as a sequence of: 421 /// vector.shape_cast 2D -> 1D 422 /// vector.shuffle 423 /// vector.shape_cast 1D -> 2D 424 class TransposeOp2DToShuffleLowering 425 : public OpRewritePattern<vector::TransposeOp> { 426 public: 427 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 428 429 TransposeOp2DToShuffleLowering( 430 vector::VectorTransformsOptions vectorTransformOptions, 431 MLIRContext *context) 432 : OpRewritePattern<vector::TransposeOp>(context), 433 vectorTransformOptions(vectorTransformOptions) {} 434 435 LogicalResult matchAndRewrite(vector::TransposeOp op, 436 PatternRewriter &rewriter) const override { 437 auto loc = op.getLoc(); 438 439 VectorType srcType = op.getVectorType(); 440 if (srcType.getRank() != 2) 441 return rewriter.notifyMatchFailure(op, "Not a 2D transpose"); 442 443 SmallVector<int64_t, 4> transp; 444 for (auto attr : op.transp()) 445 transp.push_back(attr.cast<IntegerAttr>().getInt()); 446 if (transp[0] != 1 && transp[1] != 0) 447 return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation"); 448 449 if (vectorTransformOptions.vectorTransposeLowering != 450 VectorTransposeLowering::Shuffle) 451 return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle"); 452 453 int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); 454 Value casted = rewriter.create<vector::ShapeCastOp>( 455 loc, VectorType::get({m * n}, srcType.getElementType()), op.vector()); 456 SmallVector<int64_t> mask; 457 mask.reserve(m * n); 458 for (int64_t j = 0; j < n; ++j) 459 for (int64_t i = 0; i < m; ++i) 460 mask.push_back(i * n + j); 461 462 Value shuffled = 463 rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask); 464 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(), 465 shuffled); 466 467 return success(); 468 } 469 470 private: 471 /// Options to control the vector patterns. 472 vector::VectorTransformsOptions vectorTransformOptions; 473 }; 474 475 /// Progressive lowering of OuterProductOp. 476 /// One: 477 /// %x = vector.outerproduct %lhs, %rhs, %acc 478 /// is replaced by: 479 /// %z = zero-result 480 /// %0 = vector.extract %lhs[0] 481 /// %1 = vector.broadcast %0 482 /// %2 = vector.extract %acc[0] 483 /// %3 = vector.fma %1, %rhs, %2 484 /// %4 = vector.insert %3, %z[0] 485 /// .. 486 /// %x = vector.insert %.., %..[N-1] 487 /// 488 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { 489 public: 490 using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern; 491 492 LogicalResult matchAndRewrite(vector::OuterProductOp op, 493 PatternRewriter &rewriter) const override { 494 auto loc = op.getLoc(); 495 496 VectorType lhsType = op.getOperandVectorTypeLHS(); 497 VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>(); 498 VectorType resType = op.getVectorType(); 499 Type eltType = resType.getElementType(); 500 bool isInt = eltType.isa<IntegerType, IndexType>(); 501 Value acc = (op.acc().empty()) ? nullptr : op.acc()[0]; 502 vector::CombiningKind kind = op.kind(); 503 504 if (!rhsType) { 505 // Special case: AXPY operation. 506 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs()); 507 Optional<Value> mult = 508 isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter) 509 : genMultF(loc, op.lhs(), b, acc, kind, rewriter); 510 if (!mult.hasValue()) 511 return failure(); 512 rewriter.replaceOp(op, mult.getValue()); 513 return success(); 514 } 515 516 Value result = rewriter.create<arith::ConstantOp>( 517 loc, resType, rewriter.getZeroAttr(resType)); 518 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { 519 auto pos = rewriter.getI64ArrayAttr(d); 520 Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos); 521 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x); 522 Value r = nullptr; 523 if (acc) 524 r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos); 525 Optional<Value> m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter) 526 : genMultF(loc, a, op.rhs(), r, kind, rewriter); 527 if (!m.hasValue()) 528 return failure(); 529 result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(), 530 result, pos); 531 } 532 rewriter.replaceOp(op, result); 533 return success(); 534 } 535 536 private: 537 static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc, 538 vector::CombiningKind kind, 539 PatternRewriter &rewriter) { 540 using vector::CombiningKind; 541 542 auto mul = rewriter.create<arith::MulIOp>(loc, x, y); 543 if (!acc) 544 return Optional<Value>(mul); 545 546 if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) 547 // Only valid for floating point types. 548 return Optional<Value>(); 549 550 return makeArithReduction(rewriter, loc, kind, mul, acc); 551 } 552 553 static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc, 554 vector::CombiningKind kind, 555 PatternRewriter &rewriter) { 556 using vector::CombiningKind; 557 558 // Special case for fused multiply-add. 559 if (acc && kind == CombiningKind::ADD) { 560 return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc)); 561 } 562 563 auto mul = rewriter.create<arith::MulFOp>(loc, x, y); 564 565 if (!acc) 566 return Optional<Value>(mul); 567 568 if (kind == CombiningKind::ADD || kind == CombiningKind::AND || 569 kind == CombiningKind::MINUI || kind == CombiningKind::MINSI || 570 kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI || 571 kind == CombiningKind::OR || kind == CombiningKind::XOR) 572 // Already handled or only valid for integer types. 573 return Optional<Value>(); 574 575 return makeArithReduction(rewriter, loc, kind, mul, acc); 576 } 577 }; 578 579 /// Progressive lowering of ConstantMaskOp. 580 /// One: 581 /// %x = vector.constant_mask [a,b] 582 /// is replaced by: 583 /// %z = zero-result 584 /// %l = vector.constant_mask [b] 585 /// %4 = vector.insert %l, %z[0] 586 /// .. 587 /// %x = vector.insert %l, %..[a-1] 588 /// until a one-dimensional vector is reached. All these operations 589 /// will be folded at LLVM IR level. 590 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { 591 public: 592 using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern; 593 594 LogicalResult matchAndRewrite(vector::ConstantMaskOp op, 595 PatternRewriter &rewriter) const override { 596 auto loc = op.getLoc(); 597 auto dstType = op.getType(); 598 auto eltType = dstType.getElementType(); 599 auto dimSizes = op.mask_dim_sizes(); 600 int64_t rank = dstType.getRank(); 601 602 if (rank == 0) { 603 assert(dimSizes.size() == 1 && 604 "Expected exactly one dim size for a 0-D vector"); 605 bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1; 606 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 607 op, dstType, 608 DenseIntElementsAttr::get( 609 VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()), 610 ArrayRef<bool>{value})); 611 return success(); 612 } 613 614 int64_t trueDim = std::min(dstType.getDimSize(0), 615 dimSizes[0].cast<IntegerAttr>().getInt()); 616 617 if (rank == 1) { 618 // Express constant 1-D case in explicit vector form: 619 // [T,..,T,F,..,F]. 620 SmallVector<bool, 4> values(dstType.getDimSize(0)); 621 for (int64_t d = 0; d < trueDim; d++) 622 values[d] = true; 623 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 624 op, dstType, rewriter.getBoolVectorAttr(values)); 625 return success(); 626 } 627 628 VectorType lowType = 629 VectorType::get(dstType.getShape().drop_front(), eltType); 630 SmallVector<int64_t, 4> newDimSizes; 631 for (int64_t r = 1; r < rank; r++) 632 newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt()); 633 Value trueVal = rewriter.create<vector::ConstantMaskOp>( 634 loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); 635 Value result = rewriter.create<arith::ConstantOp>( 636 loc, dstType, rewriter.getZeroAttr(dstType)); 637 for (int64_t d = 0; d < trueDim; d++) { 638 auto pos = rewriter.getI64ArrayAttr(d); 639 result = 640 rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos); 641 } 642 rewriter.replaceOp(op, result); 643 return success(); 644 } 645 }; 646 647 /// Progressive lowering of CreateMaskOp. 648 /// One: 649 /// %x = vector.create_mask %a, ... : vector<dx...> 650 /// is replaced by: 651 /// %l = vector.create_mask ... : vector<...> ; one lower rank 652 /// %0 = arith.cmpi "slt", %ci, %a | 653 /// %1 = select %0, %l, %zeroes | 654 /// %r = vector.insert %1, %pr [i] | d-times 655 /// %x = .... 656 /// until a one-dimensional vector is reached. 657 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { 658 public: 659 using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern; 660 661 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 662 PatternRewriter &rewriter) const override { 663 auto dstType = op.getResult().getType().cast<VectorType>(); 664 int64_t rank = dstType.getRank(); 665 if (rank <= 1) 666 return rewriter.notifyMatchFailure( 667 op, "0-D and 1-D vectors are handled separately"); 668 669 auto loc = op.getLoc(); 670 auto eltType = dstType.getElementType(); 671 int64_t dim = dstType.getDimSize(0); 672 Value idx = op.getOperand(0); 673 674 VectorType lowType = 675 VectorType::get(dstType.getShape().drop_front(), eltType); 676 Value trueVal = rewriter.create<vector::CreateMaskOp>( 677 loc, lowType, op.getOperands().drop_front()); 678 Value falseVal = rewriter.create<arith::ConstantOp>( 679 loc, lowType, rewriter.getZeroAttr(lowType)); 680 Value result = rewriter.create<arith::ConstantOp>( 681 loc, dstType, rewriter.getZeroAttr(dstType)); 682 for (int64_t d = 0; d < dim; d++) { 683 Value bnd = 684 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d)); 685 Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 686 bnd, idx); 687 Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal); 688 auto pos = rewriter.getI64ArrayAttr(d); 689 result = 690 rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos); 691 } 692 rewriter.replaceOp(op, result); 693 return success(); 694 } 695 }; 696 697 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D 698 /// vectors progressively on the way to target llvm.matrix intrinsics. 699 /// This iterates over the most major dimension of the 2-D vector and performs 700 /// rewrites into: 701 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D 702 class ShapeCastOp2DDownCastRewritePattern 703 : public OpRewritePattern<vector::ShapeCastOp> { 704 public: 705 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 706 707 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 708 PatternRewriter &rewriter) const override { 709 auto sourceVectorType = op.getSourceVectorType(); 710 auto resultVectorType = op.getResultVectorType(); 711 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) 712 return failure(); 713 714 auto loc = op.getLoc(); 715 Value desc = rewriter.create<arith::ConstantOp>( 716 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 717 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; 718 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { 719 Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i); 720 desc = rewriter.create<vector::InsertStridedSliceOp>( 721 loc, vec, desc, 722 /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); 723 } 724 rewriter.replaceOp(op, desc); 725 return success(); 726 } 727 }; 728 729 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D 730 /// vectors progressively. 731 /// This iterates over the most major dimension of the 2-D vector and performs 732 /// rewrites into: 733 /// vector.extract_strided_slice from 1-D + vector.insert into 2-D 734 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. 735 class ShapeCastOp2DUpCastRewritePattern 736 : public OpRewritePattern<vector::ShapeCastOp> { 737 public: 738 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 739 740 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 741 PatternRewriter &rewriter) const override { 742 auto sourceVectorType = op.getSourceVectorType(); 743 auto resultVectorType = op.getResultVectorType(); 744 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) 745 return failure(); 746 747 auto loc = op.getLoc(); 748 Value desc = rewriter.create<arith::ConstantOp>( 749 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 750 unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; 751 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { 752 Value vec = rewriter.create<vector::ExtractStridedSliceOp>( 753 loc, op.source(), /*offsets=*/i * mostMinorVectorSize, 754 /*sizes=*/mostMinorVectorSize, 755 /*strides=*/1); 756 desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i); 757 } 758 rewriter.replaceOp(op, desc); 759 return success(); 760 } 761 }; 762 763 // We typically should not lower general shape cast operations into data 764 // movement instructions, since the assumption is that these casts are 765 // optimized away during progressive lowering. For completeness, however, 766 // we fall back to a reference implementation that moves all elements 767 // into the right place if we get here. 768 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { 769 public: 770 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 771 772 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 773 PatternRewriter &rewriter) const override { 774 Location loc = op.getLoc(); 775 auto sourceVectorType = op.getSourceVectorType(); 776 auto resultVectorType = op.getResultVectorType(); 777 778 // Special case 2D/1D lowerings with better implementations. 779 // TODO: make is ND/1D to allow generic ND->1D->MD. 780 int64_t srcRank = sourceVectorType.getRank(); 781 int64_t resRank = resultVectorType.getRank(); 782 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2)) 783 return failure(); 784 785 // Generic ShapeCast lowering path goes all the way down to unrolled scalar 786 // extract/insert chains. 787 // TODO: consider evolving the semantics to only allow 1D source or dest and 788 // drop this potentially very expensive lowering. 789 // Compute number of elements involved in the reshape. 790 int64_t numElts = 1; 791 for (int64_t r = 0; r < srcRank; r++) 792 numElts *= sourceVectorType.getDimSize(r); 793 // Replace with data movement operations: 794 // x[0,0,0] = y[0,0] 795 // x[0,0,1] = y[0,1] 796 // x[0,1,0] = y[0,2] 797 // etc., incrementing the two index vectors "row-major" 798 // within the source and result shape. 799 SmallVector<int64_t, 4> srcIdx(srcRank); 800 SmallVector<int64_t, 4> resIdx(resRank); 801 Value result = rewriter.create<arith::ConstantOp>( 802 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 803 for (int64_t i = 0; i < numElts; i++) { 804 if (i != 0) { 805 incIdx(srcIdx, sourceVectorType, srcRank - 1); 806 incIdx(resIdx, resultVectorType, resRank - 1); 807 } 808 Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx); 809 result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx); 810 } 811 rewriter.replaceOp(op, result); 812 return success(); 813 } 814 815 private: 816 static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) { 817 assert(0 <= r && r < tp.getRank()); 818 if (++idx[r] == tp.getDimSize(r)) { 819 idx[r] = 0; 820 incIdx(idx, tp, r - 1); 821 } 822 } 823 }; 824 825 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp. 826 /// Ex: 827 /// ``` 828 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> 829 /// %1 = vector.multi_reduction add, %0 [1] 830 /// : vector<8x32x16xf32> to vector<8x16xf32> 831 /// ``` 832 /// Gets converted to: 833 /// ``` 834 /// %1 = vector.contract {indexing_maps = [ 835 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 836 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 837 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 838 /// iterator_types = ["parallel", "parallel", "reduction"], 839 /// kind = add} %0, %arg1, %cst_f0 840 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 841 /// ``` 842 struct MultiReduceToContract 843 : public OpRewritePattern<vector::MultiDimReductionOp> { 844 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 845 846 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, 847 PatternRewriter &rewriter) const override { 848 if (reduceOp.kind() != vector::CombiningKind::ADD) 849 return failure(); 850 Operation *mulOp = reduceOp.source().getDefiningOp(); 851 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp)) 852 return failure(); 853 SmallVector<bool> reductionMask = reduceOp.getReductionMask(); 854 auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size()); 855 SmallVector<AffineExpr> exprs; 856 SmallVector<StringRef> iteratorTypes; 857 for (const auto &isReduceDim : llvm::enumerate(reductionMask)) { 858 if (!isReduceDim.value()) { 859 iteratorTypes.push_back(getParallelIteratorTypeName()); 860 exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index())); 861 } else { 862 iteratorTypes.push_back(getReductionIteratorTypeName()); 863 } 864 } 865 auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(), 866 /*symCount=*/0, exprs, reduceOp.getContext()); 867 Value zero = rewriter.create<arith::ConstantOp>( 868 reduceOp.getLoc(), reduceOp.getDestType(), 869 rewriter.getZeroAttr(reduceOp.getDestType())); 870 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>( 871 reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero, 872 rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}), 873 rewriter.getStrArrayAttr(iteratorTypes)); 874 return success(); 875 } 876 }; 877 878 /// Merge TransposeOp into ContractionOp user. 879 /// Ex: 880 /// ``` 881 /// %0 = vector.transpose %arg0, [2, 0, 1] 882 /// : vector<32x16x8xf32> to vector<8x32x16xf32> 883 /// %1 = vector.contract {indexing_maps = [ 884 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 885 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 886 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 887 /// iterator_types = ["parallel", "parallel", "reduction"], 888 /// kind = add} %0, %arg1, %cst_f0 889 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 890 /// ``` 891 /// Gets converted to: 892 /// ``` 893 /// %1 = vector.contract {indexing_maps = [ 894 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>, 895 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 896 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 897 /// iterator_types = ["parallel", "parallel", "reduction"], 898 /// kind = add} %arg0, %arg1, %cst_f0 899 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 900 /// ``` 901 struct CombineContractTranspose 902 : public OpRewritePattern<vector::ContractionOp> { 903 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 904 905 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 906 PatternRewriter &rewriter) const override { 907 SmallVector<AffineMap, 4> maps = 908 llvm::to_vector<4>(contractOp.getIndexingMaps()); 909 Value lhs = contractOp.lhs(); 910 Value rhs = contractOp.rhs(); 911 size_t index = 0; 912 bool changed = false; 913 for (Value *operand : {&lhs, &rhs}) { 914 AffineMap &map = maps[index++]; 915 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>(); 916 if (!transposeOp) 917 continue; 918 SmallVector<int64_t> perm; 919 transposeOp.getTransp(perm); 920 AffineMap permutationMap = AffineMap::getPermutationMap( 921 extractVector<unsigned>(transposeOp.transp()), 922 contractOp.getContext()); 923 map = inversePermutation(permutationMap).compose(map); 924 *operand = transposeOp.vector(); 925 changed = true; 926 } 927 if (!changed) 928 return failure(); 929 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 930 contractOp, lhs, rhs, contractOp.acc(), 931 rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); 932 return success(); 933 } 934 }; 935 936 /// Merge BroadcastOp into ContractionOp user. 937 /// Ex: 938 /// ``` 939 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> 940 /// %1 = vector.contract {indexing_maps = [ 941 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 942 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 943 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 944 /// iterator_types = ["parallel", "parallel", "reduction"], 945 /// kind = add} %0, %arg1, %cst_f0 946 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 947 /// ``` 948 /// Gets converted to: 949 /// ``` 950 /// %1 = vector.contract {indexing_maps = [ 951 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 952 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 953 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 954 /// iterator_types = ["parallel", "parallel", "reduction"], 955 /// kind = add} %arg0, %arg1, %cst_f0 956 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 957 /// ``` 958 struct CombineContractBroadcast 959 : public OpRewritePattern<vector::ContractionOp> { 960 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 961 962 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 963 PatternRewriter &rewriter) const override { 964 SmallVector<AffineMap, 4> maps = 965 llvm::to_vector<4>(contractOp.getIndexingMaps()); 966 Value lhs = contractOp.lhs(); 967 Value rhs = contractOp.rhs(); 968 size_t index = 0; 969 bool changed = false; 970 for (Value *operand : {&lhs, &rhs}) { 971 AffineMap &map = maps[index++]; 972 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>(); 973 if (!broadcast) 974 continue; 975 // contractionOp can only take vector as operands. 976 auto srcType = broadcast.getSourceType().dyn_cast<VectorType>(); 977 if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank()) 978 continue; 979 int64_t rankDiff = 980 broadcast.getVectorType().getRank() - srcType.getRank(); 981 bool innerDimBroadcast = false; 982 SmallVector<AffineExpr> originalDims; 983 for (const auto &dim : llvm::enumerate(srcType.getShape())) { 984 if (dim.value() != 985 broadcast.getVectorType().getDimSize(rankDiff + dim.index())) { 986 innerDimBroadcast = true; 987 break; 988 } 989 originalDims.push_back( 990 rewriter.getAffineDimExpr(dim.index() + rankDiff)); 991 } 992 // Contract doesn't support inner dimension broadcast. Once this is 993 // relaxed we can remove this case. 994 if (innerDimBroadcast) 995 continue; 996 AffineMap broadcastMap = 997 AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims, 998 contractOp.getContext()); 999 map = broadcastMap.compose(map); 1000 *operand = broadcast.source(); 1001 changed = true; 1002 } 1003 if (!changed) 1004 return failure(); 1005 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 1006 contractOp, lhs, rhs, contractOp.acc(), 1007 rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); 1008 return success(); 1009 } 1010 }; 1011 1012 } // namespace 1013 1014 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using 1015 /// operands `x` and `y`. 1016 static Value createAdd(Location loc, Value x, Value y, bool isInt, 1017 PatternRewriter &rewriter) { 1018 if (isInt) 1019 return rewriter.create<arith::AddIOp>(loc, x, y); 1020 return rewriter.create<arith::AddFOp>(loc, x, y); 1021 } 1022 1023 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using 1024 /// operands `x and `y`. 1025 static Value createMul(Location loc, Value x, Value y, bool isInt, 1026 PatternRewriter &rewriter) { 1027 if (isInt) 1028 return rewriter.create<arith::MulIOp>(loc, x, y); 1029 return rewriter.create<arith::MulFOp>(loc, x, y); 1030 } 1031 1032 namespace mlir { 1033 1034 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1035 /// semantics to: 1036 /// ``` 1037 /// %mta = maybe_transpose 1038 /// %mtb = maybe_transpose 1039 /// %flattened_a = vector.shape_cast %mta 1040 /// %flattened_b = vector.shape_cast %mtb 1041 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 1042 /// %mtd = vector.shape_cast %flattened_d 1043 /// %d = maybe_untranspose %mtd 1044 /// %e = add %c, %d 1045 /// ``` 1046 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 1047 // 1048 /// This only kicks in when VectorTransformsOptions is set to `Matmul`. 1049 /// vector.transpose operations are inserted if the vector.contract op is not a 1050 /// row-major matrix multiply. 1051 LogicalResult 1052 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, 1053 PatternRewriter &rew) const { 1054 // TODO: implement masks 1055 if (llvm::size(op.masks()) != 0) 1056 return failure(); 1057 if (vectorTransformOptions.vectorContractLowering != 1058 vector::VectorContractLowering::Matmul) 1059 return failure(); 1060 if (failed(filter(op))) 1061 return failure(); 1062 1063 auto iteratorTypes = op.iterator_types().getValue(); 1064 if (!isParallelIterator(iteratorTypes[0]) || 1065 !isParallelIterator(iteratorTypes[1]) || 1066 !isReductionIterator(iteratorTypes[2])) 1067 return failure(); 1068 1069 Type elementType = op.getLhsType().getElementType(); 1070 if (!elementType.isIntOrFloat()) 1071 return failure(); 1072 1073 // Perform lhs + rhs transpositions to conform to matmul row-major semantics. 1074 // Bail out if the contraction cannot be put in this form. 1075 MLIRContext *ctx = op.getContext(); 1076 Location loc = op.getLoc(); 1077 AffineExpr m, n, k; 1078 bindDims(rew.getContext(), m, n, k); 1079 // LHS must be A(m, k) or A(k, m). 1080 Value lhs = op.lhs(); 1081 auto lhsMap = op.indexing_maps()[0]; 1082 if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) 1083 lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0}); 1084 else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) 1085 return failure(); 1086 1087 // RHS must be B(k, n) or B(n, k). 1088 Value rhs = op.rhs(); 1089 auto rhsMap = op.indexing_maps()[1]; 1090 if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) 1091 rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0}); 1092 else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) 1093 return failure(); 1094 1095 // At this point lhs and rhs are in row-major. 1096 VectorType lhsType = lhs.getType().cast<VectorType>(); 1097 VectorType rhsType = rhs.getType().cast<VectorType>(); 1098 int64_t lhsRows = lhsType.getDimSize(0); 1099 int64_t lhsColumns = lhsType.getDimSize(1); 1100 int64_t rhsColumns = rhsType.getDimSize(1); 1101 1102 Type flattenedLHSType = 1103 VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); 1104 lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs); 1105 1106 Type flattenedRHSType = 1107 VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); 1108 rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs); 1109 1110 Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns, 1111 rhsColumns); 1112 mul = rew.create<vector::ShapeCastOp>( 1113 loc, 1114 VectorType::get({lhsRows, rhsColumns}, 1115 getElementTypeOrSelf(op.acc().getType())), 1116 mul); 1117 1118 // ACC must be C(m, n) or C(n, m). 1119 auto accMap = op.indexing_maps()[2]; 1120 if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) 1121 mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0}); 1122 else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) 1123 llvm_unreachable("invalid contraction semantics"); 1124 1125 Value res = 1126 elementType.isa<IntegerType>() 1127 ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.acc(), mul)) 1128 : static_cast<Value>(rew.create<arith::AddFOp>(loc, op.acc(), mul)); 1129 1130 rew.replaceOp(op, res); 1131 return success(); 1132 } 1133 1134 namespace { 1135 struct IteratorType { 1136 IteratorType(StringRef strRef) : strRef(strRef) {} 1137 bool isOfType(Attribute attr) const { 1138 auto sAttr = attr.dyn_cast<StringAttr>(); 1139 return sAttr && sAttr.getValue() == strRef; 1140 } 1141 StringRef strRef; 1142 }; 1143 struct Par : public IteratorType { 1144 Par() : IteratorType(getParallelIteratorTypeName()) {} 1145 }; 1146 struct Red : public IteratorType { 1147 Red() : IteratorType(getReductionIteratorTypeName()) {} 1148 }; 1149 1150 /// Generate a vector implementation for matmat, matvec and tmatvec. 1151 /// This unrolls outer-products along the reduction dimension. 1152 struct UnrolledOuterProductGenerator 1153 : public StructuredGenerator<vector::ContractionOp> { 1154 1155 UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op) 1156 : StructuredGenerator<vector::ContractionOp>(builder, op), 1157 kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()), 1158 lhsType(op.getLhsType()) {} 1159 1160 Value t(Value v) { 1161 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1162 return builder.create<vector::TransposeOp>(loc, v, perm); 1163 } 1164 1165 Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) { 1166 assert(reductionSize > 0); 1167 for (int64_t k = 0; k < reductionSize; ++k) { 1168 Value a = builder.create<vector::ExtractOp>(loc, lhs, k); 1169 Value b = builder.create<vector::ExtractOp>(loc, rhs, k); 1170 res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b, 1171 res, kind); 1172 } 1173 return res; 1174 } 1175 1176 /// Two outer parallel, one inner reduction (matmat flavor). 1177 FailureOr<Value> matmat() { 1178 if (!iters({Par(), Par(), Red()})) 1179 return failure(); 1180 // Set up the parallel/reduction structure in the right form. 1181 AffineExpr m, n, k; 1182 bindDims(builder.getContext(), m, n, k); 1183 // Classical row-major matmul: Just permute the lhs. 1184 if (layout({{m, k}, {k, n}, {m, n}})) 1185 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1186 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1187 if (layout({{m, k}, {n, k}, {m, n}})) { 1188 Value tlhs = t(lhs); 1189 return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1)); 1190 } 1191 // No need to permute anything. 1192 if (layout({{k, m}, {k, n}, {m, n}})) 1193 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1194 // Just permute the rhs. 1195 if (layout({{k, m}, {n, k}, {m, n}})) 1196 return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0)); 1197 // Transposed output: swap RHS and LHS. 1198 // Classical row-major matmul: permute the lhs. 1199 if (layout({{m, k}, {k, n}, {n, m}})) 1200 return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1)); 1201 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1202 if (layout({{m, k}, {n, k}, {n, m}})) { 1203 Value trhs = t(rhs); 1204 return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1)); 1205 } 1206 if (layout({{k, m}, {k, n}, {n, m}})) 1207 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1208 if (layout({{k, m}, {n, k}, {n, m}})) 1209 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1210 return failure(); 1211 } 1212 1213 /// One outer parallel, one inner reduction (matvec flavor) 1214 FailureOr<Value> matvec() { 1215 if (!iters({Par(), Red()})) 1216 return failure(); 1217 AffineExpr m, k; 1218 bindDims(builder.getContext(), m, k); 1219 1220 // Case mat-vec: transpose. 1221 if (layout({{m, k}, {k}, {m}})) 1222 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1223 // Case mat-trans-vec: ready to go. 1224 if (layout({{k, m}, {k}, {m}})) 1225 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1226 // Case vec-mat: swap and transpose. 1227 if (layout({{k}, {m, k}, {m}})) 1228 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1229 // Case vec-mat-trans: swap and ready to go. 1230 if (layout({{k}, {k, m}, {m}})) 1231 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1232 return failure(); 1233 } 1234 1235 // 1236 // One outer reduction, one inner parallel (tmatvec flavor) 1237 // 1238 FailureOr<Value> tmatvec() { 1239 if (!iters({Red(), Par()})) 1240 return failure(); 1241 AffineExpr k, m; 1242 bindDims(builder.getContext(), k, m); 1243 1244 // Case mat-vec: transpose. 1245 if (layout({{m, k}, {k}, {m}})) 1246 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1247 // Case mat-trans-vec: ready to go. 1248 if (layout({{k, m}, {k}, {m}})) 1249 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1250 // Case vec-mat: swap and transpose. 1251 if (layout({{k}, {m, k}, {m}})) 1252 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1253 // Case vec-mat-trans: swap and ready to go. 1254 if (layout({{k}, {k, m}, {m}})) 1255 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1256 return failure(); 1257 } 1258 1259 private: 1260 vector::CombiningKind kind; 1261 Value lhs, rhs, res; 1262 VectorType lhsType; 1263 }; 1264 } // namespace 1265 1266 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1267 /// semantics to a reduction_size-unrolled sequence: 1268 /// ``` 1269 /// %at = vector.transpose %a, [1, 0] 1270 /// %bRow0 = vector.extract %b[0] 1271 /// %atRow0 = vector.extract %at[0] 1272 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 1273 /// ... 1274 /// %bRowK = vector.extract %b[K] 1275 /// %atRowK = vector.extract %at[K] 1276 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 1277 /// ``` 1278 /// 1279 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but 1280 /// otherwise supports any layout permutation of the matrix-multiply. 1281 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( 1282 vector::ContractionOp op, PatternRewriter &rewriter) const { 1283 // TODO: implement masks 1284 if (llvm::size(op.masks()) != 0) 1285 return failure(); 1286 1287 if (vectorTransformOptions.vectorContractLowering != 1288 vector::VectorContractLowering::OuterProduct) 1289 return failure(); 1290 1291 if (failed(filter(op))) 1292 return failure(); 1293 1294 UnrolledOuterProductGenerator e(rewriter, op); 1295 FailureOr<Value> matmatRes = e.matmat(); 1296 if (succeeded(matmatRes)) { 1297 rewriter.replaceOp(op, *matmatRes); 1298 return success(); 1299 } 1300 FailureOr<Value> matvecRes = e.matvec(); 1301 if (succeeded(matvecRes)) { 1302 rewriter.replaceOp(op, *matvecRes); 1303 return success(); 1304 } 1305 FailureOr<Value> tmatvecRes = e.tmatvec(); 1306 if (succeeded(tmatvecRes)) { 1307 rewriter.replaceOp(op, *tmatvecRes); 1308 return success(); 1309 } 1310 1311 return failure(); 1312 } 1313 1314 LogicalResult 1315 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, 1316 PatternRewriter &rewriter) const { 1317 // TODO: implement masks 1318 if (llvm::size(op.masks()) != 0) 1319 return failure(); 1320 1321 if (failed(filter(op))) 1322 return failure(); 1323 1324 if (vectorTransformOptions.vectorContractLowering != 1325 vector::VectorContractLowering::Dot) 1326 return failure(); 1327 1328 auto iteratorTypes = op.iterator_types().getValue(); 1329 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1330 Location loc = op.getLoc(); 1331 Value lhs = op.lhs(), rhs = op.rhs(); 1332 1333 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1334 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 1335 AffineExpr m, n, k; 1336 bindDims(rewriter.getContext(), m, n, k); 1337 SmallVector<AffineMap, 4> maps = op.getIndexingMaps(); 1338 // 1339 // In the following we wish to make the reduction dimension innermost so we 1340 // can load vectors and just fmul + reduce into a scalar. 1341 // 1342 if (isParallelIterator(iteratorTypes[0]) && 1343 isParallelIterator(iteratorTypes[1]) && 1344 isReductionIterator(iteratorTypes[2])) { 1345 // 1346 // Two outer parallel, one inner reduction (matmat flavor). 1347 // 1348 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 1349 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1350 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { 1351 // No need to permute anything. 1352 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 1353 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1354 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1355 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 1356 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1357 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 1358 // This is the classical row-major matmul. Just permute the lhs. 1359 Value tmp = lhs; 1360 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1361 rhs = tmp; 1362 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 1363 std::swap(lhs, rhs); 1364 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 1365 Value tmp = lhs; 1366 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1367 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm); 1368 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 1369 Value tmp = rhs; 1370 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1371 lhs = tmp; 1372 } else { 1373 return failure(); 1374 } 1375 } else if (isParallelIterator(iteratorTypes[0]) && 1376 isReductionIterator(iteratorTypes[1])) { 1377 // 1378 // One outer parallel, one inner reduction (matvec flavor) 1379 // 1380 if (maps == infer({{m, n}, {n}, {m}})) { 1381 // No need to permute anything. 1382 } else if (maps == infer({{n, m}, {n}, {m}})) { 1383 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1384 } else if (maps == infer({{n}, {m, n}, {m}})) { 1385 std::swap(lhs, rhs); 1386 } else if (maps == infer({{n}, {n, m}, {m}})) { 1387 std::swap(lhs, rhs); 1388 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1389 } else { 1390 return failure(); 1391 } 1392 } else { 1393 return failure(); 1394 } 1395 1396 VectorType dstType = op.getResultType().cast<VectorType>(); 1397 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && 1398 "Expected dst type of rank 1 or 2"); 1399 1400 unsigned rank = dstType.getRank(); 1401 unsigned dstRows = dstType.getShape()[0]; 1402 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; 1403 1404 // ExtractOp does not allow dynamic indexing, we must unroll explicitly. 1405 Value res = rewriter.create<arith::ConstantOp>(loc, dstType, 1406 rewriter.getZeroAttr(dstType)); 1407 bool isInt = dstType.getElementType().isa<IntegerType>(); 1408 for (unsigned r = 0; r < dstRows; ++r) { 1409 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r); 1410 for (unsigned c = 0; c < dstColumns; ++c) { 1411 Value b = rank == 1 1412 ? rhs 1413 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c); 1414 Value m = createMul(op.getLoc(), a, b, isInt, rewriter); 1415 Value reduced = rewriter.create<vector::ReductionOp>( 1416 op.getLoc(), vector::CombiningKind::ADD, m); 1417 1418 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r} 1419 : SmallVector<int64_t, 2>{r, c}; 1420 res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos); 1421 } 1422 } 1423 if (auto acc = op.acc()) 1424 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1425 rewriter.replaceOp(op, res); 1426 return success(); 1427 } 1428 1429 /// Progressive lowering of ContractionOp. 1430 /// One: 1431 /// %x = vector.contract with at least one free/batch dimension 1432 /// is replaced by: 1433 /// %a = vector.contract with one less free/batch dimension 1434 /// %b = vector.contract with one less free/batch dimension 1435 /// .. 1436 /// %x = combine %a %b .. 1437 /// until a pure contraction is reached (no free/batch dimensions), 1438 /// which is replaced by a dot-product. 1439 /// 1440 /// This only kicks in when either VectorTransformsOptions is set 1441 /// to DOT or when other contraction patterns fail. 1442 // 1443 // TODO: break down into transpose/reshape/cast ops 1444 // when they become available to avoid code dup 1445 // TODO: investigate lowering order impact on performance 1446 LogicalResult 1447 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, 1448 PatternRewriter &rewriter) const { 1449 // TODO: implement masks. 1450 if (llvm::size(op.masks()) != 0) 1451 return failure(); 1452 1453 if (failed(filter(op))) 1454 return failure(); 1455 1456 // TODO: support mixed mode contract lowering. 1457 if (op.getLhsType().getElementType() != 1458 getElementTypeOrSelf(op.getAccType()) || 1459 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) 1460 return failure(); 1461 1462 // TODO: implement benefits, cost models. 1463 MLIRContext *ctx = op.getContext(); 1464 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); 1465 if (succeeded(pat1.matchAndRewrite(op, rewriter))) 1466 return success(); 1467 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); 1468 if (succeeded(pat2.matchAndRewrite(op, rewriter))) 1469 return success(); 1470 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); 1471 if (succeeded(pat3.matchAndRewrite(op, rewriter))) 1472 return success(); 1473 1474 // Find first batch dimension in LHS/RHS, and lower when found. 1475 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap(); 1476 if (!batchDimMap.empty()) { 1477 int64_t lhsIndex = batchDimMap[0].first; 1478 int64_t rhsIndex = batchDimMap[0].second; 1479 rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter)); 1480 return success(); 1481 } 1482 1483 // Collect contracting dimensions. 1484 std::vector<std::pair<int64_t, int64_t>> contractingDimMap = 1485 op.getContractingDimMap(); 1486 DenseSet<int64_t> lhsContractingDimSet; 1487 DenseSet<int64_t> rhsContractingDimSet; 1488 for (auto &dimPair : contractingDimMap) { 1489 lhsContractingDimSet.insert(dimPair.first); 1490 rhsContractingDimSet.insert(dimPair.second); 1491 } 1492 1493 // Find first free dimension in LHS, and lower when found. 1494 VectorType lhsType = op.getLhsType(); 1495 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { 1496 if (lhsContractingDimSet.count(lhsIndex) == 0) { 1497 rewriter.replaceOp( 1498 op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter)); 1499 return success(); 1500 } 1501 } 1502 1503 // Find first free dimension in RHS, and lower when found. 1504 VectorType rhsType = op.getRhsType(); 1505 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { 1506 if (rhsContractingDimSet.count(rhsIndex) == 0) { 1507 rewriter.replaceOp( 1508 op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter)); 1509 return success(); 1510 } 1511 } 1512 1513 // Lower the first remaining reduction dimension. 1514 if (!contractingDimMap.empty()) { 1515 rewriter.replaceOp(op, lowerReduction(op, rewriter)); 1516 return success(); 1517 } 1518 1519 return failure(); 1520 } 1521 1522 // Lower one parallel dimension. 1523 // TODO: consider reusing existing contract unrolling 1524 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op, 1525 int64_t lhsIndex, int64_t rhsIndex, 1526 PatternRewriter &rewriter) const { 1527 VectorType lhsType = op.getLhsType(); 1528 VectorType rhsType = op.getRhsType(); 1529 VectorType resType = op.getResultType().cast<VectorType>(); 1530 // Find the iterator type index and result index. 1531 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1532 int64_t iterIndex = -1; 1533 int64_t dimSize = -1; 1534 if (lhsIndex >= 0) { 1535 iterIndex = iMap[0].getDimPosition(lhsIndex); 1536 assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) && 1537 "parallel index should be free in LHS or batch in LHS/RHS"); 1538 dimSize = lhsType.getDimSize(lhsIndex); 1539 } else { 1540 assert(rhsIndex >= 0 && "missing parallel index"); 1541 iterIndex = iMap[1].getDimPosition(rhsIndex); 1542 dimSize = rhsType.getDimSize(rhsIndex); 1543 } 1544 assert(iterIndex >= 0 && "parallel index not listed in operand mapping"); 1545 Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex); 1546 assert(lookup.hasValue() && "parallel index not listed in reduction"); 1547 int64_t resIndex = lookup.getValue(); 1548 // Construct new iterator types and affine map array attribute. 1549 std::array<AffineMap, 3> lowIndexingMaps = { 1550 adjustMap(iMap[0], iterIndex, rewriter), 1551 adjustMap(iMap[1], iterIndex, rewriter), 1552 adjustMap(iMap[2], iterIndex, rewriter)}; 1553 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1554 auto lowIter = 1555 rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); 1556 // Unroll into a series of lower dimensional vector.contract ops. 1557 Location loc = op.getLoc(); 1558 Value result = rewriter.create<arith::ConstantOp>( 1559 loc, resType, rewriter.getZeroAttr(resType)); 1560 for (int64_t d = 0; d < dimSize; ++d) { 1561 auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); 1562 auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); 1563 auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter); 1564 Value lowContract = rewriter.create<vector::ContractionOp>( 1565 loc, lhs, rhs, acc, lowAffine, lowIter); 1566 result = 1567 reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter); 1568 } 1569 return result; 1570 } 1571 1572 // Lower one reduction dimension. 1573 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op, 1574 PatternRewriter &rewriter) const { 1575 auto loc = op.getLoc(); 1576 VectorType lhsType = op.getLhsType(); 1577 VectorType rhsType = op.getRhsType(); 1578 Type resType = op.getResultType(); 1579 assert(!resType.isa<VectorType>()); 1580 bool isInt = resType.isa<IntegerType>(); 1581 // Use iterator index 0. 1582 int64_t iterIndex = 0; 1583 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1584 Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex); 1585 Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex); 1586 assert(lookupLhs.hasValue() && "missing LHS parallel index"); 1587 assert(lookupRhs.hasValue() && "missing RHS parallel index"); 1588 int64_t lhsIndex = lookupLhs.getValue(); 1589 int64_t rhsIndex = lookupRhs.getValue(); 1590 int64_t dimSize = lhsType.getDimSize(lhsIndex); 1591 assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape"); 1592 // Base case. 1593 if (lhsType.getRank() == 1) { 1594 assert(rhsType.getRank() == 1 && "corrupt contraction"); 1595 Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter); 1596 auto kind = vector::CombiningKind::ADD; 1597 Value res = rewriter.create<vector::ReductionOp>(loc, kind, m); 1598 if (auto acc = op.acc()) 1599 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1600 return res; 1601 } 1602 // Construct new iterator types and affine map array attribute. 1603 std::array<AffineMap, 3> lowIndexingMaps = { 1604 adjustMap(iMap[0], iterIndex, rewriter), 1605 adjustMap(iMap[1], iterIndex, rewriter), 1606 adjustMap(iMap[2], iterIndex, rewriter)}; 1607 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1608 auto lowIter = 1609 rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); 1610 // Unroll into a series of lower dimensional vector.contract ops. 1611 // By feeding the initial accumulator into the first contraction, 1612 // and the result of each contraction into the next, eventually 1613 // the sum of all reductions is computed. 1614 Value result = op.acc(); 1615 for (int64_t d = 0; d < dimSize; ++d) { 1616 auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); 1617 auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); 1618 result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result, 1619 lowAffine, lowIter); 1620 } 1621 return result; 1622 } 1623 1624 } // namespace mlir 1625 1626 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp( 1627 OpBuilder &builder, Operation *op, ArrayRef<Value> ids, 1628 ArrayRef<int64_t> multiplicity, const AffineMap &map) { 1629 OpBuilder::InsertionGuard guard(builder); 1630 builder.setInsertionPointAfter(op); 1631 Location loc = op->getLoc(); 1632 if (op->getNumResults() != 1) 1633 return {}; 1634 Value result = op->getResult(0); 1635 VectorType type = op->getResult(0).getType().dyn_cast<VectorType>(); 1636 if (!type || map.getNumResults() != multiplicity.size()) 1637 return {}; 1638 // For each dimension being distributed check that the size is a multiple of 1639 // the multiplicity. To handle more sizes we would need to support masking. 1640 unsigned multiplictyCount = 0; 1641 for (auto exp : map.getResults()) { 1642 auto affinExp = exp.dyn_cast<AffineDimExpr>(); 1643 if (!affinExp || affinExp.getPosition() >= type.getRank() || 1644 type.getDimSize(affinExp.getPosition()) % 1645 multiplicity[multiplictyCount++] != 1646 0) 1647 return {}; 1648 } 1649 DistributeOps ops; 1650 ops.extract = 1651 builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map); 1652 ops.insert = 1653 builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids); 1654 return ops; 1655 } 1656 1657 /// Progressive lowering of transfer_read. This pattern supports lowering of 1658 /// `vector.transfer_read` to a combination of `vector.load` and 1659 /// `vector.broadcast` if all of the following hold: 1660 /// - Stride of most minor memref dimension must be 1. 1661 /// - Out-of-bounds masking is not required. 1662 /// - If the memref's element type is a vector type then it coincides with the 1663 /// result type. 1664 /// - The permutation map doesn't perform permutation (broadcasting is allowed). 1665 struct TransferReadToVectorLoadLowering 1666 : public OpRewritePattern<vector::TransferReadOp> { 1667 TransferReadToVectorLoadLowering(MLIRContext *context, 1668 llvm::Optional<unsigned> maxRank) 1669 : OpRewritePattern<vector::TransferReadOp>(context), 1670 maxTransferRank(maxRank) {} 1671 1672 LogicalResult matchAndRewrite(vector::TransferReadOp read, 1673 PatternRewriter &rewriter) const override { 1674 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) 1675 return failure(); 1676 1677 SmallVector<unsigned, 4> broadcastedDims; 1678 // Permutations are handled by VectorToSCF or 1679 // populateVectorTransferPermutationMapLoweringPatterns. 1680 // We let the 0-d corner case pass-through as it is supported. 1681 if (!read.permutation_map().isMinorIdentityWithBroadcasting( 1682 &broadcastedDims)) 1683 return failure(); 1684 1685 auto memRefType = read.getShapedType().dyn_cast<MemRefType>(); 1686 if (!memRefType) 1687 return failure(); 1688 1689 // Non-unit strides are handled by VectorToSCF. 1690 if (!vector::isLastMemrefDimUnitStride(memRefType)) 1691 return failure(); 1692 1693 // If there is broadcasting involved then we first load the unbroadcasted 1694 // vector, and then broadcast it with `vector.broadcast`. 1695 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape(); 1696 SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(), 1697 vectorShape.end()); 1698 for (unsigned i : broadcastedDims) 1699 unbroadcastedVectorShape[i] = 1; 1700 VectorType unbroadcastedVectorType = VectorType::get( 1701 unbroadcastedVectorShape, read.getVectorType().getElementType()); 1702 1703 // `vector.load` supports vector types as memref's elements only when the 1704 // resulting vector type is the same as the element type. 1705 auto memrefElTy = memRefType.getElementType(); 1706 if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType) 1707 return failure(); 1708 1709 // Otherwise, element types of the memref and the vector must match. 1710 if (!memrefElTy.isa<VectorType>() && 1711 memrefElTy != read.getVectorType().getElementType()) 1712 return failure(); 1713 1714 // Out-of-bounds dims are handled by MaterializeTransferMask. 1715 if (read.hasOutOfBoundsDim()) 1716 return failure(); 1717 1718 // Create vector load op. 1719 Operation *loadOp; 1720 if (read.mask()) { 1721 Value fill = rewriter.create<vector::SplatOp>( 1722 read.getLoc(), unbroadcastedVectorType, read.padding()); 1723 loadOp = rewriter.create<vector::MaskedLoadOp>( 1724 read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(), 1725 read.mask(), fill); 1726 } else { 1727 loadOp = rewriter.create<vector::LoadOp>(read.getLoc(), 1728 unbroadcastedVectorType, 1729 read.source(), read.indices()); 1730 } 1731 1732 // Insert a broadcasting op if required. 1733 if (!broadcastedDims.empty()) { 1734 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1735 read, read.getVectorType(), loadOp->getResult(0)); 1736 } else { 1737 rewriter.replaceOp(read, loadOp->getResult(0)); 1738 } 1739 1740 return success(); 1741 } 1742 1743 llvm::Optional<unsigned> maxTransferRank; 1744 }; 1745 1746 /// Replace a 0-d vector.load with a memref.load + vector.broadcast. 1747 // TODO: we shouldn't cross the vector/scalar domains just for this 1748 // but atm we lack the infra to avoid it. Possible solutions include: 1749 // - go directly to LLVM + bitcast 1750 // - introduce a bitcast op and likely a new pointer dialect 1751 // - let memref.load/store additionally support the 0-d vector case 1752 // There are still deeper data layout issues lingering even in this 1753 // trivial case (for architectures for which this matters). 1754 struct VectorLoadToMemrefLoadLowering 1755 : public OpRewritePattern<vector::LoadOp> { 1756 using OpRewritePattern<vector::LoadOp>::OpRewritePattern; 1757 1758 LogicalResult matchAndRewrite(vector::LoadOp loadOp, 1759 PatternRewriter &rewriter) const override { 1760 auto vecType = loadOp.getVectorType(); 1761 if (vecType.getNumElements() != 1) 1762 return failure(); 1763 auto memrefLoad = rewriter.create<memref::LoadOp>( 1764 loadOp.getLoc(), loadOp.base(), loadOp.indices()); 1765 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType, 1766 memrefLoad); 1767 return success(); 1768 } 1769 }; 1770 1771 /// Replace a 0-d vector.store with a vector.extractelement + memref.store. 1772 struct VectorStoreToMemrefStoreLowering 1773 : public OpRewritePattern<vector::StoreOp> { 1774 using OpRewritePattern<vector::StoreOp>::OpRewritePattern; 1775 1776 LogicalResult matchAndRewrite(vector::StoreOp storeOp, 1777 PatternRewriter &rewriter) const override { 1778 auto vecType = storeOp.getVectorType(); 1779 if (vecType.getNumElements() != 1) 1780 return failure(); 1781 Value extracted; 1782 if (vecType.getRank() == 0) { 1783 // TODO: Unifiy once ExtractOp supports 0-d vectors. 1784 extracted = rewriter.create<vector::ExtractElementOp>( 1785 storeOp.getLoc(), storeOp.valueToStore()); 1786 } else { 1787 SmallVector<int64_t> indices(vecType.getRank(), 0); 1788 extracted = rewriter.create<vector::ExtractOp>( 1789 storeOp.getLoc(), storeOp.valueToStore(), indices); 1790 } 1791 1792 rewriter.replaceOpWithNewOp<memref::StoreOp>( 1793 storeOp, extracted, storeOp.base(), storeOp.indices()); 1794 return success(); 1795 } 1796 }; 1797 1798 /// Progressive lowering of transfer_write. This pattern supports lowering of 1799 /// `vector.transfer_write` to `vector.store` if all of the following hold: 1800 /// - Stride of most minor memref dimension must be 1. 1801 /// - Out-of-bounds masking is not required. 1802 /// - If the memref's element type is a vector type then it coincides with the 1803 /// type of the written value. 1804 /// - The permutation map is the minor identity map (neither permutation nor 1805 /// broadcasting is allowed). 1806 struct TransferWriteToVectorStoreLowering 1807 : public OpRewritePattern<vector::TransferWriteOp> { 1808 TransferWriteToVectorStoreLowering(MLIRContext *context, 1809 llvm::Optional<unsigned> maxRank) 1810 : OpRewritePattern<vector::TransferWriteOp>(context), 1811 maxTransferRank(maxRank) {} 1812 1813 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 1814 PatternRewriter &rewriter) const override { 1815 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) 1816 return failure(); 1817 1818 // Permutations are handled by VectorToSCF or 1819 // populateVectorTransferPermutationMapLoweringPatterns. 1820 if ( // pass-through for the 0-d corner case. 1821 !write.permutation_map().isMinorIdentity()) 1822 return failure(); 1823 1824 auto memRefType = write.getShapedType().dyn_cast<MemRefType>(); 1825 if (!memRefType) 1826 return failure(); 1827 1828 // Non-unit strides are handled by VectorToSCF. 1829 if (!vector::isLastMemrefDimUnitStride(memRefType)) 1830 return failure(); 1831 1832 // `vector.store` supports vector types as memref's elements only when the 1833 // type of the vector value being written is the same as the element type. 1834 auto memrefElTy = memRefType.getElementType(); 1835 if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType()) 1836 return failure(); 1837 1838 // Otherwise, element types of the memref and the vector must match. 1839 if (!memrefElTy.isa<VectorType>() && 1840 memrefElTy != write.getVectorType().getElementType()) 1841 return failure(); 1842 1843 // Out-of-bounds dims are handled by MaterializeTransferMask. 1844 if (write.hasOutOfBoundsDim()) 1845 return failure(); 1846 if (write.mask()) { 1847 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 1848 write, write.source(), write.indices(), write.mask(), write.vector()); 1849 } else { 1850 rewriter.replaceOpWithNewOp<vector::StoreOp>( 1851 write, write.vector(), write.source(), write.indices()); 1852 } 1853 return success(); 1854 } 1855 1856 llvm::Optional<unsigned> maxTransferRank; 1857 }; 1858 1859 // Returns the values in `arrayAttr` as an integer vector. 1860 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) { 1861 return llvm::to_vector<4>( 1862 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(), 1863 [](IntegerAttr attr) { return attr.getInt(); })); 1864 } 1865 1866 // Shuffles vector.bitcast op after vector.extract op. 1867 // 1868 // This transforms IR like: 1869 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> 1870 // %1 = vector.extract %0[3] : vector<8xf16> 1871 // Into: 1872 // %0 = vector.extract %src[1] : vector<4xf32> 1873 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16> 1874 // %2 = vector.extract %1[1] : vector<2xf16> 1875 struct BubbleDownVectorBitCastForExtract 1876 : public OpRewritePattern<vector::ExtractOp> { 1877 using OpRewritePattern::OpRewritePattern; 1878 1879 LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 1880 PatternRewriter &rewriter) const override { 1881 // Only support extracting scalars for now. 1882 if (extractOp.getVectorType().getRank() != 1) 1883 return failure(); 1884 1885 auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>(); 1886 if (!castOp) 1887 return failure(); 1888 1889 VectorType castSrcType = castOp.getSourceVectorType(); 1890 VectorType castDstType = castOp.getResultVectorType(); 1891 assert(castSrcType.getRank() == castDstType.getRank()); 1892 1893 // Fail to match if we only have one element in the cast op source. 1894 // This is to avoid infinite loop given that this pattern can generate 1895 // such cases. 1896 if (castSrcType.getNumElements() == 1) 1897 return failure(); 1898 1899 // Only support casting to a larger number of elements or now. 1900 // E.g., vector<4xf32> -> vector<8xf16>. 1901 if (castSrcType.getNumElements() > castDstType.getNumElements()) 1902 return failure(); 1903 1904 unsigned expandRatio = 1905 castDstType.getNumElements() / castSrcType.getNumElements(); 1906 1907 auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t { 1908 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 1909 }; 1910 1911 uint64_t index = getFirstIntValue(extractOp.position()); 1912 1913 // Get the single scalar (as a vector) in the source value that packs the 1914 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> 1915 VectorType oneScalarType = 1916 VectorType::get({1}, castSrcType.getElementType()); 1917 Value packedValue = rewriter.create<vector::ExtractOp>( 1918 extractOp.getLoc(), oneScalarType, castOp.source(), 1919 rewriter.getI64ArrayAttr(index / expandRatio)); 1920 1921 // Cast it to a vector with the desired scalar's type. 1922 // E.g. f32 -> vector<2xf16> 1923 VectorType packedType = 1924 VectorType::get({expandRatio}, castDstType.getElementType()); 1925 Value castedValue = rewriter.create<vector::BitCastOp>( 1926 extractOp.getLoc(), packedType, packedValue); 1927 1928 // Finally extract the desired scalar. 1929 rewriter.replaceOpWithNewOp<vector::ExtractOp>( 1930 extractOp, extractOp.getType(), castedValue, 1931 rewriter.getI64ArrayAttr(index % expandRatio)); 1932 1933 return success(); 1934 } 1935 }; 1936 1937 // Shuffles vector.bitcast op after vector.extract_strided_slice op. 1938 // 1939 // This transforms IR like: 1940 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> 1941 // %0 = vector.extract_strided_slice %cast { 1942 // offsets = [4], sizes = [4], strides = [1] 1943 // } : vector<8xf16> to vector<4xf16> 1944 // Into: 1945 // %0 = vector.extract_strided_slice %src { 1946 // offsets = [2], sizes = [2], strides = [1] 1947 // } : vector<4xf32> to vector<2xf32> 1948 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> 1949 struct BubbleDownBitCastForStridedSliceExtract 1950 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 1951 using OpRewritePattern::OpRewritePattern; 1952 1953 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 1954 PatternRewriter &rewriter) const override { 1955 auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>(); 1956 if (!castOp) 1957 return failure(); 1958 1959 VectorType castSrcType = castOp.getSourceVectorType(); 1960 VectorType castDstType = castOp.getResultVectorType(); 1961 assert(castSrcType.getRank() == castDstType.getRank()); 1962 1963 int64_t castSrcLastDim = castSrcType.getShape().back(); 1964 int64_t castDstLastDim = castDstType.getShape().back(); 1965 // Require casting to more elements for now; other cases to be implemented. 1966 if (castSrcLastDim > castDstLastDim) 1967 return failure(); 1968 1969 // Only accept all one strides for now. 1970 if (llvm::any_of(extractOp.strides().getAsValueRange<IntegerAttr>(), 1971 [](const APInt &val) { return !val.isOneValue(); })) 1972 return failure(); 1973 1974 unsigned rank = extractOp.getVectorType().getRank(); 1975 assert(castDstLastDim % castSrcLastDim == 0); 1976 int64_t expandRatio = castDstLastDim / castSrcLastDim; 1977 1978 // If we have a less number of offsets than the rank, then implicitly we 1979 // are selecting the full range for the last bitcasted dimension; other 1980 // dimensions aren't affected. Otherwise, we need to scale down the last 1981 // dimension's offset given we are extracting from less elements now. 1982 ArrayAttr newOffsets = extractOp.offsets(); 1983 if (newOffsets.size() == rank) { 1984 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 1985 if (offsets.back() % expandRatio != 0) 1986 return failure(); 1987 offsets.back() = offsets.back() / expandRatio; 1988 newOffsets = rewriter.getI64ArrayAttr(offsets); 1989 } 1990 1991 // Similarly for sizes. 1992 ArrayAttr newSizes = extractOp.sizes(); 1993 if (newSizes.size() == rank) { 1994 SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes); 1995 if (sizes.back() % expandRatio != 0) 1996 return failure(); 1997 sizes.back() = sizes.back() / expandRatio; 1998 newSizes = rewriter.getI64ArrayAttr(sizes); 1999 } 2000 2001 SmallVector<int64_t, 4> dims = 2002 llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape()); 2003 dims.back() = dims.back() / expandRatio; 2004 VectorType newExtractType = 2005 VectorType::get(dims, castSrcType.getElementType()); 2006 2007 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 2008 extractOp.getLoc(), newExtractType, castOp.source(), newOffsets, 2009 newSizes, extractOp.strides()); 2010 2011 rewriter.replaceOpWithNewOp<vector::BitCastOp>( 2012 extractOp, extractOp.getType(), newExtractOp); 2013 2014 return success(); 2015 } 2016 }; 2017 2018 // Shuffles vector.bitcast op before vector.insert_strided_slice op. 2019 // 2020 // This transforms IR like: 2021 // %0 = vector.insert_strided_slice %src, %dst { 2022 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> 2023 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> 2024 // Into: 2025 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32> 2026 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32> 2027 // %2 = vector.insert_strided_slice %src, %dst { 2028 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 2029 struct BubbleUpBitCastForStridedSliceInsert 2030 : public OpRewritePattern<vector::BitCastOp> { 2031 using OpRewritePattern::OpRewritePattern; 2032 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 2033 PatternRewriter &rewriter) const override { 2034 VectorType castSrcType = bitcastOp.getSourceVectorType(); 2035 VectorType castDstType = bitcastOp.getResultVectorType(); 2036 assert(castSrcType.getRank() == castDstType.getRank()); 2037 2038 int64_t castSrcLastDim = castSrcType.getShape().back(); 2039 int64_t castDstLastDim = castDstType.getShape().back(); 2040 // Require casting to less elements for now; other cases to be implemented. 2041 if (castSrcLastDim < castDstLastDim) 2042 return failure(); 2043 2044 assert(castSrcLastDim % castDstLastDim == 0); 2045 int64_t shrinkRatio = castSrcLastDim / castDstLastDim; 2046 2047 auto insertOp = 2048 bitcastOp.source().getDefiningOp<vector::InsertStridedSliceOp>(); 2049 if (!insertOp) 2050 return failure(); 2051 2052 // Only accept all one strides for now. 2053 if (llvm::any_of(insertOp.strides().getAsValueRange<IntegerAttr>(), 2054 [](const APInt &val) { return !val.isOneValue(); })) 2055 return failure(); 2056 2057 unsigned rank = insertOp.getSourceVectorType().getRank(); 2058 // Require insert op to have the same rank for the source and destination 2059 // vector; other cases to be implemented. 2060 if (rank != insertOp.getDestVectorType().getRank()) 2061 return failure(); 2062 2063 ArrayAttr newOffsets = insertOp.offsets(); 2064 assert(newOffsets.size() == rank); 2065 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2066 if (offsets.back() % shrinkRatio != 0) 2067 return failure(); 2068 offsets.back() = offsets.back() / shrinkRatio; 2069 newOffsets = rewriter.getI64ArrayAttr(offsets); 2070 2071 SmallVector<int64_t, 4> srcDims = 2072 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape()); 2073 srcDims.back() = srcDims.back() / shrinkRatio; 2074 VectorType newCastSrcType = 2075 VectorType::get(srcDims, castDstType.getElementType()); 2076 2077 auto newCastSrcOp = rewriter.create<vector::BitCastOp>( 2078 bitcastOp.getLoc(), newCastSrcType, insertOp.source()); 2079 2080 SmallVector<int64_t, 4> dstDims = 2081 llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); 2082 dstDims.back() = dstDims.back() / shrinkRatio; 2083 VectorType newCastDstType = 2084 VectorType::get(dstDims, castDstType.getElementType()); 2085 2086 auto newCastDstOp = rewriter.create<vector::BitCastOp>( 2087 bitcastOp.getLoc(), newCastDstType, insertOp.dest()); 2088 2089 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>( 2090 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, 2091 insertOp.strides()); 2092 2093 return success(); 2094 } 2095 }; 2096 2097 static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc, 2098 Type targetType, Value value) { 2099 if (targetType == value.getType()) 2100 return value; 2101 2102 bool targetIsIndex = targetType.isIndex(); 2103 bool valueIsIndex = value.getType().isIndex(); 2104 if (targetIsIndex ^ valueIsIndex) 2105 return rewriter.create<arith::IndexCastOp>(loc, targetType, value); 2106 2107 auto targetIntegerType = targetType.dyn_cast<IntegerType>(); 2108 auto valueIntegerType = value.getType().dyn_cast<IntegerType>(); 2109 assert(targetIntegerType && valueIntegerType && 2110 "unexpected cast between types other than integers and index"); 2111 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); 2112 2113 if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) 2114 return rewriter.create<arith::ExtSIOp>(loc, targetIntegerType, value); 2115 return rewriter.create<arith::TruncIOp>(loc, targetIntegerType, value); 2116 } 2117 2118 // Helper that returns a vector comparison that constructs a mask: 2119 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 2120 // 2121 // If `dim == 0` then the result will be a 0-D vector. 2122 // 2123 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 2124 // much more compact, IR for this operation, but LLVM eventually 2125 // generates more elaborate instructions for this intrinsic since it 2126 // is very conservative on the boundary conditions. 2127 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, 2128 bool indexOptimizations, int64_t dim, 2129 Value b, Value *off = nullptr) { 2130 auto loc = op->getLoc(); 2131 // If we can assume all indices fit in 32-bit, we perform the vector 2132 // comparison in 32-bit to get a higher degree of SIMD parallelism. 2133 // Otherwise we perform the vector comparison using 64-bit indices. 2134 Type idxType = 2135 indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type(); 2136 DenseIntElementsAttr indicesAttr; 2137 if (dim == 0 && indexOptimizations) { 2138 indicesAttr = DenseIntElementsAttr::get( 2139 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0}); 2140 } else if (dim == 0) { 2141 indicesAttr = DenseIntElementsAttr::get( 2142 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0}); 2143 } else if (indexOptimizations) { 2144 indicesAttr = rewriter.getI32VectorAttr( 2145 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))); 2146 } else { 2147 indicesAttr = rewriter.getI64VectorAttr( 2148 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))); 2149 } 2150 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 2151 // Add in an offset if requested. 2152 if (off) { 2153 Value o = createCastToIndexLike(rewriter, loc, idxType, *off); 2154 Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o); 2155 indices = rewriter.create<arith::AddIOp>(loc, ov, indices); 2156 } 2157 // Construct the vector comparison. 2158 Value bound = createCastToIndexLike(rewriter, loc, idxType, b); 2159 Value bounds = 2160 rewriter.create<vector::SplatOp>(loc, indices.getType(), bound); 2161 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices, 2162 bounds); 2163 } 2164 2165 template <typename ConcreteOp> 2166 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> { 2167 public: 2168 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt) 2169 : mlir::OpRewritePattern<ConcreteOp>(context), 2170 indexOptimizations(enableIndexOpt) {} 2171 2172 LogicalResult matchAndRewrite(ConcreteOp xferOp, 2173 PatternRewriter &rewriter) const override { 2174 if (!xferOp.hasOutOfBoundsDim()) 2175 return failure(); 2176 2177 if (xferOp.getVectorType().getRank() > 1 || 2178 llvm::size(xferOp.indices()) == 0) 2179 return failure(); 2180 2181 Location loc = xferOp->getLoc(); 2182 VectorType vtp = xferOp.getVectorType(); 2183 2184 // * Create a vector with linear indices [ 0 .. vector_length - 1 ]. 2185 // * Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 2186 // * Let dim the memref dimension, compute the vector comparison mask 2187 // (in-bounds mask): 2188 // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 2189 // 2190 // TODO: when the leaf transfer rank is k > 1, we need the last `k` 2191 // dimensions here. 2192 unsigned vecWidth = vtp.getNumElements(); 2193 unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 2194 Value off = xferOp.indices()[lastIndex]; 2195 Value dim = 2196 vector::createOrFoldDimOp(rewriter, loc, xferOp.source(), lastIndex); 2197 Value mask = buildVectorComparison(rewriter, xferOp, indexOptimizations, 2198 vecWidth, dim, &off); 2199 2200 if (xferOp.mask()) { 2201 // Intersect the in-bounds with the mask specified as an op parameter. 2202 mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.mask()); 2203 } 2204 2205 rewriter.updateRootInPlace(xferOp, [&]() { 2206 xferOp.maskMutable().assign(mask); 2207 xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true})); 2208 }); 2209 2210 return success(); 2211 } 2212 2213 private: 2214 const bool indexOptimizations; 2215 }; 2216 2217 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only). 2218 class VectorCreateMaskOpConversion 2219 : public OpRewritePattern<vector::CreateMaskOp> { 2220 public: 2221 explicit VectorCreateMaskOpConversion(MLIRContext *context, 2222 bool enableIndexOpt) 2223 : mlir::OpRewritePattern<vector::CreateMaskOp>(context), 2224 indexOptimizations(enableIndexOpt) {} 2225 2226 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 2227 PatternRewriter &rewriter) const override { 2228 auto dstType = op.getType(); 2229 int64_t rank = dstType.getRank(); 2230 if (rank > 1) 2231 return failure(); 2232 rewriter.replaceOp( 2233 op, buildVectorComparison(rewriter, op, indexOptimizations, 2234 rank == 0 ? 0 : dstType.getDimSize(0), 2235 op.getOperand(0))); 2236 return success(); 2237 } 2238 2239 private: 2240 const bool indexOptimizations; 2241 }; 2242 2243 // Drop inner most contiguous unit dimensions from transfer_read operand. 2244 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> { 2245 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 2246 2247 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 2248 PatternRewriter &rewriter) const override { 2249 // TODO: support 0-d corner case. 2250 if (readOp.getTransferRank() == 0) 2251 return failure(); 2252 2253 // TODO: support mask. 2254 if (readOp.mask()) 2255 return failure(); 2256 2257 auto srcType = readOp.source().getType().dyn_cast<MemRefType>(); 2258 if (!srcType || !srcType.hasStaticShape()) 2259 return failure(); 2260 2261 if (!readOp.permutation_map().isMinorIdentity()) 2262 return failure(); 2263 2264 auto targetType = readOp.getVectorType(); 2265 if (targetType.getRank() <= 1) 2266 return failure(); 2267 2268 SmallVector<int64_t> srcStrides; 2269 int64_t srcOffset; 2270 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) 2271 return failure(); 2272 2273 size_t dimsToDrop = 0; 2274 for (size_t i = 1; i < srcStrides.size(); ++i) { 2275 int dim = srcType.getRank() - i - 1; 2276 if (srcStrides[dim] == 1) { 2277 dimsToDrop++; 2278 } else { 2279 break; 2280 } 2281 } 2282 if (dimsToDrop == 0) 2283 return failure(); 2284 2285 auto resultTargetVecType = 2286 VectorType::get(targetType.getShape().drop_back(dimsToDrop), 2287 targetType.getElementType()); 2288 2289 MemRefType resultMemrefType; 2290 if (srcType.getLayout().getAffineMap().isIdentity()) { 2291 resultMemrefType = MemRefType::get( 2292 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2293 {}, srcType.getMemorySpaceAsInt()); 2294 } else { 2295 AffineMap map = srcType.getLayout().getAffineMap(); 2296 int numResultDims = map.getNumDims() - dimsToDrop; 2297 int numSymbols = map.getNumSymbols(); 2298 for (size_t i = 0; i < dimsToDrop; ++i) { 2299 int dim = srcType.getRank() - i - 1; 2300 map = map.replace(rewriter.getAffineDimExpr(dim), 2301 rewriter.getAffineConstantExpr(0), numResultDims, 2302 numSymbols); 2303 } 2304 resultMemrefType = MemRefType::get( 2305 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2306 map, srcType.getMemorySpaceAsInt()); 2307 } 2308 2309 auto loc = readOp.getLoc(); 2310 SmallVector<int64_t> offsets(srcType.getRank(), 0); 2311 SmallVector<int64_t> strides(srcType.getRank(), 1); 2312 2313 ArrayAttr inBoundsAttr = 2314 readOp.in_bounds() 2315 ? rewriter.getArrayAttr( 2316 readOp.in_boundsAttr().getValue().drop_back(dimsToDrop)) 2317 : ArrayAttr(); 2318 Value rankedReducedView = rewriter.create<memref::SubViewOp>( 2319 loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(), 2320 strides); 2321 auto permMap = getTransferMinorIdentityMap( 2322 rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType); 2323 Value result = rewriter.create<vector::TransferReadOp>( 2324 loc, resultTargetVecType, rankedReducedView, 2325 readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), 2326 readOp.padding(), 2327 // TODO: support mask. 2328 /*mask=*/Value(), inBoundsAttr); 2329 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType, 2330 result); 2331 return success(); 2332 } 2333 }; 2334 2335 namespace { 2336 2337 /// This function checks to see if the vector combining kind 2338 /// is consistent with the integer or float element type. 2339 static bool isValidKind(bool isInt, vector::CombiningKind kind) { 2340 using vector::CombiningKind; 2341 enum class KindType { FLOAT, INT, INVALID }; 2342 KindType type{KindType::INVALID}; 2343 switch (kind) { 2344 case CombiningKind::MINF: 2345 case CombiningKind::MAXF: 2346 type = KindType::FLOAT; 2347 break; 2348 case CombiningKind::MINUI: 2349 case CombiningKind::MINSI: 2350 case CombiningKind::MAXUI: 2351 case CombiningKind::MAXSI: 2352 case CombiningKind::AND: 2353 case CombiningKind::OR: 2354 case CombiningKind::XOR: 2355 type = KindType::INT; 2356 break; 2357 case CombiningKind::ADD: 2358 case CombiningKind::MUL: 2359 type = isInt ? KindType::INT : KindType::FLOAT; 2360 break; 2361 } 2362 bool isValidIntKind = (type == KindType::INT) && isInt; 2363 bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt); 2364 return (isValidIntKind || isValidFloatKind); 2365 } 2366 2367 /// This function constructs the appropriate integer or float 2368 /// operation given the vector combining kind and operands. The 2369 /// supported int operations are : add, mul, min (signed/unsigned), 2370 /// max(signed/unsigned), and, or, xor. The supported float 2371 /// operations are : add, mul, min and max. 2372 static Value genOperator(Location loc, Value x, Value y, 2373 vector::CombiningKind kind, 2374 PatternRewriter &rewriter) { 2375 using vector::CombiningKind; 2376 2377 auto elType = x.getType().cast<VectorType>().getElementType(); 2378 bool isInt = elType.isIntOrIndex(); 2379 2380 Value combinedResult{nullptr}; 2381 switch (kind) { 2382 case CombiningKind::ADD: 2383 if (isInt) 2384 combinedResult = rewriter.create<arith::AddIOp>(loc, x, y); 2385 else 2386 combinedResult = rewriter.create<arith::AddFOp>(loc, x, y); 2387 break; 2388 case CombiningKind::MUL: 2389 if (isInt) 2390 combinedResult = rewriter.create<arith::MulIOp>(loc, x, y); 2391 else 2392 combinedResult = rewriter.create<arith::MulFOp>(loc, x, y); 2393 break; 2394 case CombiningKind::MINUI: 2395 combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y); 2396 break; 2397 case CombiningKind::MINSI: 2398 combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y); 2399 break; 2400 case CombiningKind::MAXUI: 2401 combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y); 2402 break; 2403 case CombiningKind::MAXSI: 2404 combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y); 2405 break; 2406 case CombiningKind::AND: 2407 combinedResult = rewriter.create<arith::AndIOp>(loc, x, y); 2408 break; 2409 case CombiningKind::OR: 2410 combinedResult = rewriter.create<arith::OrIOp>(loc, x, y); 2411 break; 2412 case CombiningKind::XOR: 2413 combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y); 2414 break; 2415 case CombiningKind::MINF: 2416 combinedResult = rewriter.create<arith::MinFOp>(loc, x, y); 2417 break; 2418 case CombiningKind::MAXF: 2419 combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y); 2420 break; 2421 } 2422 return combinedResult; 2423 } 2424 2425 /// Convert vector.scan op into arith ops and 2426 /// vector.insert_strided_slice/extract_strided_slice 2427 /// 2428 /// Ex: 2429 /// ``` 2430 /// %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 2431 /// 1} : 2432 /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>) 2433 /// ``` 2434 /// Gets converted to: 2435 /// ``` 2436 /// %cst = arith.constant dense<0> : vector<2x3xi32> 2437 /// %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1], 2438 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 = 2439 /// vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]} 2440 /// : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice 2441 /// %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : 2442 /// vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 : 2443 /// vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1], 2444 /// strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 = 2445 /// vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1], 2446 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3, 2447 /// %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets = 2448 /// [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 = 2449 /// vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 : 2450 /// vector<2x3xi32>, vector<2xi32> 2451 /// ``` 2452 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { 2453 using OpRewritePattern<vector::ScanOp>::OpRewritePattern; 2454 2455 LogicalResult matchAndRewrite(vector::ScanOp scanOp, 2456 PatternRewriter &rewriter) const override { 2457 auto loc = scanOp.getLoc(); 2458 VectorType destType = scanOp.getDestType(); 2459 ArrayRef<int64_t> destShape = destType.getShape(); 2460 auto elType = destType.getElementType(); 2461 bool isInt = elType.isIntOrIndex(); 2462 if (!isValidKind(isInt, scanOp.kind())) 2463 return failure(); 2464 2465 VectorType resType = VectorType::get(destShape, elType); 2466 Value result = rewriter.create<arith::ConstantOp>( 2467 loc, resType, rewriter.getZeroAttr(resType)); 2468 int64_t reductionDim = scanOp.reduction_dim(); 2469 bool inclusive = scanOp.inclusive(); 2470 int64_t destRank = destType.getRank(); 2471 VectorType initialValueType = scanOp.getInitialValueType(); 2472 int64_t initialValueRank = initialValueType.getRank(); 2473 2474 SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end()); 2475 reductionShape[reductionDim] = 1; 2476 VectorType reductionType = VectorType::get(reductionShape, elType); 2477 SmallVector<int64_t> offsets(destRank, 0); 2478 SmallVector<int64_t> strides(destRank, 1); 2479 SmallVector<int64_t> sizes(destShape.begin(), destShape.end()); 2480 sizes[reductionDim] = 1; 2481 ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes); 2482 ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides); 2483 2484 Value lastOutput, lastInput; 2485 for (int i = 0; i < destShape[reductionDim]; i++) { 2486 offsets[reductionDim] = i; 2487 ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets); 2488 Value input = rewriter.create<vector::ExtractStridedSliceOp>( 2489 loc, reductionType, scanOp.source(), scanOffsets, scanSizes, 2490 scanStrides); 2491 Value output; 2492 if (i == 0) { 2493 if (inclusive) { 2494 output = input; 2495 } else { 2496 if (initialValueRank == 0) { 2497 // ShapeCastOp cannot handle 0-D vectors 2498 output = rewriter.create<vector::BroadcastOp>( 2499 loc, input.getType(), scanOp.initial_value()); 2500 } else { 2501 output = rewriter.create<vector::ShapeCastOp>( 2502 loc, input.getType(), scanOp.initial_value()); 2503 } 2504 } 2505 } else { 2506 Value y = inclusive ? input : lastInput; 2507 output = genOperator(loc, lastOutput, y, scanOp.kind(), rewriter); 2508 assert(output != nullptr); 2509 } 2510 result = rewriter.create<vector::InsertStridedSliceOp>( 2511 loc, output, result, offsets, strides); 2512 lastOutput = output; 2513 lastInput = input; 2514 } 2515 2516 Value reduction; 2517 if (initialValueRank == 0) { 2518 Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0); 2519 reduction = 2520 rewriter.create<vector::BroadcastOp>(loc, initialValueType, v); 2521 } else { 2522 reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType, 2523 lastOutput); 2524 } 2525 2526 rewriter.replaceOp(scanOp, {result, reduction}); 2527 return success(); 2528 } 2529 }; 2530 2531 } // namespace 2532 2533 void mlir::vector::populateVectorMaskMaterializationPatterns( 2534 RewritePatternSet &patterns, bool indexOptimizations) { 2535 patterns.add<VectorCreateMaskOpConversion, 2536 MaterializeTransferMask<vector::TransferReadOp>, 2537 MaterializeTransferMask<vector::TransferWriteOp>>( 2538 patterns.getContext(), indexOptimizations); 2539 } 2540 2541 void mlir::vector::populateShapeCastFoldingPatterns( 2542 RewritePatternSet &patterns) { 2543 patterns.add<ShapeCastOpFolder>(patterns.getContext()); 2544 } 2545 2546 void mlir::vector::populateBubbleVectorBitCastOpPatterns( 2547 RewritePatternSet &patterns) { 2548 patterns.add<BubbleDownVectorBitCastForExtract, 2549 BubbleDownBitCastForStridedSliceExtract, 2550 BubbleUpBitCastForStridedSliceInsert>(patterns.getContext()); 2551 } 2552 2553 void mlir::vector::populateVectorBroadcastLoweringPatterns( 2554 RewritePatternSet &patterns) { 2555 patterns.add<BroadcastOpLowering>(patterns.getContext()); 2556 } 2557 2558 void mlir::vector::populateVectorMaskOpLoweringPatterns( 2559 RewritePatternSet &patterns) { 2560 patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>( 2561 patterns.getContext()); 2562 } 2563 2564 void mlir::vector::populateVectorShapeCastLoweringPatterns( 2565 RewritePatternSet &patterns) { 2566 patterns.add<ShapeCastOp2DDownCastRewritePattern, 2567 ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>( 2568 patterns.getContext()); 2569 } 2570 2571 void mlir::vector::populateVectorContractLoweringPatterns( 2572 RewritePatternSet &patterns, VectorTransformsOptions options) { 2573 patterns.add<OuterProductOpLowering>(patterns.getContext()); 2574 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering, 2575 ContractionOpToOuterProductOpLowering>(options, 2576 patterns.getContext()); 2577 } 2578 2579 void mlir::vector::populateVectorTransposeLoweringPatterns( 2580 RewritePatternSet &patterns, VectorTransformsOptions options) { 2581 patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>( 2582 options, patterns.getContext()); 2583 } 2584 2585 void mlir::vector::populateVectorReductionToContractPatterns( 2586 RewritePatternSet &patterns) { 2587 patterns.add<MultiReduceToContract, CombineContractBroadcast, 2588 CombineContractTranspose>(patterns.getContext()); 2589 } 2590 2591 void mlir::vector:: 2592 populateVectorTransferCollapseInnerMostContiguousDimsPatterns( 2593 RewritePatternSet &patterns) { 2594 patterns.add<DropInnerMostUnitDims>(patterns.getContext()); 2595 } 2596 2597 void mlir::vector::populateVectorTransferLoweringPatterns( 2598 RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) { 2599 patterns.add<TransferReadToVectorLoadLowering, 2600 TransferWriteToVectorStoreLowering>(patterns.getContext(), 2601 maxTransferRank); 2602 patterns 2603 .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>( 2604 patterns.getContext()); 2605 } 2606 2607 void mlir::vector::populateVectorScanLoweringPatterns( 2608 RewritePatternSet &patterns) { 2609 patterns.add<ScanToArithOps>(patterns.getContext()); 2610 } 2611