1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites as 1->N patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <type_traits> 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SCF/SCF.h" 20 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 21 22 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 23 #include "mlir/IR/ImplicitLocOpBuilder.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/Interfaces/VectorInterfaces.h" 27 28 #include "llvm/ADT/DenseSet.h" 29 #include "llvm/ADT/MapVector.h" 30 #include "llvm/ADT/STLExtras.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/Support/Debug.h" 33 #include "llvm/Support/raw_ostream.h" 34 35 #define DEBUG_TYPE "vector-to-vector" 36 37 using namespace mlir; 38 using namespace mlir::vector; 39 40 // Helper to find an index in an affine map. 41 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) { 42 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 43 int64_t idx = map.getDimPosition(i); 44 if (idx == index) 45 return i; 46 } 47 return None; 48 } 49 50 // Helper to construct iterator types with one index removed. 51 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes, 52 int64_t index) { 53 SmallVector<Attribute, 4> results; 54 for (const auto &it : llvm::enumerate(iteratorTypes)) { 55 int64_t idx = it.index(); 56 if (idx == index) 57 continue; 58 results.push_back(it.value()); 59 } 60 return results; 61 } 62 63 // Helper to construct an affine map with one index removed. 64 static AffineMap adjustMap(AffineMap map, int64_t index, 65 PatternRewriter &rewriter) { 66 auto *ctx = rewriter.getContext(); 67 SmallVector<AffineExpr, 4> results; 68 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 69 int64_t idx = map.getDimPosition(i); 70 if (idx == index) 71 continue; 72 // Re-insert remaining indices, but renamed when occurring 73 // after the removed index. 74 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); 75 results.push_back(targetExpr); 76 } 77 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); 78 } 79 80 // Helper method to possibly drop a dimension in a load. 81 // TODO 82 static Value reshapeLoad(Location loc, Value val, VectorType type, 83 int64_t index, int64_t pos, 84 PatternRewriter &rewriter) { 85 if (index == -1) 86 return val; 87 Type lowType = VectorType::Builder(type).dropDim(0); 88 // At extraction dimension? 89 if (index == 0) { 90 auto posAttr = rewriter.getI64ArrayAttr(pos); 91 return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr); 92 } 93 // Unroll leading dimensions. 94 VectorType vType = lowType.cast<VectorType>(); 95 Type resType = VectorType::Builder(type).dropDim(index); 96 auto resVectorType = resType.cast<VectorType>(); 97 Value result = rewriter.create<arith::ConstantOp>( 98 loc, resVectorType, rewriter.getZeroAttr(resVectorType)); 99 for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { 100 auto posAttr = rewriter.getI64ArrayAttr(d); 101 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr); 102 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); 103 result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, 104 posAttr); 105 } 106 return result; 107 } 108 109 // Helper method to possibly drop a dimension in a store. 110 // TODO 111 static Value reshapeStore(Location loc, Value val, Value result, 112 VectorType type, int64_t index, int64_t pos, 113 PatternRewriter &rewriter) { 114 // Unmodified? 115 if (index == -1) 116 return val; 117 // At insertion dimension? 118 if (index == 0) { 119 auto posAttr = rewriter.getI64ArrayAttr(pos); 120 return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr); 121 } 122 // Unroll leading dimensions. 123 Type lowType = VectorType::Builder(type).dropDim(0); 124 VectorType vType = lowType.cast<VectorType>(); 125 Type insType = VectorType::Builder(vType).dropDim(0); 126 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { 127 auto posAttr = rewriter.getI64ArrayAttr(d); 128 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr); 129 Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr); 130 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); 131 result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr); 132 } 133 return result; 134 } 135 136 template <typename IntType> 137 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) { 138 return llvm::to_vector<4>(llvm::map_range( 139 arrayAttr.getAsRange<IntegerAttr>(), 140 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); 141 } 142 143 namespace { 144 145 /// ShapeCastOpFolder folds cancelling ShapeCastOps away. 146 // 147 // Example: 148 // 149 // The following MLIR with cancelling ShapeCastOps: 150 // 151 // %0 = source : vector<5x4x2xf32> 152 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> 153 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> 154 // %3 = user %2 : vector<5x4x2xf32> 155 // 156 // Should canonicalize to the following: 157 // 158 // %0 = source : vector<5x4x2xf32> 159 // %1 = user %0 : vector<5x4x2xf32> 160 // 161 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> { 162 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 163 164 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, 165 PatternRewriter &rewriter) const override { 166 // Check if 'shapeCastOp' has vector source/result type. 167 auto sourceVectorType = 168 shapeCastOp.source().getType().dyn_cast_or_null<VectorType>(); 169 auto resultVectorType = 170 shapeCastOp.result().getType().dyn_cast_or_null<VectorType>(); 171 if (!sourceVectorType || !resultVectorType) 172 return failure(); 173 174 // Check if shape cast op source operand is also a shape cast op. 175 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>( 176 shapeCastOp.source().getDefiningOp()); 177 if (!sourceShapeCastOp) 178 return failure(); 179 auto operandSourceVectorType = 180 sourceShapeCastOp.source().getType().cast<VectorType>(); 181 auto operandResultVectorType = sourceShapeCastOp.getType(); 182 183 // Check if shape cast operations invert each other. 184 if (operandSourceVectorType != resultVectorType || 185 operandResultVectorType != sourceVectorType) 186 return failure(); 187 188 rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source()); 189 return success(); 190 } 191 }; 192 193 /// Progressive lowering of BroadcastOp. 194 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> { 195 public: 196 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern; 197 198 LogicalResult matchAndRewrite(vector::BroadcastOp op, 199 PatternRewriter &rewriter) const override { 200 auto loc = op.getLoc(); 201 VectorType dstType = op.getVectorType(); 202 VectorType srcType = op.getSourceType().dyn_cast<VectorType>(); 203 Type eltType = dstType.getElementType(); 204 205 // Scalar to any vector can use splat. 206 if (!srcType) { 207 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.source()); 208 return success(); 209 } 210 211 // Determine rank of source and destination. 212 int64_t srcRank = srcType.getRank(); 213 int64_t dstRank = dstType.getRank(); 214 215 // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. 216 if (srcRank <= 1 && dstRank == 1) { 217 Value ext; 218 if (srcRank == 0) 219 ext = rewriter.create<vector::ExtractElementOp>(loc, op.source()); 220 else 221 ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0); 222 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext); 223 return success(); 224 } 225 226 // Duplicate this rank. 227 // For example: 228 // %x = broadcast %y : k-D to n-D, k < n 229 // becomes: 230 // %b = broadcast %y : k-D to (n-1)-D 231 // %x = [%b,%b,%b,%b] : n-D 232 // becomes: 233 // %b = [%y,%y] : (n-1)-D 234 // %x = [%b,%b,%b,%b] : n-D 235 if (srcRank < dstRank) { 236 // Duplication. 237 VectorType resType = 238 VectorType::get(dstType.getShape().drop_front(), eltType); 239 Value bcst = 240 rewriter.create<vector::BroadcastOp>(loc, resType, op.source()); 241 Value result = rewriter.create<arith::ConstantOp>( 242 loc, dstType, rewriter.getZeroAttr(dstType)); 243 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 244 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 245 rewriter.replaceOp(op, result); 246 return success(); 247 } 248 249 // Find non-matching dimension, if any. 250 assert(srcRank == dstRank); 251 int64_t m = -1; 252 for (int64_t r = 0; r < dstRank; r++) 253 if (srcType.getDimSize(r) != dstType.getDimSize(r)) { 254 m = r; 255 break; 256 } 257 258 // All trailing dimensions are the same. Simply pass through. 259 if (m == -1) { 260 rewriter.replaceOp(op, op.source()); 261 return success(); 262 } 263 264 // Any non-matching dimension forces a stretch along this rank. 265 // For example: 266 // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32> 267 // becomes: 268 // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32> 269 // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32> 270 // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32> 271 // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32> 272 // %x = [%a,%b,%c,%d] 273 // becomes: 274 // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32> 275 // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32> 276 // %a = [%u, %v] 277 // .. 278 // %x = [%a,%b,%c,%d] 279 VectorType resType = 280 VectorType::get(dstType.getShape().drop_front(), eltType); 281 Value result = rewriter.create<arith::ConstantOp>( 282 loc, dstType, rewriter.getZeroAttr(dstType)); 283 if (m == 0) { 284 // Stetch at start. 285 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0); 286 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 287 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 288 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 289 } else { 290 // Stetch not at start. 291 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) { 292 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d); 293 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 294 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 295 } 296 } 297 rewriter.replaceOp(op, result); 298 return success(); 299 } 300 }; 301 302 /// Progressive lowering of TransposeOp. 303 /// One: 304 /// %x = vector.transpose %y, [1, 0] 305 /// is replaced by: 306 /// %z = arith.constant dense<0.000000e+00> 307 /// %0 = vector.extract %y[0, 0] 308 /// %1 = vector.insert %0, %z [0, 0] 309 /// .. 310 /// %x = vector.insert .., .. [.., ..] 311 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { 312 public: 313 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 314 315 TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions, 316 MLIRContext *context) 317 : OpRewritePattern<vector::TransposeOp>(context), 318 vectorTransformOptions(vectorTransformOptions) {} 319 320 LogicalResult matchAndRewrite(vector::TransposeOp op, 321 PatternRewriter &rewriter) const override { 322 auto loc = op.getLoc(); 323 324 VectorType resType = op.getResultType(); 325 326 // Set up convenience transposition table. 327 SmallVector<int64_t, 4> transp; 328 for (auto attr : op.transp()) 329 transp.push_back(attr.cast<IntegerAttr>().getInt()); 330 331 if (vectorTransformOptions.vectorTransposeLowering == 332 vector::VectorTransposeLowering::Shuffle && 333 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) 334 return rewriter.notifyMatchFailure( 335 op, "Options specifies lowering to shuffle"); 336 337 // Handle a true 2-D matrix transpose differently when requested. 338 if (vectorTransformOptions.vectorTransposeLowering == 339 vector::VectorTransposeLowering::Flat && 340 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { 341 Type flattenedType = 342 VectorType::get(resType.getNumElements(), resType.getElementType()); 343 auto matrix = 344 rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector()); 345 auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); 346 auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); 347 Value trans = rewriter.create<vector::FlatTransposeOp>( 348 loc, flattenedType, matrix, rows, columns); 349 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans); 350 return success(); 351 } 352 353 // Generate fully unrolled extract/insert ops. 354 Value result = rewriter.create<arith::ConstantOp>( 355 loc, resType, rewriter.getZeroAttr(resType)); 356 SmallVector<int64_t, 4> lhs(transp.size(), 0); 357 SmallVector<int64_t, 4> rhs(transp.size(), 0); 358 rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs, 359 op.vector(), result, rewriter)); 360 return success(); 361 } 362 363 private: 364 // Builds the indices arrays for the lhs and rhs. Generates the extract/insert 365 // operation when al ranks are exhausted. 366 Value expandIndices(Location loc, VectorType resType, int64_t pos, 367 SmallVector<int64_t, 4> &transp, 368 SmallVector<int64_t, 4> &lhs, 369 SmallVector<int64_t, 4> &rhs, Value input, Value result, 370 PatternRewriter &rewriter) const { 371 if (pos >= resType.getRank()) { 372 auto ridx = rewriter.getI64ArrayAttr(rhs); 373 auto lidx = rewriter.getI64ArrayAttr(lhs); 374 Type eltType = resType.getElementType(); 375 Value e = rewriter.create<vector::ExtractOp>(loc, eltType, input, ridx); 376 return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx); 377 } 378 for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) { 379 lhs[pos] = d; 380 rhs[transp[pos]] = d; 381 result = expandIndices(loc, resType, pos + 1, transp, lhs, rhs, input, 382 result, rewriter); 383 } 384 return result; 385 } 386 387 /// Options to control the vector patterns. 388 vector::VectorTransformsOptions vectorTransformOptions; 389 }; 390 391 /// Rewrite a 2-D vector.transpose as a sequence of: 392 /// vector.shape_cast 2D -> 1D 393 /// vector.shuffle 394 /// vector.shape_cast 1D -> 2D 395 class TransposeOp2DToShuffleLowering 396 : public OpRewritePattern<vector::TransposeOp> { 397 public: 398 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 399 400 TransposeOp2DToShuffleLowering( 401 vector::VectorTransformsOptions vectorTransformOptions, 402 MLIRContext *context) 403 : OpRewritePattern<vector::TransposeOp>(context), 404 vectorTransformOptions(vectorTransformOptions) {} 405 406 LogicalResult matchAndRewrite(vector::TransposeOp op, 407 PatternRewriter &rewriter) const override { 408 auto loc = op.getLoc(); 409 410 VectorType srcType = op.getVectorType(); 411 if (srcType.getRank() != 2) 412 return rewriter.notifyMatchFailure(op, "Not a 2D transpose"); 413 414 SmallVector<int64_t, 4> transp; 415 for (auto attr : op.transp()) 416 transp.push_back(attr.cast<IntegerAttr>().getInt()); 417 if (transp[0] != 1 && transp[1] != 0) 418 return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation"); 419 420 if (vectorTransformOptions.vectorTransposeLowering != 421 VectorTransposeLowering::Shuffle) 422 return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle"); 423 424 int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); 425 Value casted = rewriter.create<vector::ShapeCastOp>( 426 loc, VectorType::get({m * n}, srcType.getElementType()), op.vector()); 427 SmallVector<int64_t> mask; 428 mask.reserve(m * n); 429 for (int64_t j = 0; j < n; ++j) 430 for (int64_t i = 0; i < m; ++i) 431 mask.push_back(i * n + j); 432 433 Value shuffled = 434 rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask); 435 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(), 436 shuffled); 437 438 return success(); 439 } 440 441 private: 442 /// Options to control the vector patterns. 443 vector::VectorTransformsOptions vectorTransformOptions; 444 }; 445 446 /// Progressive lowering of OuterProductOp. 447 /// One: 448 /// %x = vector.outerproduct %lhs, %rhs, %acc 449 /// is replaced by: 450 /// %z = zero-result 451 /// %0 = vector.extract %lhs[0] 452 /// %1 = vector.broadcast %0 453 /// %2 = vector.extract %acc[0] 454 /// %3 = vector.fma %1, %rhs, %2 455 /// %4 = vector.insert %3, %z[0] 456 /// .. 457 /// %x = vector.insert %.., %..[N-1] 458 /// 459 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { 460 public: 461 using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern; 462 463 LogicalResult matchAndRewrite(vector::OuterProductOp op, 464 PatternRewriter &rewriter) const override { 465 auto loc = op.getLoc(); 466 467 VectorType lhsType = op.getOperandVectorTypeLHS(); 468 VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>(); 469 VectorType resType = op.getVectorType(); 470 Type eltType = resType.getElementType(); 471 bool isInt = eltType.isa<IntegerType, IndexType>(); 472 Value acc = (op.acc().empty()) ? nullptr : op.acc()[0]; 473 vector::CombiningKind kind = op.kind(); 474 475 if (!rhsType) { 476 // Special case: AXPY operation. 477 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs()); 478 Optional<Value> mult = 479 isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter) 480 : genMultF(loc, op.lhs(), b, acc, kind, rewriter); 481 if (!mult.hasValue()) 482 return failure(); 483 rewriter.replaceOp(op, mult.getValue()); 484 return success(); 485 } 486 487 Value result = rewriter.create<arith::ConstantOp>( 488 loc, resType, rewriter.getZeroAttr(resType)); 489 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { 490 auto pos = rewriter.getI64ArrayAttr(d); 491 Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos); 492 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x); 493 Value r = nullptr; 494 if (acc) 495 r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos); 496 Optional<Value> m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter) 497 : genMultF(loc, a, op.rhs(), r, kind, rewriter); 498 if (!m.hasValue()) 499 return failure(); 500 result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(), 501 result, pos); 502 } 503 rewriter.replaceOp(op, result); 504 return success(); 505 } 506 507 private: 508 static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc, 509 vector::CombiningKind kind, 510 PatternRewriter &rewriter) { 511 using vector::CombiningKind; 512 513 auto mul = rewriter.create<arith::MulIOp>(loc, x, y); 514 if (!acc) 515 return Optional<Value>(mul); 516 517 Value combinedResult; 518 switch (kind) { 519 case CombiningKind::ADD: 520 combinedResult = rewriter.create<arith::AddIOp>(loc, mul, acc); 521 break; 522 case CombiningKind::MUL: 523 combinedResult = rewriter.create<arith::MulIOp>(loc, mul, acc); 524 break; 525 case CombiningKind::MINUI: 526 combinedResult = rewriter.create<arith::MinUIOp>(loc, mul, acc); 527 break; 528 case CombiningKind::MINSI: 529 combinedResult = rewriter.create<arith::MinSIOp>(loc, mul, acc); 530 break; 531 case CombiningKind::MAXUI: 532 combinedResult = rewriter.create<arith::MaxUIOp>(loc, mul, acc); 533 break; 534 case CombiningKind::MAXSI: 535 combinedResult = rewriter.create<arith::MaxSIOp>(loc, mul, acc); 536 break; 537 case CombiningKind::AND: 538 combinedResult = rewriter.create<arith::AndIOp>(loc, mul, acc); 539 break; 540 case CombiningKind::OR: 541 combinedResult = rewriter.create<arith::OrIOp>(loc, mul, acc); 542 break; 543 case CombiningKind::XOR: 544 combinedResult = rewriter.create<arith::XOrIOp>(loc, mul, acc); 545 break; 546 case CombiningKind::MINF: // Only valid for floating point types. 547 case CombiningKind::MAXF: // Only valid for floating point types. 548 return Optional<Value>(); 549 } 550 return Optional<Value>(combinedResult); 551 } 552 553 static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc, 554 vector::CombiningKind kind, 555 PatternRewriter &rewriter) { 556 using vector::CombiningKind; 557 558 // Special case for fused multiply-add. 559 if (acc && kind == CombiningKind::ADD) { 560 return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc)); 561 } 562 563 auto mul = rewriter.create<arith::MulFOp>(loc, x, y); 564 565 if (!acc) 566 return Optional<Value>(mul); 567 568 Value combinedResult; 569 switch (kind) { 570 case CombiningKind::MUL: 571 combinedResult = rewriter.create<arith::MulFOp>(loc, mul, acc); 572 break; 573 case CombiningKind::MINF: 574 combinedResult = rewriter.create<arith::MinFOp>(loc, mul, acc); 575 break; 576 case CombiningKind::MAXF: 577 combinedResult = rewriter.create<arith::MaxFOp>(loc, mul, acc); 578 break; 579 case CombiningKind::ADD: // Already handled this special case above. 580 case CombiningKind::AND: // Only valid for integer types. 581 case CombiningKind::MINUI: // Only valid for integer types. 582 case CombiningKind::MINSI: // Only valid for integer types. 583 case CombiningKind::MAXUI: // Only valid for integer types. 584 case CombiningKind::MAXSI: // Only valid for integer types. 585 case CombiningKind::OR: // Only valid for integer types. 586 case CombiningKind::XOR: // Only valid for integer types. 587 return Optional<Value>(); 588 } 589 return Optional<Value>(combinedResult); 590 } 591 }; 592 593 /// Progressive lowering of ConstantMaskOp. 594 /// One: 595 /// %x = vector.constant_mask [a,b] 596 /// is replaced by: 597 /// %z = zero-result 598 /// %l = vector.constant_mask [b] 599 /// %4 = vector.insert %l, %z[0] 600 /// .. 601 /// %x = vector.insert %l, %..[a-1] 602 /// until a one-dimensional vector is reached. All these operations 603 /// will be folded at LLVM IR level. 604 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { 605 public: 606 using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern; 607 608 LogicalResult matchAndRewrite(vector::ConstantMaskOp op, 609 PatternRewriter &rewriter) const override { 610 auto loc = op.getLoc(); 611 auto dstType = op.getType(); 612 auto eltType = dstType.getElementType(); 613 auto dimSizes = op.mask_dim_sizes(); 614 int64_t rank = dstType.getRank(); 615 616 if (rank == 0) { 617 assert(dimSizes.size() == 1 && 618 "Expected exactly one dim size for a 0-D vector"); 619 bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1; 620 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 621 op, dstType, 622 DenseIntElementsAttr::get( 623 VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()), 624 ArrayRef<bool>{value})); 625 return success(); 626 } 627 628 int64_t trueDim = std::min(dstType.getDimSize(0), 629 dimSizes[0].cast<IntegerAttr>().getInt()); 630 631 if (rank == 1) { 632 // Express constant 1-D case in explicit vector form: 633 // [T,..,T,F,..,F]. 634 SmallVector<bool, 4> values(dstType.getDimSize(0)); 635 for (int64_t d = 0; d < trueDim; d++) 636 values[d] = true; 637 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 638 op, dstType, rewriter.getBoolVectorAttr(values)); 639 return success(); 640 } 641 642 VectorType lowType = 643 VectorType::get(dstType.getShape().drop_front(), eltType); 644 SmallVector<int64_t, 4> newDimSizes; 645 for (int64_t r = 1; r < rank; r++) 646 newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt()); 647 Value trueVal = rewriter.create<vector::ConstantMaskOp>( 648 loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); 649 Value result = rewriter.create<arith::ConstantOp>( 650 loc, dstType, rewriter.getZeroAttr(dstType)); 651 for (int64_t d = 0; d < trueDim; d++) { 652 auto pos = rewriter.getI64ArrayAttr(d); 653 result = 654 rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos); 655 } 656 rewriter.replaceOp(op, result); 657 return success(); 658 } 659 }; 660 661 /// Progressive lowering of CreateMaskOp. 662 /// One: 663 /// %x = vector.create_mask %a, ... : vector<dx...> 664 /// is replaced by: 665 /// %l = vector.create_mask ... : vector<...> ; one lower rank 666 /// %0 = arith.cmpi "slt", %ci, %a | 667 /// %1 = select %0, %l, %zeroes | 668 /// %r = vector.insert %1, %pr [i] | d-times 669 /// %x = .... 670 /// until a one-dimensional vector is reached. 671 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { 672 public: 673 using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern; 674 675 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 676 PatternRewriter &rewriter) const override { 677 auto dstType = op.getResult().getType().cast<VectorType>(); 678 int64_t rank = dstType.getRank(); 679 if (rank <= 1) 680 return rewriter.notifyMatchFailure( 681 op, "0-D and 1-D vectors are handled separately"); 682 683 auto loc = op.getLoc(); 684 auto eltType = dstType.getElementType(); 685 int64_t dim = dstType.getDimSize(0); 686 Value idx = op.getOperand(0); 687 688 VectorType lowType = 689 VectorType::get(dstType.getShape().drop_front(), eltType); 690 Value trueVal = rewriter.create<vector::CreateMaskOp>( 691 loc, lowType, op.getOperands().drop_front()); 692 Value falseVal = rewriter.create<arith::ConstantOp>( 693 loc, lowType, rewriter.getZeroAttr(lowType)); 694 Value result = rewriter.create<arith::ConstantOp>( 695 loc, dstType, rewriter.getZeroAttr(dstType)); 696 for (int64_t d = 0; d < dim; d++) { 697 Value bnd = 698 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d)); 699 Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 700 bnd, idx); 701 Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal); 702 auto pos = rewriter.getI64ArrayAttr(d); 703 result = 704 rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos); 705 } 706 rewriter.replaceOp(op, result); 707 return success(); 708 } 709 }; 710 711 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D 712 /// vectors progressively on the way to target llvm.matrix intrinsics. 713 /// This iterates over the most major dimension of the 2-D vector and performs 714 /// rewrites into: 715 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D 716 class ShapeCastOp2DDownCastRewritePattern 717 : public OpRewritePattern<vector::ShapeCastOp> { 718 public: 719 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 720 721 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 722 PatternRewriter &rewriter) const override { 723 auto sourceVectorType = op.getSourceVectorType(); 724 auto resultVectorType = op.getResultVectorType(); 725 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) 726 return failure(); 727 728 auto loc = op.getLoc(); 729 Value desc = rewriter.create<arith::ConstantOp>( 730 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 731 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; 732 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { 733 Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i); 734 desc = rewriter.create<vector::InsertStridedSliceOp>( 735 loc, vec, desc, 736 /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); 737 } 738 rewriter.replaceOp(op, desc); 739 return success(); 740 } 741 }; 742 743 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D 744 /// vectors progressively. 745 /// This iterates over the most major dimension of the 2-D vector and performs 746 /// rewrites into: 747 /// vector.extract_strided_slice from 1-D + vector.insert into 2-D 748 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. 749 class ShapeCastOp2DUpCastRewritePattern 750 : public OpRewritePattern<vector::ShapeCastOp> { 751 public: 752 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 753 754 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 755 PatternRewriter &rewriter) const override { 756 auto sourceVectorType = op.getSourceVectorType(); 757 auto resultVectorType = op.getResultVectorType(); 758 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) 759 return failure(); 760 761 auto loc = op.getLoc(); 762 Value desc = rewriter.create<arith::ConstantOp>( 763 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 764 unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; 765 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { 766 Value vec = rewriter.create<vector::ExtractStridedSliceOp>( 767 loc, op.source(), /*offsets=*/i * mostMinorVectorSize, 768 /*sizes=*/mostMinorVectorSize, 769 /*strides=*/1); 770 desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i); 771 } 772 rewriter.replaceOp(op, desc); 773 return success(); 774 } 775 }; 776 777 // We typically should not lower general shape cast operations into data 778 // movement instructions, since the assumption is that these casts are 779 // optimized away during progressive lowering. For completeness, however, 780 // we fall back to a reference implementation that moves all elements 781 // into the right place if we get here. 782 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { 783 public: 784 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 785 786 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 787 PatternRewriter &rewriter) const override { 788 Location loc = op.getLoc(); 789 auto sourceVectorType = op.getSourceVectorType(); 790 auto resultVectorType = op.getResultVectorType(); 791 792 // Special case 2D/1D lowerings with better implementations. 793 // TODO: make is ND/1D to allow generic ND->1D->MD. 794 int64_t srcRank = sourceVectorType.getRank(); 795 int64_t resRank = resultVectorType.getRank(); 796 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2)) 797 return failure(); 798 799 // Generic ShapeCast lowering path goes all the way down to unrolled scalar 800 // extract/insert chains. 801 // TODO: consider evolving the semantics to only allow 1D source or dest and 802 // drop this potentially very expensive lowering. 803 // Compute number of elements involved in the reshape. 804 int64_t numElts = 1; 805 for (int64_t r = 0; r < srcRank; r++) 806 numElts *= sourceVectorType.getDimSize(r); 807 // Replace with data movement operations: 808 // x[0,0,0] = y[0,0] 809 // x[0,0,1] = y[0,1] 810 // x[0,1,0] = y[0,2] 811 // etc., incrementing the two index vectors "row-major" 812 // within the source and result shape. 813 SmallVector<int64_t, 4> srcIdx(srcRank); 814 SmallVector<int64_t, 4> resIdx(resRank); 815 Value result = rewriter.create<arith::ConstantOp>( 816 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 817 for (int64_t i = 0; i < numElts; i++) { 818 if (i != 0) { 819 incIdx(srcIdx, sourceVectorType, srcRank - 1); 820 incIdx(resIdx, resultVectorType, resRank - 1); 821 } 822 Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx); 823 result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx); 824 } 825 rewriter.replaceOp(op, result); 826 return success(); 827 } 828 829 private: 830 static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) { 831 assert(0 <= r && r < tp.getRank()); 832 if (++idx[r] == tp.getDimSize(r)) { 833 idx[r] = 0; 834 incIdx(idx, tp, r - 1); 835 } 836 } 837 }; 838 839 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp. 840 /// Ex: 841 /// ``` 842 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> 843 /// %1 = vector.multi_reduction add, %0 [1] 844 /// : vector<8x32x16xf32> to vector<8x16xf32> 845 /// ``` 846 /// Gets converted to: 847 /// ``` 848 /// %1 = vector.contract {indexing_maps = [ 849 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 850 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 851 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 852 /// iterator_types = ["parallel", "parallel", "reduction"], 853 /// kind = add} %0, %arg1, %cst_f0 854 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 855 /// ``` 856 struct MultiReduceToContract 857 : public OpRewritePattern<vector::MultiDimReductionOp> { 858 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 859 860 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, 861 PatternRewriter &rewriter) const override { 862 if (reduceOp.kind() != vector::CombiningKind::ADD) 863 return failure(); 864 Operation *mulOp = reduceOp.source().getDefiningOp(); 865 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp)) 866 return failure(); 867 SmallVector<bool> reductionMask = reduceOp.getReductionMask(); 868 auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size()); 869 SmallVector<AffineExpr> exprs; 870 SmallVector<StringRef> iteratorTypes; 871 for (const auto &isReduceDim : llvm::enumerate(reductionMask)) { 872 if (!isReduceDim.value()) { 873 iteratorTypes.push_back(getParallelIteratorTypeName()); 874 exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index())); 875 } else { 876 iteratorTypes.push_back(getReductionIteratorTypeName()); 877 } 878 } 879 auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(), 880 /*symCount=*/0, exprs, reduceOp.getContext()); 881 Value zero = rewriter.create<arith::ConstantOp>( 882 reduceOp.getLoc(), reduceOp.getDestType(), 883 rewriter.getZeroAttr(reduceOp.getDestType())); 884 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>( 885 reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero, 886 rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}), 887 rewriter.getStrArrayAttr(iteratorTypes)); 888 return success(); 889 } 890 }; 891 892 /// Merge TransposeOp into ContractionOp user. 893 /// Ex: 894 /// ``` 895 /// %0 = vector.transpose %arg0, [2, 0, 1] 896 /// : vector<32x16x8xf32> to vector<8x32x16xf32> 897 /// %1 = vector.contract {indexing_maps = [ 898 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 899 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 900 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 901 /// iterator_types = ["parallel", "parallel", "reduction"], 902 /// kind = add} %0, %arg1, %cst_f0 903 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 904 /// ``` 905 /// Gets converted to: 906 /// ``` 907 /// %1 = vector.contract {indexing_maps = [ 908 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>, 909 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 910 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 911 /// iterator_types = ["parallel", "parallel", "reduction"], 912 /// kind = add} %arg0, %arg1, %cst_f0 913 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 914 /// ``` 915 struct CombineContractTranspose 916 : public OpRewritePattern<vector::ContractionOp> { 917 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 918 919 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 920 PatternRewriter &rewriter) const override { 921 SmallVector<AffineMap, 4> maps = 922 llvm::to_vector<4>(contractOp.getIndexingMaps()); 923 Value lhs = contractOp.lhs(); 924 Value rhs = contractOp.rhs(); 925 size_t index = 0; 926 bool changed = false; 927 for (Value *operand : {&lhs, &rhs}) { 928 AffineMap &map = maps[index++]; 929 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>(); 930 if (!transposeOp) 931 continue; 932 SmallVector<int64_t> perm; 933 transposeOp.getTransp(perm); 934 AffineMap permutationMap = AffineMap::getPermutationMap( 935 extractVector<unsigned>(transposeOp.transp()), 936 contractOp.getContext()); 937 map = inversePermutation(permutationMap).compose(map); 938 *operand = transposeOp.vector(); 939 changed = true; 940 } 941 if (!changed) 942 return failure(); 943 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 944 contractOp, lhs, rhs, contractOp.acc(), 945 rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); 946 return success(); 947 } 948 }; 949 950 /// Merge BroadcastOp into ContractionOp user. 951 /// Ex: 952 /// ``` 953 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> 954 /// %1 = vector.contract {indexing_maps = [ 955 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 956 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 957 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 958 /// iterator_types = ["parallel", "parallel", "reduction"], 959 /// kind = add} %0, %arg1, %cst_f0 960 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 961 /// ``` 962 /// Gets converted to: 963 /// ``` 964 /// %1 = vector.contract {indexing_maps = [ 965 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 966 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 967 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 968 /// iterator_types = ["parallel", "parallel", "reduction"], 969 /// kind = add} %arg0, %arg1, %cst_f0 970 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 971 /// ``` 972 struct CombineContractBroadcast 973 : public OpRewritePattern<vector::ContractionOp> { 974 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 975 976 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 977 PatternRewriter &rewriter) const override { 978 SmallVector<AffineMap, 4> maps = 979 llvm::to_vector<4>(contractOp.getIndexingMaps()); 980 Value lhs = contractOp.lhs(); 981 Value rhs = contractOp.rhs(); 982 size_t index = 0; 983 bool changed = false; 984 for (Value *operand : {&lhs, &rhs}) { 985 AffineMap &map = maps[index++]; 986 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>(); 987 if (!broadcast) 988 continue; 989 // contractionOp can only take vector as operands. 990 auto srcType = broadcast.getSourceType().dyn_cast<VectorType>(); 991 if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank()) 992 continue; 993 int64_t rankDiff = 994 broadcast.getVectorType().getRank() - srcType.getRank(); 995 bool innerDimBroadcast = false; 996 SmallVector<AffineExpr> originalDims; 997 for (const auto &dim : llvm::enumerate(srcType.getShape())) { 998 if (dim.value() != 999 broadcast.getVectorType().getDimSize(rankDiff + dim.index())) { 1000 innerDimBroadcast = true; 1001 break; 1002 } 1003 originalDims.push_back( 1004 rewriter.getAffineDimExpr(dim.index() + rankDiff)); 1005 } 1006 // Contract doesn't support inner dimension broadcast. Once this is 1007 // relaxed we can remove this case. 1008 if (innerDimBroadcast) 1009 continue; 1010 AffineMap broadcastMap = 1011 AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims, 1012 contractOp.getContext()); 1013 map = broadcastMap.compose(map); 1014 *operand = broadcast.source(); 1015 changed = true; 1016 } 1017 if (!changed) 1018 return failure(); 1019 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 1020 contractOp, lhs, rhs, contractOp.acc(), 1021 rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); 1022 return success(); 1023 } 1024 }; 1025 1026 } // namespace 1027 1028 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using 1029 /// operands `x` and `y`. 1030 static Value createAdd(Location loc, Value x, Value y, bool isInt, 1031 PatternRewriter &rewriter) { 1032 if (isInt) 1033 return rewriter.create<arith::AddIOp>(loc, x, y); 1034 return rewriter.create<arith::AddFOp>(loc, x, y); 1035 } 1036 1037 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using 1038 /// operands `x and `y`. 1039 static Value createMul(Location loc, Value x, Value y, bool isInt, 1040 PatternRewriter &rewriter) { 1041 if (isInt) 1042 return rewriter.create<arith::MulIOp>(loc, x, y); 1043 return rewriter.create<arith::MulFOp>(loc, x, y); 1044 } 1045 1046 namespace mlir { 1047 1048 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1049 /// semantics to: 1050 /// ``` 1051 /// %mta = maybe_transpose 1052 /// %mtb = maybe_transpose 1053 /// %flattened_a = vector.shape_cast %mta 1054 /// %flattened_b = vector.shape_cast %mtb 1055 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 1056 /// %mtd = vector.shape_cast %flattened_d 1057 /// %d = maybe_untranspose %mtd 1058 /// %e = add %c, %d 1059 /// ``` 1060 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 1061 // 1062 /// This only kicks in when VectorTransformsOptions is set to `Matmul`. 1063 /// vector.transpose operations are inserted if the vector.contract op is not a 1064 /// row-major matrix multiply. 1065 LogicalResult 1066 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, 1067 PatternRewriter &rew) const { 1068 // TODO: implement masks 1069 if (llvm::size(op.masks()) != 0) 1070 return failure(); 1071 if (vectorTransformOptions.vectorContractLowering != 1072 vector::VectorContractLowering::Matmul) 1073 return failure(); 1074 if (failed(filter(op))) 1075 return failure(); 1076 1077 auto iteratorTypes = op.iterator_types().getValue(); 1078 if (!isParallelIterator(iteratorTypes[0]) || 1079 !isParallelIterator(iteratorTypes[1]) || 1080 !isReductionIterator(iteratorTypes[2])) 1081 return failure(); 1082 1083 Type elementType = op.getLhsType().getElementType(); 1084 if (!elementType.isIntOrFloat()) 1085 return failure(); 1086 1087 // Perform lhs + rhs transpositions to conform to matmul row-major semantics. 1088 // Bail out if the contraction cannot be put in this form. 1089 MLIRContext *ctx = op.getContext(); 1090 Location loc = op.getLoc(); 1091 AffineExpr m, n, k; 1092 bindDims(rew.getContext(), m, n, k); 1093 // LHS must be A(m, k) or A(k, m). 1094 Value lhs = op.lhs(); 1095 auto lhsMap = op.indexing_maps()[0].cast<AffineMapAttr>().getValue(); 1096 if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) 1097 lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0}); 1098 else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) 1099 return failure(); 1100 1101 // RHS must be B(k, n) or B(n, k). 1102 Value rhs = op.rhs(); 1103 auto rhsMap = op.indexing_maps()[1].cast<AffineMapAttr>().getValue(); 1104 if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) 1105 rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0}); 1106 else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) 1107 return failure(); 1108 1109 // At this point lhs and rhs are in row-major. 1110 VectorType lhsType = lhs.getType().cast<VectorType>(); 1111 VectorType rhsType = rhs.getType().cast<VectorType>(); 1112 int64_t lhsRows = lhsType.getDimSize(0); 1113 int64_t lhsColumns = lhsType.getDimSize(1); 1114 int64_t rhsColumns = rhsType.getDimSize(1); 1115 1116 Type flattenedLHSType = 1117 VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); 1118 lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs); 1119 1120 Type flattenedRHSType = 1121 VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); 1122 rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs); 1123 1124 Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns, 1125 rhsColumns); 1126 mul = rew.create<vector::ShapeCastOp>( 1127 loc, 1128 VectorType::get({lhsRows, rhsColumns}, 1129 getElementTypeOrSelf(op.acc().getType())), 1130 mul); 1131 1132 // ACC must be C(m, n) or C(n, m). 1133 auto accMap = op.indexing_maps()[2].cast<AffineMapAttr>().getValue(); 1134 if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) 1135 mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0}); 1136 else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) 1137 llvm_unreachable("invalid contraction semantics"); 1138 1139 Value res = 1140 elementType.isa<IntegerType>() 1141 ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.acc(), mul)) 1142 : static_cast<Value>(rew.create<arith::AddFOp>(loc, op.acc(), mul)); 1143 1144 rew.replaceOp(op, res); 1145 return success(); 1146 } 1147 1148 namespace { 1149 struct IteratorType { 1150 IteratorType(StringRef strRef) : strRef(strRef) {} 1151 bool isOfType(Attribute attr) const { 1152 auto sAttr = attr.dyn_cast<StringAttr>(); 1153 return sAttr && sAttr.getValue() == strRef; 1154 } 1155 StringRef strRef; 1156 }; 1157 struct Par : public IteratorType { 1158 Par() : IteratorType(getParallelIteratorTypeName()) {} 1159 }; 1160 struct Red : public IteratorType { 1161 Red() : IteratorType(getReductionIteratorTypeName()) {} 1162 }; 1163 1164 /// Generate a vector implementation for matmat, matvec and tmatvec. 1165 /// This unrolls outer-products along the reduction dimension. 1166 struct UnrolledOuterProductGenerator 1167 : public StructuredGenerator<vector::ContractionOp> { 1168 1169 UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op) 1170 : StructuredGenerator<vector::ContractionOp>(builder, op), 1171 kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()), 1172 lhsType(op.getLhsType()) {} 1173 1174 Value t(Value v) { 1175 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1176 return builder.create<vector::TransposeOp>(loc, v, perm); 1177 } 1178 1179 Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) { 1180 assert(reductionSize > 0); 1181 for (int64_t k = 0; k < reductionSize; ++k) { 1182 Value a = builder.create<vector::ExtractOp>(loc, lhs, k); 1183 Value b = builder.create<vector::ExtractOp>(loc, rhs, k); 1184 res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b, 1185 res, kind); 1186 } 1187 return res; 1188 } 1189 1190 /// Two outer parallel, one inner reduction (matmat flavor). 1191 FailureOr<Value> matmat() { 1192 if (!iters({Par(), Par(), Red()})) 1193 return failure(); 1194 // Set up the parallel/reduction structure in the right form. 1195 AffineExpr m, n, k; 1196 bindDims(builder.getContext(), m, n, k); 1197 // Classical row-major matmul: Just permute the lhs. 1198 if (layout({{m, k}, {k, n}, {m, n}})) 1199 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1200 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1201 if (layout({{m, k}, {n, k}, {m, n}})) { 1202 Value tlhs = t(lhs); 1203 return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1)); 1204 } 1205 // No need to permute anything. 1206 if (layout({{k, m}, {k, n}, {m, n}})) 1207 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1208 // Just permute the rhs. 1209 if (layout({{k, m}, {n, k}, {m, n}})) 1210 return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0)); 1211 // Transposed output: swap RHS and LHS. 1212 // Classical row-major matmul: permute the lhs. 1213 if (layout({{m, k}, {k, n}, {n, m}})) 1214 return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1)); 1215 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1216 if (layout({{m, k}, {n, k}, {n, m}})) { 1217 Value trhs = t(rhs); 1218 return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1)); 1219 } 1220 if (layout({{k, m}, {k, n}, {n, m}})) 1221 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1222 if (layout({{k, m}, {n, k}, {n, m}})) 1223 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1224 return failure(); 1225 } 1226 1227 /// One outer parallel, one inner reduction (matvec flavor) 1228 FailureOr<Value> matvec() { 1229 if (!iters({Par(), Red()})) 1230 return failure(); 1231 AffineExpr m, k; 1232 bindDims(builder.getContext(), m, k); 1233 1234 // Case mat-vec: transpose. 1235 if (layout({{m, k}, {k}, {m}})) 1236 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1237 // Case mat-trans-vec: ready to go. 1238 if (layout({{k, m}, {k}, {m}})) 1239 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1240 // Case vec-mat: swap and transpose. 1241 if (layout({{k}, {m, k}, {m}})) 1242 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1243 // Case vec-mat-trans: swap and ready to go. 1244 if (layout({{k}, {k, m}, {m}})) 1245 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1246 return failure(); 1247 } 1248 1249 // 1250 // One outer reduction, one inner parallel (tmatvec flavor) 1251 // 1252 FailureOr<Value> tmatvec() { 1253 if (!iters({Red(), Par()})) 1254 return failure(); 1255 AffineExpr k, m; 1256 bindDims(builder.getContext(), k, m); 1257 1258 // Case mat-vec: transpose. 1259 if (layout({{m, k}, {k}, {m}})) 1260 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1261 // Case mat-trans-vec: ready to go. 1262 if (layout({{k, m}, {k}, {m}})) 1263 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1264 // Case vec-mat: swap and transpose. 1265 if (layout({{k}, {m, k}, {m}})) 1266 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1267 // Case vec-mat-trans: swap and ready to go. 1268 if (layout({{k}, {k, m}, {m}})) 1269 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1270 return failure(); 1271 } 1272 1273 private: 1274 vector::CombiningKind kind; 1275 Value lhs, rhs, res; 1276 VectorType lhsType; 1277 }; 1278 } // namespace 1279 1280 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1281 /// semantics to a reduction_size-unrolled sequence: 1282 /// ``` 1283 /// %at = vector.transpose %a, [1, 0] 1284 /// %bRow0 = vector.extract %b[0] 1285 /// %atRow0 = vector.extract %at[0] 1286 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 1287 /// ... 1288 /// %bRowK = vector.extract %b[K] 1289 /// %atRowK = vector.extract %at[K] 1290 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 1291 /// ``` 1292 /// 1293 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but 1294 /// otherwise supports any layout permutation of the matrix-multiply. 1295 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( 1296 vector::ContractionOp op, PatternRewriter &rewriter) const { 1297 // TODO: implement masks 1298 if (llvm::size(op.masks()) != 0) 1299 return failure(); 1300 1301 if (vectorTransformOptions.vectorContractLowering != 1302 vector::VectorContractLowering::OuterProduct) 1303 return failure(); 1304 1305 if (failed(filter(op))) 1306 return failure(); 1307 1308 UnrolledOuterProductGenerator e(rewriter, op); 1309 FailureOr<Value> matmatRes = e.matmat(); 1310 if (succeeded(matmatRes)) { 1311 rewriter.replaceOp(op, *matmatRes); 1312 return success(); 1313 } 1314 FailureOr<Value> matvecRes = e.matvec(); 1315 if (succeeded(matvecRes)) { 1316 rewriter.replaceOp(op, *matvecRes); 1317 return success(); 1318 } 1319 FailureOr<Value> tmatvecRes = e.tmatvec(); 1320 if (succeeded(tmatvecRes)) { 1321 rewriter.replaceOp(op, *tmatvecRes); 1322 return success(); 1323 } 1324 1325 return failure(); 1326 } 1327 1328 LogicalResult 1329 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, 1330 PatternRewriter &rewriter) const { 1331 // TODO: implement masks 1332 if (llvm::size(op.masks()) != 0) 1333 return failure(); 1334 1335 if (failed(filter(op))) 1336 return failure(); 1337 1338 if (vectorTransformOptions.vectorContractLowering != 1339 vector::VectorContractLowering::Dot) 1340 return failure(); 1341 1342 auto iteratorTypes = op.iterator_types().getValue(); 1343 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1344 Location loc = op.getLoc(); 1345 Value lhs = op.lhs(), rhs = op.rhs(); 1346 1347 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1348 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 1349 AffineExpr m, n, k; 1350 bindDims(rewriter.getContext(), m, n, k); 1351 SmallVector<AffineMap, 4> maps = op.getIndexingMaps(); 1352 // 1353 // In the following we wish to make the reduction dimension innermost so we 1354 // can load vectors and just fmul + reduce into a scalar. 1355 // 1356 if (isParallelIterator(iteratorTypes[0]) && 1357 isParallelIterator(iteratorTypes[1]) && 1358 isReductionIterator(iteratorTypes[2])) { 1359 // 1360 // Two outer parallel, one inner reduction (matmat flavor). 1361 // 1362 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 1363 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1364 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { 1365 // No need to permute anything. 1366 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 1367 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1368 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1369 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 1370 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1371 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 1372 // This is the classical row-major matmul. Just permute the lhs. 1373 Value tmp = lhs; 1374 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1375 rhs = tmp; 1376 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 1377 std::swap(lhs, rhs); 1378 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 1379 Value tmp = lhs; 1380 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1381 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm); 1382 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 1383 Value tmp = rhs; 1384 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1385 lhs = tmp; 1386 } else { 1387 return failure(); 1388 } 1389 } else if (isParallelIterator(iteratorTypes[0]) && 1390 isReductionIterator(iteratorTypes[1])) { 1391 // 1392 // One outer parallel, one inner reduction (matvec flavor) 1393 // 1394 if (maps == infer({{m, n}, {n}, {m}})) { 1395 // No need to permute anything. 1396 } else if (maps == infer({{n, m}, {n}, {m}})) { 1397 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1398 } else if (maps == infer({{n}, {m, n}, {m}})) { 1399 std::swap(lhs, rhs); 1400 } else if (maps == infer({{n}, {n, m}, {m}})) { 1401 std::swap(lhs, rhs); 1402 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1403 } else { 1404 return failure(); 1405 } 1406 } else { 1407 return failure(); 1408 } 1409 1410 VectorType dstType = op.getResultType().cast<VectorType>(); 1411 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && 1412 "Expected dst type of rank 1 or 2"); 1413 1414 unsigned rank = dstType.getRank(); 1415 unsigned dstRows = dstType.getShape()[0]; 1416 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; 1417 1418 // ExtractOp does not allow dynamic indexing, we must unroll explicitly. 1419 Value res = rewriter.create<arith::ConstantOp>(loc, dstType, 1420 rewriter.getZeroAttr(dstType)); 1421 bool isInt = dstType.getElementType().isa<IntegerType>(); 1422 for (unsigned r = 0; r < dstRows; ++r) { 1423 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r); 1424 for (unsigned c = 0; c < dstColumns; ++c) { 1425 Value b = rank == 1 1426 ? rhs 1427 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c); 1428 Value m = createMul(op.getLoc(), a, b, isInt, rewriter); 1429 Value reduced = rewriter.create<vector::ReductionOp>( 1430 op.getLoc(), dstType.getElementType(), rewriter.getStringAttr("add"), 1431 m, ValueRange{}); 1432 1433 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r} 1434 : SmallVector<int64_t, 2>{r, c}; 1435 res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos); 1436 } 1437 } 1438 if (auto acc = op.acc()) 1439 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1440 rewriter.replaceOp(op, res); 1441 return success(); 1442 } 1443 1444 /// Progressive lowering of ContractionOp. 1445 /// One: 1446 /// %x = vector.contract with at least one free/batch dimension 1447 /// is replaced by: 1448 /// %a = vector.contract with one less free/batch dimension 1449 /// %b = vector.contract with one less free/batch dimension 1450 /// .. 1451 /// %x = combine %a %b .. 1452 /// until a pure contraction is reached (no free/batch dimensions), 1453 /// which is replaced by a dot-product. 1454 /// 1455 /// This only kicks in when either VectorTransformsOptions is set 1456 /// to DOT or when other contraction patterns fail. 1457 // 1458 // TODO: break down into transpose/reshape/cast ops 1459 // when they become available to avoid code dup 1460 // TODO: investigate lowering order impact on performance 1461 LogicalResult 1462 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, 1463 PatternRewriter &rewriter) const { 1464 // TODO: implement masks. 1465 if (llvm::size(op.masks()) != 0) 1466 return failure(); 1467 1468 if (failed(filter(op))) 1469 return failure(); 1470 1471 // TODO: support mixed mode contract lowering. 1472 if (op.getLhsType().getElementType() != 1473 getElementTypeOrSelf(op.getAccType()) || 1474 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) 1475 return failure(); 1476 1477 // TODO: implement benefits, cost models. 1478 MLIRContext *ctx = op.getContext(); 1479 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); 1480 if (succeeded(pat1.matchAndRewrite(op, rewriter))) 1481 return success(); 1482 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); 1483 if (succeeded(pat2.matchAndRewrite(op, rewriter))) 1484 return success(); 1485 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); 1486 if (succeeded(pat3.matchAndRewrite(op, rewriter))) 1487 return success(); 1488 1489 // Find first batch dimension in LHS/RHS, and lower when found. 1490 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap(); 1491 if (!batchDimMap.empty()) { 1492 int64_t lhsIndex = batchDimMap[0].first; 1493 int64_t rhsIndex = batchDimMap[0].second; 1494 rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter)); 1495 return success(); 1496 } 1497 1498 // Collect contracting dimensions. 1499 std::vector<std::pair<int64_t, int64_t>> contractingDimMap = 1500 op.getContractingDimMap(); 1501 DenseSet<int64_t> lhsContractingDimSet; 1502 DenseSet<int64_t> rhsContractingDimSet; 1503 for (auto &dimPair : contractingDimMap) { 1504 lhsContractingDimSet.insert(dimPair.first); 1505 rhsContractingDimSet.insert(dimPair.second); 1506 } 1507 1508 // Find first free dimension in LHS, and lower when found. 1509 VectorType lhsType = op.getLhsType(); 1510 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { 1511 if (lhsContractingDimSet.count(lhsIndex) == 0) { 1512 rewriter.replaceOp( 1513 op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter)); 1514 return success(); 1515 } 1516 } 1517 1518 // Find first free dimension in RHS, and lower when found. 1519 VectorType rhsType = op.getRhsType(); 1520 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { 1521 if (rhsContractingDimSet.count(rhsIndex) == 0) { 1522 rewriter.replaceOp( 1523 op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter)); 1524 return success(); 1525 } 1526 } 1527 1528 // Lower the first remaining reduction dimension. 1529 if (!contractingDimMap.empty()) { 1530 rewriter.replaceOp(op, lowerReduction(op, rewriter)); 1531 return success(); 1532 } 1533 1534 return failure(); 1535 } 1536 1537 // Lower one parallel dimension. 1538 // TODO: consider reusing existing contract unrolling 1539 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op, 1540 int64_t lhsIndex, int64_t rhsIndex, 1541 PatternRewriter &rewriter) const { 1542 VectorType lhsType = op.getLhsType(); 1543 VectorType rhsType = op.getRhsType(); 1544 VectorType resType = op.getResultType().cast<VectorType>(); 1545 // Find the iterator type index and result index. 1546 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1547 int64_t iterIndex = -1; 1548 int64_t dimSize = -1; 1549 if (lhsIndex >= 0) { 1550 iterIndex = iMap[0].getDimPosition(lhsIndex); 1551 assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) && 1552 "parallel index should be free in LHS or batch in LHS/RHS"); 1553 dimSize = lhsType.getDimSize(lhsIndex); 1554 } else { 1555 assert(rhsIndex >= 0 && "missing parallel index"); 1556 iterIndex = iMap[1].getDimPosition(rhsIndex); 1557 dimSize = rhsType.getDimSize(rhsIndex); 1558 } 1559 assert(iterIndex >= 0 && "parallel index not listed in operand mapping"); 1560 Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex); 1561 assert(lookup.hasValue() && "parallel index not listed in reduction"); 1562 int64_t resIndex = lookup.getValue(); 1563 // Construct new iterator types and affine map array attribute. 1564 std::array<AffineMap, 3> lowIndexingMaps = { 1565 adjustMap(iMap[0], iterIndex, rewriter), 1566 adjustMap(iMap[1], iterIndex, rewriter), 1567 adjustMap(iMap[2], iterIndex, rewriter)}; 1568 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1569 auto lowIter = 1570 rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); 1571 // Unroll into a series of lower dimensional vector.contract ops. 1572 Location loc = op.getLoc(); 1573 Value result = rewriter.create<arith::ConstantOp>( 1574 loc, resType, rewriter.getZeroAttr(resType)); 1575 for (int64_t d = 0; d < dimSize; ++d) { 1576 auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); 1577 auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); 1578 auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter); 1579 Value lowContract = rewriter.create<vector::ContractionOp>( 1580 loc, lhs, rhs, acc, lowAffine, lowIter); 1581 result = 1582 reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter); 1583 } 1584 return result; 1585 } 1586 1587 // Lower one reduction dimension. 1588 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op, 1589 PatternRewriter &rewriter) const { 1590 auto loc = op.getLoc(); 1591 VectorType lhsType = op.getLhsType(); 1592 VectorType rhsType = op.getRhsType(); 1593 Type resType = op.getResultType(); 1594 assert(!resType.isa<VectorType>()); 1595 bool isInt = resType.isa<IntegerType>(); 1596 // Use iterator index 0. 1597 int64_t iterIndex = 0; 1598 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1599 Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex); 1600 Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex); 1601 assert(lookupLhs.hasValue() && "missing LHS parallel index"); 1602 assert(lookupRhs.hasValue() && "missing RHS parallel index"); 1603 int64_t lhsIndex = lookupLhs.getValue(); 1604 int64_t rhsIndex = lookupRhs.getValue(); 1605 int64_t dimSize = lhsType.getDimSize(lhsIndex); 1606 assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape"); 1607 // Base case. 1608 if (lhsType.getRank() == 1) { 1609 assert(rhsType.getRank() == 1 && "corrupt contraction"); 1610 Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter); 1611 StringAttr kind = rewriter.getStringAttr("add"); 1612 Value res = rewriter.create<vector::ReductionOp>(loc, resType, kind, m, 1613 ValueRange{}); 1614 if (auto acc = op.acc()) 1615 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1616 return res; 1617 } 1618 // Construct new iterator types and affine map array attribute. 1619 std::array<AffineMap, 3> lowIndexingMaps = { 1620 adjustMap(iMap[0], iterIndex, rewriter), 1621 adjustMap(iMap[1], iterIndex, rewriter), 1622 adjustMap(iMap[2], iterIndex, rewriter)}; 1623 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1624 auto lowIter = 1625 rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); 1626 // Unroll into a series of lower dimensional vector.contract ops. 1627 // By feeding the initial accumulator into the first contraction, 1628 // and the result of each contraction into the next, eventually 1629 // the sum of all reductions is computed. 1630 Value result = op.acc(); 1631 for (int64_t d = 0; d < dimSize; ++d) { 1632 auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); 1633 auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); 1634 result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result, 1635 lowAffine, lowIter); 1636 } 1637 return result; 1638 } 1639 1640 } // namespace mlir 1641 1642 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp( 1643 OpBuilder &builder, Operation *op, ArrayRef<Value> ids, 1644 ArrayRef<int64_t> multiplicity, const AffineMap &map) { 1645 OpBuilder::InsertionGuard guard(builder); 1646 builder.setInsertionPointAfter(op); 1647 Location loc = op->getLoc(); 1648 if (op->getNumResults() != 1) 1649 return {}; 1650 Value result = op->getResult(0); 1651 VectorType type = op->getResult(0).getType().dyn_cast<VectorType>(); 1652 if (!type || map.getNumResults() != multiplicity.size()) 1653 return {}; 1654 // For each dimension being distributed check that the size is a multiple of 1655 // the multiplicity. To handle more sizes we would need to support masking. 1656 unsigned multiplictyCount = 0; 1657 for (auto exp : map.getResults()) { 1658 auto affinExp = exp.dyn_cast<AffineDimExpr>(); 1659 if (!affinExp || affinExp.getPosition() >= type.getRank() || 1660 type.getDimSize(affinExp.getPosition()) % 1661 multiplicity[multiplictyCount++] != 1662 0) 1663 return {}; 1664 } 1665 DistributeOps ops; 1666 ops.extract = 1667 builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map); 1668 ops.insert = 1669 builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids); 1670 return ops; 1671 } 1672 1673 /// Progressive lowering of transfer_read. This pattern supports lowering of 1674 /// `vector.transfer_read` to a combination of `vector.load` and 1675 /// `vector.broadcast` if all of the following hold: 1676 /// - Stride of most minor memref dimension must be 1. 1677 /// - Out-of-bounds masking is not required. 1678 /// - If the memref's element type is a vector type then it coincides with the 1679 /// result type. 1680 /// - The permutation map doesn't perform permutation (broadcasting is allowed). 1681 struct TransferReadToVectorLoadLowering 1682 : public OpRewritePattern<vector::TransferReadOp> { 1683 TransferReadToVectorLoadLowering(MLIRContext *context, 1684 llvm::Optional<unsigned> maxRank) 1685 : OpRewritePattern<vector::TransferReadOp>(context), 1686 maxTransferRank(maxRank) {} 1687 1688 LogicalResult matchAndRewrite(vector::TransferReadOp read, 1689 PatternRewriter &rewriter) const override { 1690 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) 1691 return failure(); 1692 1693 SmallVector<unsigned, 4> broadcastedDims; 1694 // Permutations are handled by VectorToSCF or 1695 // populateVectorTransferPermutationMapLoweringPatterns. 1696 // We let the 0-d corner case pass-through as it is supported. 1697 if (!read.permutation_map().isMinorIdentityWithBroadcasting( 1698 &broadcastedDims)) 1699 return failure(); 1700 1701 auto memRefType = read.getShapedType().dyn_cast<MemRefType>(); 1702 if (!memRefType) 1703 return failure(); 1704 1705 // Non-unit strides are handled by VectorToSCF. 1706 if (!vector::isLastMemrefDimUnitStride(memRefType)) 1707 return failure(); 1708 1709 // If there is broadcasting involved then we first load the unbroadcasted 1710 // vector, and then broadcast it with `vector.broadcast`. 1711 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape(); 1712 SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(), 1713 vectorShape.end()); 1714 for (unsigned i : broadcastedDims) 1715 unbroadcastedVectorShape[i] = 1; 1716 VectorType unbroadcastedVectorType = VectorType::get( 1717 unbroadcastedVectorShape, read.getVectorType().getElementType()); 1718 1719 // `vector.load` supports vector types as memref's elements only when the 1720 // resulting vector type is the same as the element type. 1721 auto memrefElTy = memRefType.getElementType(); 1722 if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType) 1723 return failure(); 1724 1725 // Otherwise, element types of the memref and the vector must match. 1726 if (!memrefElTy.isa<VectorType>() && 1727 memrefElTy != read.getVectorType().getElementType()) 1728 return failure(); 1729 1730 // Out-of-bounds dims are handled by MaterializeTransferMask. 1731 if (read.hasOutOfBoundsDim()) 1732 return failure(); 1733 1734 // Create vector load op. 1735 Operation *loadOp; 1736 if (read.mask()) { 1737 Value fill = rewriter.create<vector::SplatOp>( 1738 read.getLoc(), unbroadcastedVectorType, read.padding()); 1739 loadOp = rewriter.create<vector::MaskedLoadOp>( 1740 read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(), 1741 read.mask(), fill); 1742 } else { 1743 loadOp = rewriter.create<vector::LoadOp>(read.getLoc(), 1744 unbroadcastedVectorType, 1745 read.source(), read.indices()); 1746 } 1747 1748 // Insert a broadcasting op if required. 1749 if (!broadcastedDims.empty()) { 1750 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1751 read, read.getVectorType(), loadOp->getResult(0)); 1752 } else { 1753 rewriter.replaceOp(read, loadOp->getResult(0)); 1754 } 1755 1756 return success(); 1757 } 1758 1759 llvm::Optional<unsigned> maxTransferRank; 1760 }; 1761 1762 /// Replace a 0-d vector.load with a memref.load + vector.broadcast. 1763 // TODO: we shouldn't cross the vector/scalar domains just for this 1764 // but atm we lack the infra to avoid it. Possible solutions include: 1765 // - go directly to LLVM + bitcast 1766 // - introduce a bitcast op and likely a new pointer dialect 1767 // - let memref.load/store additionally support the 0-d vector case 1768 // There are still deeper data layout issues lingering even in this 1769 // trivial case (for architectures for which this matters). 1770 struct VectorLoadToMemrefLoadLowering 1771 : public OpRewritePattern<vector::LoadOp> { 1772 using OpRewritePattern<vector::LoadOp>::OpRewritePattern; 1773 1774 LogicalResult matchAndRewrite(vector::LoadOp loadOp, 1775 PatternRewriter &rewriter) const override { 1776 auto vecType = loadOp.getVectorType(); 1777 if (vecType.getNumElements() != 1) 1778 return failure(); 1779 auto memrefLoad = rewriter.create<memref::LoadOp>( 1780 loadOp.getLoc(), loadOp.base(), loadOp.indices()); 1781 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType, 1782 memrefLoad); 1783 return success(); 1784 } 1785 }; 1786 1787 /// Replace a 0-d vector.store with a vector.extractelement + memref.store. 1788 struct VectorStoreToMemrefStoreLowering 1789 : public OpRewritePattern<vector::StoreOp> { 1790 using OpRewritePattern<vector::StoreOp>::OpRewritePattern; 1791 1792 LogicalResult matchAndRewrite(vector::StoreOp storeOp, 1793 PatternRewriter &rewriter) const override { 1794 auto vecType = storeOp.getVectorType(); 1795 if (vecType.getNumElements() != 1) 1796 return failure(); 1797 Value extracted; 1798 if (vecType.getRank() == 0) { 1799 // TODO: Unifiy once ExtractOp supports 0-d vectors. 1800 extracted = rewriter.create<vector::ExtractElementOp>( 1801 storeOp.getLoc(), storeOp.valueToStore()); 1802 } else { 1803 SmallVector<int64_t> indices(vecType.getRank(), 0); 1804 extracted = rewriter.create<vector::ExtractOp>( 1805 storeOp.getLoc(), storeOp.valueToStore(), indices); 1806 } 1807 1808 rewriter.replaceOpWithNewOp<memref::StoreOp>( 1809 storeOp, extracted, storeOp.base(), storeOp.indices()); 1810 return success(); 1811 } 1812 }; 1813 1814 /// Progressive lowering of transfer_write. This pattern supports lowering of 1815 /// `vector.transfer_write` to `vector.store` if all of the following hold: 1816 /// - Stride of most minor memref dimension must be 1. 1817 /// - Out-of-bounds masking is not required. 1818 /// - If the memref's element type is a vector type then it coincides with the 1819 /// type of the written value. 1820 /// - The permutation map is the minor identity map (neither permutation nor 1821 /// broadcasting is allowed). 1822 struct TransferWriteToVectorStoreLowering 1823 : public OpRewritePattern<vector::TransferWriteOp> { 1824 TransferWriteToVectorStoreLowering(MLIRContext *context, 1825 llvm::Optional<unsigned> maxRank) 1826 : OpRewritePattern<vector::TransferWriteOp>(context), 1827 maxTransferRank(maxRank) {} 1828 1829 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 1830 PatternRewriter &rewriter) const override { 1831 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) 1832 return failure(); 1833 1834 // Permutations are handled by VectorToSCF or 1835 // populateVectorTransferPermutationMapLoweringPatterns. 1836 if ( // pass-through for the 0-d corner case. 1837 !write.permutation_map().isMinorIdentity()) 1838 return failure(); 1839 1840 auto memRefType = write.getShapedType().dyn_cast<MemRefType>(); 1841 if (!memRefType) 1842 return failure(); 1843 1844 // Non-unit strides are handled by VectorToSCF. 1845 if (!vector::isLastMemrefDimUnitStride(memRefType)) 1846 return failure(); 1847 1848 // `vector.store` supports vector types as memref's elements only when the 1849 // type of the vector value being written is the same as the element type. 1850 auto memrefElTy = memRefType.getElementType(); 1851 if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType()) 1852 return failure(); 1853 1854 // Otherwise, element types of the memref and the vector must match. 1855 if (!memrefElTy.isa<VectorType>() && 1856 memrefElTy != write.getVectorType().getElementType()) 1857 return failure(); 1858 1859 // Out-of-bounds dims are handled by MaterializeTransferMask. 1860 if (write.hasOutOfBoundsDim()) 1861 return failure(); 1862 if (write.mask()) { 1863 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 1864 write, write.source(), write.indices(), write.mask(), write.vector()); 1865 } else { 1866 rewriter.replaceOpWithNewOp<vector::StoreOp>( 1867 write, write.vector(), write.source(), write.indices()); 1868 } 1869 return success(); 1870 } 1871 1872 llvm::Optional<unsigned> maxTransferRank; 1873 }; 1874 1875 // Returns the values in `arrayAttr` as an integer vector. 1876 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) { 1877 return llvm::to_vector<4>( 1878 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(), 1879 [](IntegerAttr attr) { return attr.getInt(); })); 1880 } 1881 1882 // Shuffles vector.bitcast op after vector.extract op. 1883 // 1884 // This transforms IR like: 1885 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> 1886 // %1 = vector.extract %0[3] : vector<8xf16> 1887 // Into: 1888 // %0 = vector.extract %src[1] : vector<4xf32> 1889 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16> 1890 // %2 = vector.extract %1[1] : vector<2xf16> 1891 struct BubbleDownVectorBitCastForExtract 1892 : public OpRewritePattern<vector::ExtractOp> { 1893 using OpRewritePattern::OpRewritePattern; 1894 1895 LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 1896 PatternRewriter &rewriter) const override { 1897 // Only support extracting scalars for now. 1898 if (extractOp.getVectorType().getRank() != 1) 1899 return failure(); 1900 1901 auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>(); 1902 if (!castOp) 1903 return failure(); 1904 1905 VectorType castSrcType = castOp.getSourceVectorType(); 1906 VectorType castDstType = castOp.getResultVectorType(); 1907 assert(castSrcType.getRank() == castDstType.getRank()); 1908 1909 // Fail to match if we only have one element in the cast op source. 1910 // This is to avoid infinite loop given that this pattern can generate 1911 // such cases. 1912 if (castSrcType.getNumElements() == 1) 1913 return failure(); 1914 1915 // Only support casting to a larger number of elements or now. 1916 // E.g., vector<4xf32> -> vector<8xf16>. 1917 if (castSrcType.getNumElements() > castDstType.getNumElements()) 1918 return failure(); 1919 1920 unsigned expandRatio = 1921 castDstType.getNumElements() / castSrcType.getNumElements(); 1922 1923 auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t { 1924 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 1925 }; 1926 1927 uint64_t index = getFirstIntValue(extractOp.position()); 1928 1929 // Get the single scalar (as a vector) in the source value that packs the 1930 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> 1931 VectorType oneScalarType = 1932 VectorType::get({1}, castSrcType.getElementType()); 1933 Value packedValue = rewriter.create<vector::ExtractOp>( 1934 extractOp.getLoc(), oneScalarType, castOp.source(), 1935 rewriter.getI64ArrayAttr(index / expandRatio)); 1936 1937 // Cast it to a vector with the desired scalar's type. 1938 // E.g. f32 -> vector<2xf16> 1939 VectorType packedType = 1940 VectorType::get({expandRatio}, castDstType.getElementType()); 1941 Value castedValue = rewriter.create<vector::BitCastOp>( 1942 extractOp.getLoc(), packedType, packedValue); 1943 1944 // Finally extract the desired scalar. 1945 rewriter.replaceOpWithNewOp<vector::ExtractOp>( 1946 extractOp, extractOp.getType(), castedValue, 1947 rewriter.getI64ArrayAttr(index % expandRatio)); 1948 1949 return success(); 1950 } 1951 }; 1952 1953 // Shuffles vector.bitcast op after vector.extract_strided_slice op. 1954 // 1955 // This transforms IR like: 1956 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> 1957 // %0 = vector.extract_strided_slice %cast { 1958 // offsets = [4], sizes = [4], strides = [1] 1959 // } : vector<8xf16> to vector<4xf16> 1960 // Into: 1961 // %0 = vector.extract_strided_slice %src { 1962 // offsets = [2], sizes = [2], strides = [1] 1963 // } : vector<4xf32> to vector<2xf32> 1964 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> 1965 struct BubbleDownBitCastForStridedSliceExtract 1966 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 1967 using OpRewritePattern::OpRewritePattern; 1968 1969 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 1970 PatternRewriter &rewriter) const override { 1971 auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>(); 1972 if (!castOp) 1973 return failure(); 1974 1975 VectorType castSrcType = castOp.getSourceVectorType(); 1976 VectorType castDstType = castOp.getResultVectorType(); 1977 assert(castSrcType.getRank() == castDstType.getRank()); 1978 1979 int64_t castSrcLastDim = castSrcType.getShape().back(); 1980 int64_t castDstLastDim = castDstType.getShape().back(); 1981 // Require casting to more elements for now; other cases to be implemented. 1982 if (castSrcLastDim > castDstLastDim) 1983 return failure(); 1984 1985 // Only accept all one strides for now. 1986 if (llvm::any_of(extractOp.strides().getAsValueRange<IntegerAttr>(), 1987 [](const APInt &val) { return !val.isOneValue(); })) 1988 return failure(); 1989 1990 unsigned rank = extractOp.getVectorType().getRank(); 1991 assert(castDstLastDim % castSrcLastDim == 0); 1992 int64_t expandRatio = castDstLastDim / castSrcLastDim; 1993 1994 // If we have a less number of offsets than the rank, then implicitly we 1995 // are selecting the full range for the last bitcasted dimension; other 1996 // dimensions aren't affected. Otherwise, we need to scale down the last 1997 // dimension's offset given we are extracting from less elements now. 1998 ArrayAttr newOffsets = extractOp.offsets(); 1999 if (newOffsets.size() == rank) { 2000 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2001 if (offsets.back() % expandRatio != 0) 2002 return failure(); 2003 offsets.back() = offsets.back() / expandRatio; 2004 newOffsets = rewriter.getI64ArrayAttr(offsets); 2005 } 2006 2007 // Similarly for sizes. 2008 ArrayAttr newSizes = extractOp.sizes(); 2009 if (newSizes.size() == rank) { 2010 SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes); 2011 if (sizes.back() % expandRatio != 0) 2012 return failure(); 2013 sizes.back() = sizes.back() / expandRatio; 2014 newSizes = rewriter.getI64ArrayAttr(sizes); 2015 } 2016 2017 SmallVector<int64_t, 4> dims = 2018 llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape()); 2019 dims.back() = dims.back() / expandRatio; 2020 VectorType newExtractType = 2021 VectorType::get(dims, castSrcType.getElementType()); 2022 2023 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 2024 extractOp.getLoc(), newExtractType, castOp.source(), newOffsets, 2025 newSizes, extractOp.strides()); 2026 2027 rewriter.replaceOpWithNewOp<vector::BitCastOp>( 2028 extractOp, extractOp.getType(), newExtractOp); 2029 2030 return success(); 2031 } 2032 }; 2033 2034 // Shuffles vector.bitcast op before vector.insert_strided_slice op. 2035 // 2036 // This transforms IR like: 2037 // %0 = vector.insert_strided_slice %src, %dst { 2038 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> 2039 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> 2040 // Into: 2041 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32> 2042 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32> 2043 // %2 = vector.insert_strided_slice %src, %dst { 2044 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 2045 struct BubbleUpBitCastForStridedSliceInsert 2046 : public OpRewritePattern<vector::BitCastOp> { 2047 using OpRewritePattern::OpRewritePattern; 2048 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 2049 PatternRewriter &rewriter) const override { 2050 VectorType castSrcType = bitcastOp.getSourceVectorType(); 2051 VectorType castDstType = bitcastOp.getResultVectorType(); 2052 assert(castSrcType.getRank() == castDstType.getRank()); 2053 2054 int64_t castSrcLastDim = castSrcType.getShape().back(); 2055 int64_t castDstLastDim = castDstType.getShape().back(); 2056 // Require casting to less elements for now; other cases to be implemented. 2057 if (castSrcLastDim < castDstLastDim) 2058 return failure(); 2059 2060 assert(castSrcLastDim % castDstLastDim == 0); 2061 int64_t shrinkRatio = castSrcLastDim / castDstLastDim; 2062 2063 auto insertOp = 2064 bitcastOp.source().getDefiningOp<vector::InsertStridedSliceOp>(); 2065 if (!insertOp) 2066 return failure(); 2067 2068 // Only accept all one strides for now. 2069 if (llvm::any_of(insertOp.strides().getAsValueRange<IntegerAttr>(), 2070 [](const APInt &val) { return !val.isOneValue(); })) 2071 return failure(); 2072 2073 unsigned rank = insertOp.getSourceVectorType().getRank(); 2074 // Require insert op to have the same rank for the source and destination 2075 // vector; other cases to be implemented. 2076 if (rank != insertOp.getDestVectorType().getRank()) 2077 return failure(); 2078 2079 ArrayAttr newOffsets = insertOp.offsets(); 2080 assert(newOffsets.size() == rank); 2081 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2082 if (offsets.back() % shrinkRatio != 0) 2083 return failure(); 2084 offsets.back() = offsets.back() / shrinkRatio; 2085 newOffsets = rewriter.getI64ArrayAttr(offsets); 2086 2087 SmallVector<int64_t, 4> srcDims = 2088 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape()); 2089 srcDims.back() = srcDims.back() / shrinkRatio; 2090 VectorType newCastSrcType = 2091 VectorType::get(srcDims, castDstType.getElementType()); 2092 2093 auto newCastSrcOp = rewriter.create<vector::BitCastOp>( 2094 bitcastOp.getLoc(), newCastSrcType, insertOp.source()); 2095 2096 SmallVector<int64_t, 4> dstDims = 2097 llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); 2098 dstDims.back() = dstDims.back() / shrinkRatio; 2099 VectorType newCastDstType = 2100 VectorType::get(dstDims, castDstType.getElementType()); 2101 2102 auto newCastDstOp = rewriter.create<vector::BitCastOp>( 2103 bitcastOp.getLoc(), newCastDstType, insertOp.dest()); 2104 2105 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>( 2106 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, 2107 insertOp.strides()); 2108 2109 return success(); 2110 } 2111 }; 2112 2113 static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc, 2114 Type targetType, Value value) { 2115 if (targetType == value.getType()) 2116 return value; 2117 2118 bool targetIsIndex = targetType.isIndex(); 2119 bool valueIsIndex = value.getType().isIndex(); 2120 if (targetIsIndex ^ valueIsIndex) 2121 return rewriter.create<arith::IndexCastOp>(loc, targetType, value); 2122 2123 auto targetIntegerType = targetType.dyn_cast<IntegerType>(); 2124 auto valueIntegerType = value.getType().dyn_cast<IntegerType>(); 2125 assert(targetIntegerType && valueIntegerType && 2126 "unexpected cast between types other than integers and index"); 2127 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); 2128 2129 if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) 2130 return rewriter.create<arith::ExtSIOp>(loc, targetIntegerType, value); 2131 return rewriter.create<arith::TruncIOp>(loc, targetIntegerType, value); 2132 } 2133 2134 // Helper that returns a vector comparison that constructs a mask: 2135 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 2136 // 2137 // If `dim == 0` then the result will be a 0-D vector. 2138 // 2139 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 2140 // much more compact, IR for this operation, but LLVM eventually 2141 // generates more elaborate instructions for this intrinsic since it 2142 // is very conservative on the boundary conditions. 2143 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, 2144 bool indexOptimizations, int64_t dim, 2145 Value b, Value *off = nullptr) { 2146 auto loc = op->getLoc(); 2147 // If we can assume all indices fit in 32-bit, we perform the vector 2148 // comparison in 32-bit to get a higher degree of SIMD parallelism. 2149 // Otherwise we perform the vector comparison using 64-bit indices. 2150 Type idxType = 2151 indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type(); 2152 DenseIntElementsAttr indicesAttr; 2153 if (dim == 0 && indexOptimizations) { 2154 indicesAttr = DenseIntElementsAttr::get( 2155 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0}); 2156 } else if (dim == 0) { 2157 indicesAttr = DenseIntElementsAttr::get( 2158 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0}); 2159 } else if (indexOptimizations) { 2160 indicesAttr = rewriter.getI32VectorAttr( 2161 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))); 2162 } else { 2163 indicesAttr = rewriter.getI64VectorAttr( 2164 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))); 2165 } 2166 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 2167 // Add in an offset if requested. 2168 if (off) { 2169 Value o = createCastToIndexLike(rewriter, loc, idxType, *off); 2170 Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o); 2171 indices = rewriter.create<arith::AddIOp>(loc, ov, indices); 2172 } 2173 // Construct the vector comparison. 2174 Value bound = createCastToIndexLike(rewriter, loc, idxType, b); 2175 Value bounds = 2176 rewriter.create<vector::SplatOp>(loc, indices.getType(), bound); 2177 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices, 2178 bounds); 2179 } 2180 2181 template <typename ConcreteOp> 2182 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> { 2183 public: 2184 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt) 2185 : mlir::OpRewritePattern<ConcreteOp>(context), 2186 indexOptimizations(enableIndexOpt) {} 2187 2188 LogicalResult matchAndRewrite(ConcreteOp xferOp, 2189 PatternRewriter &rewriter) const override { 2190 if (!xferOp.hasOutOfBoundsDim()) 2191 return failure(); 2192 2193 if (xferOp.getVectorType().getRank() > 1 || 2194 llvm::size(xferOp.indices()) == 0) 2195 return failure(); 2196 2197 Location loc = xferOp->getLoc(); 2198 VectorType vtp = xferOp.getVectorType(); 2199 2200 // * Create a vector with linear indices [ 0 .. vector_length - 1 ]. 2201 // * Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 2202 // * Let dim the memref dimension, compute the vector comparison mask 2203 // (in-bounds mask): 2204 // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 2205 // 2206 // TODO: when the leaf transfer rank is k > 1, we need the last `k` 2207 // dimensions here. 2208 unsigned vecWidth = vtp.getNumElements(); 2209 unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 2210 Value off = xferOp.indices()[lastIndex]; 2211 Value dim = 2212 vector::createOrFoldDimOp(rewriter, loc, xferOp.source(), lastIndex); 2213 Value mask = buildVectorComparison(rewriter, xferOp, indexOptimizations, 2214 vecWidth, dim, &off); 2215 2216 if (xferOp.mask()) { 2217 // Intersect the in-bounds with the mask specified as an op parameter. 2218 mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.mask()); 2219 } 2220 2221 rewriter.updateRootInPlace(xferOp, [&]() { 2222 xferOp.maskMutable().assign(mask); 2223 xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true})); 2224 }); 2225 2226 return success(); 2227 } 2228 2229 private: 2230 const bool indexOptimizations; 2231 }; 2232 2233 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only). 2234 class VectorCreateMaskOpConversion 2235 : public OpRewritePattern<vector::CreateMaskOp> { 2236 public: 2237 explicit VectorCreateMaskOpConversion(MLIRContext *context, 2238 bool enableIndexOpt) 2239 : mlir::OpRewritePattern<vector::CreateMaskOp>(context), 2240 indexOptimizations(enableIndexOpt) {} 2241 2242 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 2243 PatternRewriter &rewriter) const override { 2244 auto dstType = op.getType(); 2245 int64_t rank = dstType.getRank(); 2246 if (rank > 1) 2247 return failure(); 2248 rewriter.replaceOp( 2249 op, buildVectorComparison(rewriter, op, indexOptimizations, 2250 rank == 0 ? 0 : dstType.getDimSize(0), 2251 op.getOperand(0))); 2252 return success(); 2253 } 2254 2255 private: 2256 const bool indexOptimizations; 2257 }; 2258 2259 // Drop inner most contiguous unit dimensions from transfer_read operand. 2260 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> { 2261 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 2262 2263 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 2264 PatternRewriter &rewriter) const override { 2265 // TODO: support 0-d corner case. 2266 if (readOp.getTransferRank() == 0) 2267 return failure(); 2268 2269 // TODO: support mask. 2270 if (readOp.mask()) 2271 return failure(); 2272 2273 auto srcType = readOp.source().getType().dyn_cast<MemRefType>(); 2274 if (!srcType || !srcType.hasStaticShape()) 2275 return failure(); 2276 2277 if (!readOp.permutation_map().isMinorIdentity()) 2278 return failure(); 2279 2280 auto targetType = readOp.getVectorType(); 2281 if (targetType.getRank() <= 1) 2282 return failure(); 2283 2284 SmallVector<int64_t> srcStrides; 2285 int64_t srcOffset; 2286 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) 2287 return failure(); 2288 2289 size_t dimsToDrop = 0; 2290 for (size_t i = 1; i < srcStrides.size(); ++i) { 2291 int dim = srcType.getRank() - i - 1; 2292 if (srcStrides[dim] == 1) { 2293 dimsToDrop++; 2294 } else { 2295 break; 2296 } 2297 } 2298 if (dimsToDrop == 0) 2299 return failure(); 2300 2301 auto resultTargetVecType = 2302 VectorType::get(targetType.getShape().drop_back(dimsToDrop), 2303 targetType.getElementType()); 2304 2305 MemRefType resultMemrefType; 2306 if (srcType.getLayout().getAffineMap().isIdentity()) { 2307 resultMemrefType = MemRefType::get( 2308 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2309 {}, srcType.getMemorySpaceAsInt()); 2310 } else { 2311 AffineMap map = srcType.getLayout().getAffineMap(); 2312 int numResultDims = map.getNumDims() - dimsToDrop; 2313 int numSymbols = map.getNumSymbols(); 2314 for (size_t i = 0; i < dimsToDrop; ++i) { 2315 int dim = srcType.getRank() - i - 1; 2316 map = map.replace(rewriter.getAffineDimExpr(dim), 2317 rewriter.getAffineConstantExpr(0), numResultDims, 2318 numSymbols); 2319 } 2320 resultMemrefType = MemRefType::get( 2321 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2322 map, srcType.getMemorySpaceAsInt()); 2323 } 2324 2325 auto loc = readOp.getLoc(); 2326 SmallVector<int64_t> offsets(srcType.getRank(), 0); 2327 SmallVector<int64_t> strides(srcType.getRank(), 1); 2328 2329 ArrayAttr inBoundsAttr = 2330 readOp.in_bounds() 2331 ? rewriter.getArrayAttr( 2332 readOp.in_boundsAttr().getValue().drop_back(dimsToDrop)) 2333 : ArrayAttr(); 2334 Value rankedReducedView = rewriter.create<memref::SubViewOp>( 2335 loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(), 2336 strides); 2337 auto permMap = getTransferMinorIdentityMap( 2338 rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType); 2339 Value result = rewriter.create<vector::TransferReadOp>( 2340 loc, resultTargetVecType, rankedReducedView, 2341 readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), 2342 readOp.padding(), 2343 // TODO: support mask. 2344 /*mask=*/Value(), inBoundsAttr); 2345 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType, 2346 result); 2347 return success(); 2348 } 2349 }; 2350 2351 namespace { 2352 2353 /// This function checks to see if the vector combining kind 2354 /// is consistent with the integer or float element type. 2355 static bool isValidKind(bool isInt, vector::CombiningKind kind) { 2356 using vector::CombiningKind; 2357 enum class KindType { FLOAT, INT, INVALID }; 2358 KindType type{KindType::INVALID}; 2359 switch (kind) { 2360 case CombiningKind::MINF: 2361 case CombiningKind::MAXF: 2362 type = KindType::FLOAT; 2363 break; 2364 case CombiningKind::MINUI: 2365 case CombiningKind::MINSI: 2366 case CombiningKind::MAXUI: 2367 case CombiningKind::MAXSI: 2368 case CombiningKind::AND: 2369 case CombiningKind::OR: 2370 case CombiningKind::XOR: 2371 type = KindType::INT; 2372 break; 2373 case CombiningKind::ADD: 2374 case CombiningKind::MUL: 2375 type = isInt ? KindType::INT : KindType::FLOAT; 2376 break; 2377 } 2378 bool isValidIntKind = (type == KindType::INT) && isInt; 2379 bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt); 2380 return (isValidIntKind || isValidFloatKind); 2381 } 2382 2383 /// This function constructs the appropriate integer or float 2384 /// operation given the vector combining kind and operands. The 2385 /// supported int operations are : add, mul, min (signed/unsigned), 2386 /// max(signed/unsigned), and, or, xor. The supported float 2387 /// operations are : add, mul, min and max. 2388 static Value genOperator(Location loc, Value x, Value y, 2389 vector::CombiningKind kind, 2390 PatternRewriter &rewriter) { 2391 using vector::CombiningKind; 2392 2393 auto elType = x.getType().cast<VectorType>().getElementType(); 2394 bool isInt = elType.isIntOrIndex(); 2395 2396 Value combinedResult{nullptr}; 2397 switch (kind) { 2398 case CombiningKind::ADD: 2399 if (isInt) 2400 combinedResult = rewriter.create<arith::AddIOp>(loc, x, y); 2401 else 2402 combinedResult = rewriter.create<arith::AddFOp>(loc, x, y); 2403 break; 2404 case CombiningKind::MUL: 2405 if (isInt) 2406 combinedResult = rewriter.create<arith::MulIOp>(loc, x, y); 2407 else 2408 combinedResult = rewriter.create<arith::MulFOp>(loc, x, y); 2409 break; 2410 case CombiningKind::MINUI: 2411 combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y); 2412 break; 2413 case CombiningKind::MINSI: 2414 combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y); 2415 break; 2416 case CombiningKind::MAXUI: 2417 combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y); 2418 break; 2419 case CombiningKind::MAXSI: 2420 combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y); 2421 break; 2422 case CombiningKind::AND: 2423 combinedResult = rewriter.create<arith::AndIOp>(loc, x, y); 2424 break; 2425 case CombiningKind::OR: 2426 combinedResult = rewriter.create<arith::OrIOp>(loc, x, y); 2427 break; 2428 case CombiningKind::XOR: 2429 combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y); 2430 break; 2431 case CombiningKind::MINF: 2432 combinedResult = rewriter.create<arith::MinFOp>(loc, x, y); 2433 break; 2434 case CombiningKind::MAXF: 2435 combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y); 2436 break; 2437 } 2438 return combinedResult; 2439 } 2440 2441 /// Convert vector.scan op into arith ops and 2442 /// vector.insert_strided_slice/extract_strided_slice 2443 /// 2444 /// Ex: 2445 /// ``` 2446 /// %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 2447 /// 1} : 2448 /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>) 2449 /// ``` 2450 /// Gets converted to: 2451 /// ``` 2452 /// %cst = arith.constant dense<0> : vector<2x3xi32> 2453 /// %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1], 2454 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 = 2455 /// vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]} 2456 /// : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice 2457 /// %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : 2458 /// vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 : 2459 /// vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1], 2460 /// strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 = 2461 /// vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1], 2462 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3, 2463 /// %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets = 2464 /// [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 = 2465 /// vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 : 2466 /// vector<2x3xi32>, vector<2xi32> 2467 /// ``` 2468 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { 2469 using OpRewritePattern<vector::ScanOp>::OpRewritePattern; 2470 2471 LogicalResult matchAndRewrite(vector::ScanOp scanOp, 2472 PatternRewriter &rewriter) const override { 2473 auto loc = scanOp.getLoc(); 2474 VectorType destType = scanOp.getDestType(); 2475 ArrayRef<int64_t> destShape = destType.getShape(); 2476 auto elType = destType.getElementType(); 2477 bool isInt = elType.isIntOrIndex(); 2478 if (!isValidKind(isInt, scanOp.kind())) 2479 return failure(); 2480 2481 VectorType resType = VectorType::get(destShape, elType); 2482 Value result = rewriter.create<arith::ConstantOp>( 2483 loc, resType, rewriter.getZeroAttr(resType)); 2484 int64_t reductionDim = scanOp.reduction_dim(); 2485 bool inclusive = scanOp.inclusive(); 2486 int64_t destRank = destType.getRank(); 2487 VectorType initialValueType = scanOp.getInitialValueType(); 2488 int64_t initialValueRank = initialValueType.getRank(); 2489 2490 SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end()); 2491 reductionShape[reductionDim] = 1; 2492 VectorType reductionType = VectorType::get(reductionShape, elType); 2493 SmallVector<int64_t> offsets(destRank, 0); 2494 SmallVector<int64_t> strides(destRank, 1); 2495 SmallVector<int64_t> sizes(destShape.begin(), destShape.end()); 2496 sizes[reductionDim] = 1; 2497 ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes); 2498 ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides); 2499 2500 Value lastOutput, lastInput; 2501 for (int i = 0; i < destShape[reductionDim]; i++) { 2502 offsets[reductionDim] = i; 2503 ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets); 2504 Value input = rewriter.create<vector::ExtractStridedSliceOp>( 2505 loc, reductionType, scanOp.source(), scanOffsets, scanSizes, 2506 scanStrides); 2507 Value output; 2508 if (i == 0) { 2509 if (inclusive) { 2510 output = input; 2511 } else { 2512 if (initialValueRank == 0) { 2513 // ShapeCastOp cannot handle 0-D vectors 2514 output = rewriter.create<vector::BroadcastOp>( 2515 loc, input.getType(), scanOp.initial_value()); 2516 } else { 2517 output = rewriter.create<vector::ShapeCastOp>( 2518 loc, input.getType(), scanOp.initial_value()); 2519 } 2520 } 2521 } else { 2522 Value y = inclusive ? input : lastInput; 2523 output = genOperator(loc, lastOutput, y, scanOp.kind(), rewriter); 2524 assert(output != nullptr); 2525 } 2526 result = rewriter.create<vector::InsertStridedSliceOp>( 2527 loc, output, result, offsets, strides); 2528 lastOutput = output; 2529 lastInput = input; 2530 } 2531 2532 Value reduction; 2533 if (initialValueRank == 0) { 2534 Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0); 2535 reduction = 2536 rewriter.create<vector::BroadcastOp>(loc, initialValueType, v); 2537 } else { 2538 reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType, 2539 lastOutput); 2540 } 2541 2542 rewriter.replaceOp(scanOp, {result, reduction}); 2543 return success(); 2544 } 2545 }; 2546 2547 } // namespace 2548 2549 void mlir::vector::populateVectorMaskMaterializationPatterns( 2550 RewritePatternSet &patterns, bool indexOptimizations) { 2551 patterns.add<VectorCreateMaskOpConversion, 2552 MaterializeTransferMask<vector::TransferReadOp>, 2553 MaterializeTransferMask<vector::TransferWriteOp>>( 2554 patterns.getContext(), indexOptimizations); 2555 } 2556 2557 void mlir::vector::populateShapeCastFoldingPatterns( 2558 RewritePatternSet &patterns) { 2559 patterns.add<ShapeCastOpFolder>(patterns.getContext()); 2560 } 2561 2562 void mlir::vector::populateBubbleVectorBitCastOpPatterns( 2563 RewritePatternSet &patterns) { 2564 patterns.add<BubbleDownVectorBitCastForExtract, 2565 BubbleDownBitCastForStridedSliceExtract, 2566 BubbleUpBitCastForStridedSliceInsert>(patterns.getContext()); 2567 } 2568 2569 void mlir::vector::populateVectorBroadcastLoweringPatterns( 2570 RewritePatternSet &patterns) { 2571 patterns.add<BroadcastOpLowering>(patterns.getContext()); 2572 } 2573 2574 void mlir::vector::populateVectorMaskOpLoweringPatterns( 2575 RewritePatternSet &patterns) { 2576 patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>( 2577 patterns.getContext()); 2578 } 2579 2580 void mlir::vector::populateVectorShapeCastLoweringPatterns( 2581 RewritePatternSet &patterns) { 2582 patterns.add<ShapeCastOp2DDownCastRewritePattern, 2583 ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>( 2584 patterns.getContext()); 2585 } 2586 2587 void mlir::vector::populateVectorContractLoweringPatterns( 2588 RewritePatternSet &patterns, VectorTransformsOptions options) { 2589 patterns.add<OuterProductOpLowering>(patterns.getContext()); 2590 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering, 2591 ContractionOpToOuterProductOpLowering>(options, 2592 patterns.getContext()); 2593 } 2594 2595 void mlir::vector::populateVectorTransposeLoweringPatterns( 2596 RewritePatternSet &patterns, VectorTransformsOptions options) { 2597 patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>( 2598 options, patterns.getContext()); 2599 } 2600 2601 void mlir::vector::populateVectorReductionToContractPatterns( 2602 RewritePatternSet &patterns) { 2603 patterns.add<MultiReduceToContract, CombineContractBroadcast, 2604 CombineContractTranspose>(patterns.getContext()); 2605 } 2606 2607 void mlir::vector:: 2608 populateVectorTransferCollapseInnerMostContiguousDimsPatterns( 2609 RewritePatternSet &patterns) { 2610 patterns.add<DropInnerMostUnitDims>(patterns.getContext()); 2611 } 2612 2613 void mlir::vector::populateVectorTransferLoweringPatterns( 2614 RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) { 2615 patterns.add<TransferReadToVectorLoadLowering, 2616 TransferWriteToVectorStoreLowering>(patterns.getContext(), 2617 maxTransferRank); 2618 patterns 2619 .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>( 2620 patterns.getContext()); 2621 } 2622 2623 void mlir::vector::populateVectorScanLoweringPatterns( 2624 RewritePatternSet &patterns) { 2625 patterns.add<ScanToArithOps>(patterns.getContext()); 2626 } 2627