1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites as 1->N patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 14 15 #include <type_traits> 16 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 19 #include "mlir/Dialect/Linalg/IR/Linalg.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/SCF/SCF.h" 22 #include "mlir/Dialect/Utils/IndexingUtils.h" 23 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 24 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 25 #include "mlir/IR/ImplicitLocOpBuilder.h" 26 #include "mlir/IR/Matchers.h" 27 #include "mlir/IR/PatternMatch.h" 28 #include "mlir/Interfaces/VectorInterfaces.h" 29 30 #include "llvm/ADT/DenseSet.h" 31 #include "llvm/ADT/MapVector.h" 32 #include "llvm/ADT/STLExtras.h" 33 #include "llvm/Support/CommandLine.h" 34 #include "llvm/Support/Debug.h" 35 #include "llvm/Support/raw_ostream.h" 36 37 #define DEBUG_TYPE "vector-to-vector" 38 39 using namespace mlir; 40 using namespace mlir::vector; 41 42 // Helper to find an index in an affine map. 43 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) { 44 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 45 int64_t idx = map.getDimPosition(i); 46 if (idx == index) 47 return i; 48 } 49 return None; 50 } 51 52 // Helper to construct iterator types with one index removed. 53 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes, 54 int64_t index) { 55 SmallVector<Attribute, 4> results; 56 for (const auto &it : llvm::enumerate(iteratorTypes)) { 57 int64_t idx = it.index(); 58 if (idx == index) 59 continue; 60 results.push_back(it.value()); 61 } 62 return results; 63 } 64 65 // Helper to construct an affine map with one index removed. 66 static AffineMap adjustMap(AffineMap map, int64_t index, 67 PatternRewriter &rewriter) { 68 auto *ctx = rewriter.getContext(); 69 SmallVector<AffineExpr, 4> results; 70 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 71 int64_t idx = map.getDimPosition(i); 72 if (idx == index) 73 continue; 74 // Re-insert remaining indices, but renamed when occurring 75 // after the removed index. 76 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); 77 results.push_back(targetExpr); 78 } 79 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); 80 } 81 82 // Helper method to possibly drop a dimension in a load. 83 // TODO 84 static Value reshapeLoad(Location loc, Value val, VectorType type, 85 int64_t index, int64_t pos, 86 PatternRewriter &rewriter) { 87 if (index == -1) 88 return val; 89 Type lowType = VectorType::Builder(type).dropDim(0); 90 // At extraction dimension? 91 if (index == 0) { 92 auto posAttr = rewriter.getI64ArrayAttr(pos); 93 return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr); 94 } 95 // Unroll leading dimensions. 96 VectorType vType = lowType.cast<VectorType>(); 97 Type resType = VectorType::Builder(type).dropDim(index); 98 auto resVectorType = resType.cast<VectorType>(); 99 Value result = rewriter.create<arith::ConstantOp>( 100 loc, resVectorType, rewriter.getZeroAttr(resVectorType)); 101 for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { 102 auto posAttr = rewriter.getI64ArrayAttr(d); 103 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr); 104 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); 105 result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, 106 posAttr); 107 } 108 return result; 109 } 110 111 // Helper method to possibly drop a dimension in a store. 112 // TODO 113 static Value reshapeStore(Location loc, Value val, Value result, 114 VectorType type, int64_t index, int64_t pos, 115 PatternRewriter &rewriter) { 116 // Unmodified? 117 if (index == -1) 118 return val; 119 // At insertion dimension? 120 if (index == 0) { 121 auto posAttr = rewriter.getI64ArrayAttr(pos); 122 return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr); 123 } 124 // Unroll leading dimensions. 125 Type lowType = VectorType::Builder(type).dropDim(0); 126 VectorType vType = lowType.cast<VectorType>(); 127 Type insType = VectorType::Builder(vType).dropDim(0); 128 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { 129 auto posAttr = rewriter.getI64ArrayAttr(d); 130 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr); 131 Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr); 132 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); 133 result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr); 134 } 135 return result; 136 } 137 138 template <typename IntType> 139 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) { 140 return llvm::to_vector<4>(llvm::map_range( 141 arrayAttr.getAsRange<IntegerAttr>(), 142 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); 143 } 144 145 namespace { 146 147 /// ShapeCastOpFolder folds cancelling ShapeCastOps away. 148 // 149 // Example: 150 // 151 // The following MLIR with cancelling ShapeCastOps: 152 // 153 // %0 = source : vector<5x4x2xf32> 154 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> 155 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> 156 // %3 = user %2 : vector<5x4x2xf32> 157 // 158 // Should canonicalize to the following: 159 // 160 // %0 = source : vector<5x4x2xf32> 161 // %1 = user %0 : vector<5x4x2xf32> 162 // 163 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> { 164 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 165 166 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, 167 PatternRewriter &rewriter) const override { 168 // Check if 'shapeCastOp' has vector source/result type. 169 auto sourceVectorType = 170 shapeCastOp.source().getType().dyn_cast_or_null<VectorType>(); 171 auto resultVectorType = 172 shapeCastOp.result().getType().dyn_cast_or_null<VectorType>(); 173 if (!sourceVectorType || !resultVectorType) 174 return failure(); 175 176 // Check if shape cast op source operand is also a shape cast op. 177 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>( 178 shapeCastOp.source().getDefiningOp()); 179 if (!sourceShapeCastOp) 180 return failure(); 181 auto operandSourceVectorType = 182 sourceShapeCastOp.source().getType().cast<VectorType>(); 183 auto operandResultVectorType = sourceShapeCastOp.getType(); 184 185 // Check if shape cast operations invert each other. 186 if (operandSourceVectorType != resultVectorType || 187 operandResultVectorType != sourceVectorType) 188 return failure(); 189 190 rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source()); 191 return success(); 192 } 193 }; 194 195 /// Progressive lowering of BroadcastOp. 196 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> { 197 public: 198 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern; 199 200 LogicalResult matchAndRewrite(vector::BroadcastOp op, 201 PatternRewriter &rewriter) const override { 202 auto loc = op.getLoc(); 203 VectorType dstType = op.getVectorType(); 204 VectorType srcType = op.getSourceType().dyn_cast<VectorType>(); 205 Type eltType = dstType.getElementType(); 206 207 // Scalar to any vector can use splat. 208 if (!srcType) { 209 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.source()); 210 return success(); 211 } 212 213 // Determine rank of source and destination. 214 int64_t srcRank = srcType.getRank(); 215 int64_t dstRank = dstType.getRank(); 216 217 // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. 218 if (srcRank <= 1 && dstRank == 1) { 219 Value ext; 220 if (srcRank == 0) 221 ext = rewriter.create<vector::ExtractElementOp>(loc, op.source()); 222 else 223 ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0); 224 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext); 225 return success(); 226 } 227 228 // Duplicate this rank. 229 // For example: 230 // %x = broadcast %y : k-D to n-D, k < n 231 // becomes: 232 // %b = broadcast %y : k-D to (n-1)-D 233 // %x = [%b,%b,%b,%b] : n-D 234 // becomes: 235 // %b = [%y,%y] : (n-1)-D 236 // %x = [%b,%b,%b,%b] : n-D 237 if (srcRank < dstRank) { 238 // Duplication. 239 VectorType resType = 240 VectorType::get(dstType.getShape().drop_front(), eltType); 241 Value bcst = 242 rewriter.create<vector::BroadcastOp>(loc, resType, op.source()); 243 Value result = rewriter.create<arith::ConstantOp>( 244 loc, dstType, rewriter.getZeroAttr(dstType)); 245 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 246 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 247 rewriter.replaceOp(op, result); 248 return success(); 249 } 250 251 // Find non-matching dimension, if any. 252 assert(srcRank == dstRank); 253 int64_t m = -1; 254 for (int64_t r = 0; r < dstRank; r++) 255 if (srcType.getDimSize(r) != dstType.getDimSize(r)) { 256 m = r; 257 break; 258 } 259 260 // All trailing dimensions are the same. Simply pass through. 261 if (m == -1) { 262 rewriter.replaceOp(op, op.source()); 263 return success(); 264 } 265 266 // Any non-matching dimension forces a stretch along this rank. 267 // For example: 268 // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32> 269 // becomes: 270 // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32> 271 // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32> 272 // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32> 273 // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32> 274 // %x = [%a,%b,%c,%d] 275 // becomes: 276 // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32> 277 // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32> 278 // %a = [%u, %v] 279 // .. 280 // %x = [%a,%b,%c,%d] 281 VectorType resType = 282 VectorType::get(dstType.getShape().drop_front(), eltType); 283 Value result = rewriter.create<arith::ConstantOp>( 284 loc, dstType, rewriter.getZeroAttr(dstType)); 285 if (m == 0) { 286 // Stetch at start. 287 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0); 288 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 289 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 290 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 291 } else { 292 // Stetch not at start. 293 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) { 294 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d); 295 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 296 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 297 } 298 } 299 rewriter.replaceOp(op, result); 300 return success(); 301 } 302 }; 303 304 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not 305 /// transposed. 306 void pruneNonTransposedDims(ArrayRef<int64_t> transpose, 307 SmallVectorImpl<int64_t> &result) { 308 size_t numTransposedDims = transpose.size(); 309 for (size_t transpDim : llvm::reverse(transpose)) { 310 if (transpDim != numTransposedDims - 1) 311 break; 312 numTransposedDims--; 313 } 314 315 result.append(transpose.begin(), transpose.begin() + numTransposedDims); 316 } 317 318 /// Progressive lowering of TransposeOp. 319 /// One: 320 /// %x = vector.transpose %y, [1, 0] 321 /// is replaced by: 322 /// %z = arith.constant dense<0.000000e+00> 323 /// %0 = vector.extract %y[0, 0] 324 /// %1 = vector.insert %0, %z [0, 0] 325 /// .. 326 /// %x = vector.insert .., .. [.., ..] 327 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { 328 public: 329 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 330 331 TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions, 332 MLIRContext *context) 333 : OpRewritePattern<vector::TransposeOp>(context), 334 vectorTransformOptions(vectorTransformOptions) {} 335 336 LogicalResult matchAndRewrite(vector::TransposeOp op, 337 PatternRewriter &rewriter) const override { 338 auto loc = op.getLoc(); 339 340 Value input = op.vector(); 341 VectorType inputType = op.getVectorType(); 342 VectorType resType = op.getResultType(); 343 344 // Set up convenience transposition table. 345 SmallVector<int64_t, 4> transp; 346 for (auto attr : op.transp()) 347 transp.push_back(attr.cast<IntegerAttr>().getInt()); 348 349 if (vectorTransformOptions.vectorTransposeLowering == 350 vector::VectorTransposeLowering::Shuffle && 351 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) 352 return rewriter.notifyMatchFailure( 353 op, "Options specifies lowering to shuffle"); 354 355 // Handle a true 2-D matrix transpose differently when requested. 356 if (vectorTransformOptions.vectorTransposeLowering == 357 vector::VectorTransposeLowering::Flat && 358 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { 359 Type flattenedType = 360 VectorType::get(resType.getNumElements(), resType.getElementType()); 361 auto matrix = 362 rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input); 363 auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); 364 auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); 365 Value trans = rewriter.create<vector::FlatTransposeOp>( 366 loc, flattenedType, matrix, rows, columns); 367 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans); 368 return success(); 369 } 370 371 // Generate unrolled extract/insert ops. We do not unroll the rightmost 372 // (i.e., highest-order) dimensions that are not transposed and leave them 373 // in vector form to improve performance. Therefore, we prune those 374 // dimensions from the shape/transpose data structures used to generate the 375 // extract/insert ops. 376 SmallVector<int64_t, 4> prunedTransp; 377 pruneNonTransposedDims(transp, prunedTransp); 378 size_t numPrunedDims = transp.size() - prunedTransp.size(); 379 auto prunedInShape = inputType.getShape().drop_back(numPrunedDims); 380 SmallVector<int64_t, 4> ones(prunedInShape.size(), 1); 381 auto prunedInStrides = computeStrides(prunedInShape, ones); 382 383 // Generates the extract/insert operations for every scalar/vector element 384 // of the leftmost transposed dimensions. We traverse every transpose 385 // element using a linearized index that we delinearize to generate the 386 // appropriate indices for the extract/insert operations. 387 Value result = rewriter.create<arith::ConstantOp>( 388 loc, resType, rewriter.getZeroAttr(resType)); 389 int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape); 390 391 for (int64_t linearIdx = 0; linearIdx < numTransposedElements; 392 ++linearIdx) { 393 auto extractIdxs = delinearize(prunedInStrides, linearIdx); 394 SmallVector<int64_t, 4> insertIdxs(extractIdxs); 395 applyPermutationToVector(insertIdxs, prunedTransp); 396 Value extractOp = 397 rewriter.create<vector::ExtractOp>(loc, input, extractIdxs); 398 result = 399 rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs); 400 } 401 402 rewriter.replaceOp(op, result); 403 return success(); 404 } 405 406 private: 407 /// Options to control the vector patterns. 408 vector::VectorTransformsOptions vectorTransformOptions; 409 }; 410 411 /// Rewrite a 2-D vector.transpose as a sequence of: 412 /// vector.shape_cast 2D -> 1D 413 /// vector.shuffle 414 /// vector.shape_cast 1D -> 2D 415 class TransposeOp2DToShuffleLowering 416 : public OpRewritePattern<vector::TransposeOp> { 417 public: 418 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 419 420 TransposeOp2DToShuffleLowering( 421 vector::VectorTransformsOptions vectorTransformOptions, 422 MLIRContext *context) 423 : OpRewritePattern<vector::TransposeOp>(context), 424 vectorTransformOptions(vectorTransformOptions) {} 425 426 LogicalResult matchAndRewrite(vector::TransposeOp op, 427 PatternRewriter &rewriter) const override { 428 auto loc = op.getLoc(); 429 430 VectorType srcType = op.getVectorType(); 431 if (srcType.getRank() != 2) 432 return rewriter.notifyMatchFailure(op, "Not a 2D transpose"); 433 434 SmallVector<int64_t, 4> transp; 435 for (auto attr : op.transp()) 436 transp.push_back(attr.cast<IntegerAttr>().getInt()); 437 if (transp[0] != 1 && transp[1] != 0) 438 return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation"); 439 440 if (vectorTransformOptions.vectorTransposeLowering != 441 VectorTransposeLowering::Shuffle) 442 return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle"); 443 444 int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); 445 Value casted = rewriter.create<vector::ShapeCastOp>( 446 loc, VectorType::get({m * n}, srcType.getElementType()), op.vector()); 447 SmallVector<int64_t> mask; 448 mask.reserve(m * n); 449 for (int64_t j = 0; j < n; ++j) 450 for (int64_t i = 0; i < m; ++i) 451 mask.push_back(i * n + j); 452 453 Value shuffled = 454 rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask); 455 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(), 456 shuffled); 457 458 return success(); 459 } 460 461 private: 462 /// Options to control the vector patterns. 463 vector::VectorTransformsOptions vectorTransformOptions; 464 }; 465 466 /// Progressive lowering of OuterProductOp. 467 /// One: 468 /// %x = vector.outerproduct %lhs, %rhs, %acc 469 /// is replaced by: 470 /// %z = zero-result 471 /// %0 = vector.extract %lhs[0] 472 /// %1 = vector.broadcast %0 473 /// %2 = vector.extract %acc[0] 474 /// %3 = vector.fma %1, %rhs, %2 475 /// %4 = vector.insert %3, %z[0] 476 /// .. 477 /// %x = vector.insert %.., %..[N-1] 478 /// 479 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { 480 public: 481 using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern; 482 483 LogicalResult matchAndRewrite(vector::OuterProductOp op, 484 PatternRewriter &rewriter) const override { 485 auto loc = op.getLoc(); 486 487 VectorType lhsType = op.getOperandVectorTypeLHS(); 488 VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>(); 489 VectorType resType = op.getVectorType(); 490 Type eltType = resType.getElementType(); 491 bool isInt = eltType.isa<IntegerType, IndexType>(); 492 Value acc = (op.acc().empty()) ? nullptr : op.acc()[0]; 493 vector::CombiningKind kind = op.kind(); 494 495 if (!rhsType) { 496 // Special case: AXPY operation. 497 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs()); 498 Optional<Value> mult = 499 isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter) 500 : genMultF(loc, op.lhs(), b, acc, kind, rewriter); 501 if (!mult.hasValue()) 502 return failure(); 503 rewriter.replaceOp(op, mult.getValue()); 504 return success(); 505 } 506 507 Value result = rewriter.create<arith::ConstantOp>( 508 loc, resType, rewriter.getZeroAttr(resType)); 509 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { 510 auto pos = rewriter.getI64ArrayAttr(d); 511 Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos); 512 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x); 513 Value r = nullptr; 514 if (acc) 515 r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos); 516 Optional<Value> m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter) 517 : genMultF(loc, a, op.rhs(), r, kind, rewriter); 518 if (!m.hasValue()) 519 return failure(); 520 result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(), 521 result, pos); 522 } 523 rewriter.replaceOp(op, result); 524 return success(); 525 } 526 527 private: 528 static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc, 529 vector::CombiningKind kind, 530 PatternRewriter &rewriter) { 531 using vector::CombiningKind; 532 533 auto mul = rewriter.create<arith::MulIOp>(loc, x, y); 534 if (!acc) 535 return Optional<Value>(mul); 536 537 if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) 538 // Only valid for floating point types. 539 return Optional<Value>(); 540 541 return makeArithReduction(rewriter, loc, kind, mul, acc); 542 } 543 544 static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc, 545 vector::CombiningKind kind, 546 PatternRewriter &rewriter) { 547 using vector::CombiningKind; 548 549 // Special case for fused multiply-add. 550 if (acc && kind == CombiningKind::ADD) { 551 return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc)); 552 } 553 554 auto mul = rewriter.create<arith::MulFOp>(loc, x, y); 555 556 if (!acc) 557 return Optional<Value>(mul); 558 559 if (kind == CombiningKind::ADD || kind == CombiningKind::AND || 560 kind == CombiningKind::MINUI || kind == CombiningKind::MINSI || 561 kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI || 562 kind == CombiningKind::OR || kind == CombiningKind::XOR) 563 // Already handled or only valid for integer types. 564 return Optional<Value>(); 565 566 return makeArithReduction(rewriter, loc, kind, mul, acc); 567 } 568 }; 569 570 /// Progressive lowering of ConstantMaskOp. 571 /// One: 572 /// %x = vector.constant_mask [a,b] 573 /// is replaced by: 574 /// %z = zero-result 575 /// %l = vector.constant_mask [b] 576 /// %4 = vector.insert %l, %z[0] 577 /// .. 578 /// %x = vector.insert %l, %..[a-1] 579 /// until a one-dimensional vector is reached. All these operations 580 /// will be folded at LLVM IR level. 581 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { 582 public: 583 using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern; 584 585 LogicalResult matchAndRewrite(vector::ConstantMaskOp op, 586 PatternRewriter &rewriter) const override { 587 auto loc = op.getLoc(); 588 auto dstType = op.getType(); 589 auto eltType = dstType.getElementType(); 590 auto dimSizes = op.mask_dim_sizes(); 591 int64_t rank = dstType.getRank(); 592 593 if (rank == 0) { 594 assert(dimSizes.size() == 1 && 595 "Expected exactly one dim size for a 0-D vector"); 596 bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1; 597 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 598 op, dstType, 599 DenseIntElementsAttr::get( 600 VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()), 601 ArrayRef<bool>{value})); 602 return success(); 603 } 604 605 int64_t trueDim = std::min(dstType.getDimSize(0), 606 dimSizes[0].cast<IntegerAttr>().getInt()); 607 608 if (rank == 1) { 609 // Express constant 1-D case in explicit vector form: 610 // [T,..,T,F,..,F]. 611 SmallVector<bool, 4> values(dstType.getDimSize(0)); 612 for (int64_t d = 0; d < trueDim; d++) 613 values[d] = true; 614 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 615 op, dstType, rewriter.getBoolVectorAttr(values)); 616 return success(); 617 } 618 619 VectorType lowType = 620 VectorType::get(dstType.getShape().drop_front(), eltType); 621 SmallVector<int64_t, 4> newDimSizes; 622 for (int64_t r = 1; r < rank; r++) 623 newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt()); 624 Value trueVal = rewriter.create<vector::ConstantMaskOp>( 625 loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); 626 Value result = rewriter.create<arith::ConstantOp>( 627 loc, dstType, rewriter.getZeroAttr(dstType)); 628 for (int64_t d = 0; d < trueDim; d++) { 629 auto pos = rewriter.getI64ArrayAttr(d); 630 result = 631 rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos); 632 } 633 rewriter.replaceOp(op, result); 634 return success(); 635 } 636 }; 637 638 /// Progressive lowering of CreateMaskOp. 639 /// One: 640 /// %x = vector.create_mask %a, ... : vector<dx...> 641 /// is replaced by: 642 /// %l = vector.create_mask ... : vector<...> ; one lower rank 643 /// %0 = arith.cmpi "slt", %ci, %a | 644 /// %1 = select %0, %l, %zeroes | 645 /// %r = vector.insert %1, %pr [i] | d-times 646 /// %x = .... 647 /// until a one-dimensional vector is reached. 648 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { 649 public: 650 using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern; 651 652 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 653 PatternRewriter &rewriter) const override { 654 auto dstType = op.getResult().getType().cast<VectorType>(); 655 int64_t rank = dstType.getRank(); 656 if (rank <= 1) 657 return rewriter.notifyMatchFailure( 658 op, "0-D and 1-D vectors are handled separately"); 659 660 auto loc = op.getLoc(); 661 auto eltType = dstType.getElementType(); 662 int64_t dim = dstType.getDimSize(0); 663 Value idx = op.getOperand(0); 664 665 VectorType lowType = 666 VectorType::get(dstType.getShape().drop_front(), eltType); 667 Value trueVal = rewriter.create<vector::CreateMaskOp>( 668 loc, lowType, op.getOperands().drop_front()); 669 Value falseVal = rewriter.create<arith::ConstantOp>( 670 loc, lowType, rewriter.getZeroAttr(lowType)); 671 Value result = rewriter.create<arith::ConstantOp>( 672 loc, dstType, rewriter.getZeroAttr(dstType)); 673 for (int64_t d = 0; d < dim; d++) { 674 Value bnd = 675 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d)); 676 Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 677 bnd, idx); 678 Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal); 679 auto pos = rewriter.getI64ArrayAttr(d); 680 result = 681 rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos); 682 } 683 rewriter.replaceOp(op, result); 684 return success(); 685 } 686 }; 687 688 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D 689 /// vectors progressively on the way to target llvm.matrix intrinsics. 690 /// This iterates over the most major dimension of the 2-D vector and performs 691 /// rewrites into: 692 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D 693 class ShapeCastOp2DDownCastRewritePattern 694 : public OpRewritePattern<vector::ShapeCastOp> { 695 public: 696 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 697 698 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 699 PatternRewriter &rewriter) const override { 700 auto sourceVectorType = op.getSourceVectorType(); 701 auto resultVectorType = op.getResultVectorType(); 702 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) 703 return failure(); 704 705 auto loc = op.getLoc(); 706 Value desc = rewriter.create<arith::ConstantOp>( 707 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 708 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; 709 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { 710 Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i); 711 desc = rewriter.create<vector::InsertStridedSliceOp>( 712 loc, vec, desc, 713 /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); 714 } 715 rewriter.replaceOp(op, desc); 716 return success(); 717 } 718 }; 719 720 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D 721 /// vectors progressively. 722 /// This iterates over the most major dimension of the 2-D vector and performs 723 /// rewrites into: 724 /// vector.extract_strided_slice from 1-D + vector.insert into 2-D 725 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. 726 class ShapeCastOp2DUpCastRewritePattern 727 : public OpRewritePattern<vector::ShapeCastOp> { 728 public: 729 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 730 731 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 732 PatternRewriter &rewriter) const override { 733 auto sourceVectorType = op.getSourceVectorType(); 734 auto resultVectorType = op.getResultVectorType(); 735 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) 736 return failure(); 737 738 auto loc = op.getLoc(); 739 Value desc = rewriter.create<arith::ConstantOp>( 740 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 741 unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; 742 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { 743 Value vec = rewriter.create<vector::ExtractStridedSliceOp>( 744 loc, op.source(), /*offsets=*/i * mostMinorVectorSize, 745 /*sizes=*/mostMinorVectorSize, 746 /*strides=*/1); 747 desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i); 748 } 749 rewriter.replaceOp(op, desc); 750 return success(); 751 } 752 }; 753 754 // We typically should not lower general shape cast operations into data 755 // movement instructions, since the assumption is that these casts are 756 // optimized away during progressive lowering. For completeness, however, 757 // we fall back to a reference implementation that moves all elements 758 // into the right place if we get here. 759 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { 760 public: 761 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 762 763 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 764 PatternRewriter &rewriter) const override { 765 Location loc = op.getLoc(); 766 auto sourceVectorType = op.getSourceVectorType(); 767 auto resultVectorType = op.getResultVectorType(); 768 769 // Special case 2D/1D lowerings with better implementations. 770 // TODO: make is ND/1D to allow generic ND->1D->MD. 771 int64_t srcRank = sourceVectorType.getRank(); 772 int64_t resRank = resultVectorType.getRank(); 773 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2)) 774 return failure(); 775 776 // Generic ShapeCast lowering path goes all the way down to unrolled scalar 777 // extract/insert chains. 778 // TODO: consider evolving the semantics to only allow 1D source or dest and 779 // drop this potentially very expensive lowering. 780 // Compute number of elements involved in the reshape. 781 int64_t numElts = 1; 782 for (int64_t r = 0; r < srcRank; r++) 783 numElts *= sourceVectorType.getDimSize(r); 784 // Replace with data movement operations: 785 // x[0,0,0] = y[0,0] 786 // x[0,0,1] = y[0,1] 787 // x[0,1,0] = y[0,2] 788 // etc., incrementing the two index vectors "row-major" 789 // within the source and result shape. 790 SmallVector<int64_t, 4> srcIdx(srcRank); 791 SmallVector<int64_t, 4> resIdx(resRank); 792 Value result = rewriter.create<arith::ConstantOp>( 793 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 794 for (int64_t i = 0; i < numElts; i++) { 795 if (i != 0) { 796 incIdx(srcIdx, sourceVectorType, srcRank - 1); 797 incIdx(resIdx, resultVectorType, resRank - 1); 798 } 799 Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx); 800 result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx); 801 } 802 rewriter.replaceOp(op, result); 803 return success(); 804 } 805 806 private: 807 static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) { 808 assert(0 <= r && r < tp.getRank()); 809 if (++idx[r] == tp.getDimSize(r)) { 810 idx[r] = 0; 811 incIdx(idx, tp, r - 1); 812 } 813 } 814 }; 815 816 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp. 817 /// Ex: 818 /// ``` 819 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> 820 /// %1 = vector.multi_reduction add, %0 [1] 821 /// : vector<8x32x16xf32> to vector<8x16xf32> 822 /// ``` 823 /// Gets converted to: 824 /// ``` 825 /// %1 = vector.contract {indexing_maps = [ 826 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 827 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 828 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 829 /// iterator_types = ["parallel", "parallel", "reduction"], 830 /// kind = add} %0, %arg1, %cst_f0 831 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 832 /// ``` 833 struct MultiReduceToContract 834 : public OpRewritePattern<vector::MultiDimReductionOp> { 835 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 836 837 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, 838 PatternRewriter &rewriter) const override { 839 if (reduceOp.kind() != vector::CombiningKind::ADD) 840 return failure(); 841 Operation *mulOp = reduceOp.source().getDefiningOp(); 842 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp)) 843 return failure(); 844 SmallVector<bool> reductionMask = reduceOp.getReductionMask(); 845 auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size()); 846 SmallVector<AffineExpr> exprs; 847 SmallVector<StringRef> iteratorTypes; 848 for (const auto &isReduceDim : llvm::enumerate(reductionMask)) { 849 if (!isReduceDim.value()) { 850 iteratorTypes.push_back(getParallelIteratorTypeName()); 851 exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index())); 852 } else { 853 iteratorTypes.push_back(getReductionIteratorTypeName()); 854 } 855 } 856 auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(), 857 /*symCount=*/0, exprs, reduceOp.getContext()); 858 Value zero = rewriter.create<arith::ConstantOp>( 859 reduceOp.getLoc(), reduceOp.getDestType(), 860 rewriter.getZeroAttr(reduceOp.getDestType())); 861 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>( 862 reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero, 863 rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}), 864 rewriter.getStrArrayAttr(iteratorTypes)); 865 return success(); 866 } 867 }; 868 869 /// Merge TransposeOp into ContractionOp user. 870 /// Ex: 871 /// ``` 872 /// %0 = vector.transpose %arg0, [2, 0, 1] 873 /// : vector<32x16x8xf32> to vector<8x32x16xf32> 874 /// %1 = vector.contract {indexing_maps = [ 875 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 876 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 877 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 878 /// iterator_types = ["parallel", "parallel", "reduction"], 879 /// kind = add} %0, %arg1, %cst_f0 880 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 881 /// ``` 882 /// Gets converted to: 883 /// ``` 884 /// %1 = vector.contract {indexing_maps = [ 885 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>, 886 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 887 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 888 /// iterator_types = ["parallel", "parallel", "reduction"], 889 /// kind = add} %arg0, %arg1, %cst_f0 890 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 891 /// ``` 892 struct CombineContractTranspose 893 : public OpRewritePattern<vector::ContractionOp> { 894 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 895 896 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 897 PatternRewriter &rewriter) const override { 898 SmallVector<AffineMap, 4> maps = 899 llvm::to_vector<4>(contractOp.getIndexingMaps()); 900 Value lhs = contractOp.lhs(); 901 Value rhs = contractOp.rhs(); 902 size_t index = 0; 903 bool changed = false; 904 for (Value *operand : {&lhs, &rhs}) { 905 AffineMap &map = maps[index++]; 906 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>(); 907 if (!transposeOp) 908 continue; 909 SmallVector<int64_t> perm; 910 transposeOp.getTransp(perm); 911 AffineMap permutationMap = AffineMap::getPermutationMap( 912 extractVector<unsigned>(transposeOp.transp()), 913 contractOp.getContext()); 914 map = inversePermutation(permutationMap).compose(map); 915 *operand = transposeOp.vector(); 916 changed = true; 917 } 918 if (!changed) 919 return failure(); 920 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 921 contractOp, lhs, rhs, contractOp.acc(), 922 rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); 923 return success(); 924 } 925 }; 926 927 /// Merge BroadcastOp into ContractionOp user. 928 /// Ex: 929 /// ``` 930 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> 931 /// %1 = vector.contract {indexing_maps = [ 932 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 933 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 934 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 935 /// iterator_types = ["parallel", "parallel", "reduction"], 936 /// kind = add} %0, %arg1, %cst_f0 937 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 938 /// ``` 939 /// Gets converted to: 940 /// ``` 941 /// %1 = vector.contract {indexing_maps = [ 942 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 943 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 944 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 945 /// iterator_types = ["parallel", "parallel", "reduction"], 946 /// kind = add} %arg0, %arg1, %cst_f0 947 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 948 /// ``` 949 struct CombineContractBroadcast 950 : public OpRewritePattern<vector::ContractionOp> { 951 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 952 953 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 954 PatternRewriter &rewriter) const override { 955 SmallVector<AffineMap, 4> maps = 956 llvm::to_vector<4>(contractOp.getIndexingMaps()); 957 Value lhs = contractOp.lhs(); 958 Value rhs = contractOp.rhs(); 959 size_t index = 0; 960 bool changed = false; 961 for (Value *operand : {&lhs, &rhs}) { 962 AffineMap &map = maps[index++]; 963 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>(); 964 if (!broadcast) 965 continue; 966 // contractionOp can only take vector as operands. 967 auto srcType = broadcast.getSourceType().dyn_cast<VectorType>(); 968 if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank()) 969 continue; 970 int64_t rankDiff = 971 broadcast.getVectorType().getRank() - srcType.getRank(); 972 bool innerDimBroadcast = false; 973 SmallVector<AffineExpr> originalDims; 974 for (const auto &dim : llvm::enumerate(srcType.getShape())) { 975 if (dim.value() != 976 broadcast.getVectorType().getDimSize(rankDiff + dim.index())) { 977 innerDimBroadcast = true; 978 break; 979 } 980 originalDims.push_back( 981 rewriter.getAffineDimExpr(dim.index() + rankDiff)); 982 } 983 // Contract doesn't support inner dimension broadcast. Once this is 984 // relaxed we can remove this case. 985 if (innerDimBroadcast) 986 continue; 987 AffineMap broadcastMap = 988 AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims, 989 contractOp.getContext()); 990 map = broadcastMap.compose(map); 991 *operand = broadcast.source(); 992 changed = true; 993 } 994 if (!changed) 995 return failure(); 996 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 997 contractOp, lhs, rhs, contractOp.acc(), 998 rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); 999 return success(); 1000 } 1001 }; 1002 1003 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and 1004 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when 1005 /// casting ops are around these operations. 1006 /// Ex: 1007 /// ``` 1008 /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8> 1009 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> 1010 /// ``` 1011 /// Gets converted to: 1012 /// ``` 1013 /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32> 1014 /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32> 1015 /// ``` 1016 struct ReorderCastOpsOnBroadcast 1017 : public OpInterfaceRewritePattern<CastOpInterface> { 1018 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern; 1019 1020 LogicalResult matchAndRewrite(CastOpInterface op, 1021 PatternRewriter &rewriter) const override { 1022 if (op->getNumOperands() != 1) 1023 return failure(); 1024 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>(); 1025 if (!bcastOp) 1026 return failure(); 1027 1028 Type castResTy = getElementTypeOrSelf(op->getResult(0)); 1029 if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>()) 1030 castResTy = VectorType::get(vecTy.getShape(), castResTy); 1031 auto castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), 1032 bcastOp.source(), castResTy, op->getAttrs()); 1033 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1034 op, op->getResult(0).getType(), castOp->getResult(0)); 1035 return success(); 1036 } 1037 }; 1038 1039 /// Reorders cast(transpose) to transpose(cast). This makes broadcast ops and 1040 /// contraction ops closer, which kicks in CombineContractTranspose pattern when 1041 /// casting ops are around these operations. 1042 /// Ex: 1043 /// ``` 1044 /// %0 = vector.transpose %arg0, [2, 0, 1] 1045 /// : vector<32x16x8xi8> to vector<8x32x16xi8> 1046 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> 1047 /// ``` 1048 /// Gets converted to: 1049 /// ``` 1050 /// %0 = arith.extsi %0 : vector<32x16x8xi8> to vector<32x16x8xi32> 1051 /// %1 = vector.transpose %arg0, [2, 0, 1] 1052 /// : vector<32x16x8xi32> to vector<8x32x16xi32> 1053 /// ``` 1054 struct ReorderCastOpsOnTranspose 1055 : public OpInterfaceRewritePattern<CastOpInterface> { 1056 1057 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern; 1058 1059 LogicalResult matchAndRewrite(CastOpInterface op, 1060 PatternRewriter &rewriter) const override { 1061 if (op->getNumOperands() != 1) 1062 return failure(); 1063 auto transpOp = op->getOperand(0).getDefiningOp<vector::TransposeOp>(); 1064 if (!transpOp) 1065 return failure(); 1066 1067 auto castResTy = transpOp.getVectorType(); 1068 castResTy = VectorType::get(castResTy.getShape(), 1069 getElementTypeOrSelf(op->getResult(0))); 1070 auto castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), 1071 transpOp.vector(), castResTy, op->getAttrs()); 1072 rewriter.replaceOpWithNewOp<vector::TransposeOp>( 1073 op, op->getResult(0).getType(), castOp->getResult(0), 1074 transpOp.getTransp()); 1075 return success(); 1076 } 1077 }; 1078 1079 } // namespace 1080 1081 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using 1082 /// operands `x` and `y`. 1083 static Value createAdd(Location loc, Value x, Value y, bool isInt, 1084 PatternRewriter &rewriter) { 1085 if (isInt) 1086 return rewriter.create<arith::AddIOp>(loc, x, y); 1087 return rewriter.create<arith::AddFOp>(loc, x, y); 1088 } 1089 1090 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using 1091 /// operands `x and `y`. 1092 static Value createMul(Location loc, Value x, Value y, bool isInt, 1093 PatternRewriter &rewriter) { 1094 if (isInt) 1095 return rewriter.create<arith::MulIOp>(loc, x, y); 1096 return rewriter.create<arith::MulFOp>(loc, x, y); 1097 } 1098 1099 namespace mlir { 1100 1101 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1102 /// semantics to: 1103 /// ``` 1104 /// %mta = maybe_transpose 1105 /// %mtb = maybe_transpose 1106 /// %flattened_a = vector.shape_cast %mta 1107 /// %flattened_b = vector.shape_cast %mtb 1108 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 1109 /// %mtd = vector.shape_cast %flattened_d 1110 /// %d = maybe_untranspose %mtd 1111 /// %e = add %c, %d 1112 /// ``` 1113 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 1114 // 1115 /// This only kicks in when VectorTransformsOptions is set to `Matmul`. 1116 /// vector.transpose operations are inserted if the vector.contract op is not a 1117 /// row-major matrix multiply. 1118 LogicalResult 1119 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, 1120 PatternRewriter &rew) const { 1121 // TODO: implement masks 1122 if (llvm::size(op.masks()) != 0) 1123 return failure(); 1124 if (vectorTransformOptions.vectorContractLowering != 1125 vector::VectorContractLowering::Matmul) 1126 return failure(); 1127 if (failed(filter(op))) 1128 return failure(); 1129 1130 auto iteratorTypes = op.iterator_types().getValue(); 1131 if (!isParallelIterator(iteratorTypes[0]) || 1132 !isParallelIterator(iteratorTypes[1]) || 1133 !isReductionIterator(iteratorTypes[2])) 1134 return failure(); 1135 1136 Type elementType = op.getLhsType().getElementType(); 1137 if (!elementType.isIntOrFloat()) 1138 return failure(); 1139 1140 // Perform lhs + rhs transpositions to conform to matmul row-major semantics. 1141 // Bail out if the contraction cannot be put in this form. 1142 MLIRContext *ctx = op.getContext(); 1143 Location loc = op.getLoc(); 1144 AffineExpr m, n, k; 1145 bindDims(rew.getContext(), m, n, k); 1146 // LHS must be A(m, k) or A(k, m). 1147 Value lhs = op.lhs(); 1148 auto lhsMap = op.indexing_maps()[0]; 1149 if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) 1150 lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0}); 1151 else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) 1152 return failure(); 1153 1154 // RHS must be B(k, n) or B(n, k). 1155 Value rhs = op.rhs(); 1156 auto rhsMap = op.indexing_maps()[1]; 1157 if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) 1158 rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0}); 1159 else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) 1160 return failure(); 1161 1162 // At this point lhs and rhs are in row-major. 1163 VectorType lhsType = lhs.getType().cast<VectorType>(); 1164 VectorType rhsType = rhs.getType().cast<VectorType>(); 1165 int64_t lhsRows = lhsType.getDimSize(0); 1166 int64_t lhsColumns = lhsType.getDimSize(1); 1167 int64_t rhsColumns = rhsType.getDimSize(1); 1168 1169 Type flattenedLHSType = 1170 VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); 1171 lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs); 1172 1173 Type flattenedRHSType = 1174 VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); 1175 rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs); 1176 1177 Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns, 1178 rhsColumns); 1179 mul = rew.create<vector::ShapeCastOp>( 1180 loc, 1181 VectorType::get({lhsRows, rhsColumns}, 1182 getElementTypeOrSelf(op.acc().getType())), 1183 mul); 1184 1185 // ACC must be C(m, n) or C(n, m). 1186 auto accMap = op.indexing_maps()[2]; 1187 if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) 1188 mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0}); 1189 else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) 1190 llvm_unreachable("invalid contraction semantics"); 1191 1192 Value res = 1193 elementType.isa<IntegerType>() 1194 ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.acc(), mul)) 1195 : static_cast<Value>(rew.create<arith::AddFOp>(loc, op.acc(), mul)); 1196 1197 rew.replaceOp(op, res); 1198 return success(); 1199 } 1200 1201 namespace { 1202 struct IteratorType { 1203 IteratorType(StringRef strRef) : strRef(strRef) {} 1204 bool isOfType(Attribute attr) const { 1205 auto sAttr = attr.dyn_cast<StringAttr>(); 1206 return sAttr && sAttr.getValue() == strRef; 1207 } 1208 StringRef strRef; 1209 }; 1210 struct Par : public IteratorType { 1211 Par() : IteratorType(getParallelIteratorTypeName()) {} 1212 }; 1213 struct Red : public IteratorType { 1214 Red() : IteratorType(getReductionIteratorTypeName()) {} 1215 }; 1216 1217 /// Generate a vector implementation for matmat, matvec and tmatvec. 1218 /// This unrolls outer-products along the reduction dimension. 1219 struct UnrolledOuterProductGenerator 1220 : public StructuredGenerator<vector::ContractionOp> { 1221 1222 UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op) 1223 : StructuredGenerator<vector::ContractionOp>(builder, op), 1224 kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()), 1225 lhsType(op.getLhsType()) {} 1226 1227 Value t(Value v) { 1228 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1229 return builder.create<vector::TransposeOp>(loc, v, perm); 1230 } 1231 1232 Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) { 1233 assert(reductionSize > 0); 1234 for (int64_t k = 0; k < reductionSize; ++k) { 1235 Value a = builder.create<vector::ExtractOp>(loc, lhs, k); 1236 Value b = builder.create<vector::ExtractOp>(loc, rhs, k); 1237 res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b, 1238 res, kind); 1239 } 1240 return res; 1241 } 1242 1243 /// Two outer parallel, one inner reduction (matmat flavor). 1244 FailureOr<Value> matmat() { 1245 if (!iters({Par(), Par(), Red()})) 1246 return failure(); 1247 // Set up the parallel/reduction structure in the right form. 1248 AffineExpr m, n, k; 1249 bindDims(builder.getContext(), m, n, k); 1250 // Classical row-major matmul: Just permute the lhs. 1251 if (layout({{m, k}, {k, n}, {m, n}})) 1252 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1253 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1254 if (layout({{m, k}, {n, k}, {m, n}})) { 1255 Value tlhs = t(lhs); 1256 return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1)); 1257 } 1258 // No need to permute anything. 1259 if (layout({{k, m}, {k, n}, {m, n}})) 1260 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1261 // Just permute the rhs. 1262 if (layout({{k, m}, {n, k}, {m, n}})) 1263 return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0)); 1264 // Transposed output: swap RHS and LHS. 1265 // Classical row-major matmul: permute the lhs. 1266 if (layout({{m, k}, {k, n}, {n, m}})) 1267 return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1)); 1268 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1269 if (layout({{m, k}, {n, k}, {n, m}})) { 1270 Value trhs = t(rhs); 1271 return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1)); 1272 } 1273 if (layout({{k, m}, {k, n}, {n, m}})) 1274 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1275 if (layout({{k, m}, {n, k}, {n, m}})) 1276 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1277 return failure(); 1278 } 1279 1280 /// One outer parallel, one inner reduction (matvec flavor) 1281 FailureOr<Value> matvec() { 1282 if (!iters({Par(), Red()})) 1283 return failure(); 1284 AffineExpr m, k; 1285 bindDims(builder.getContext(), m, k); 1286 1287 // Case mat-vec: transpose. 1288 if (layout({{m, k}, {k}, {m}})) 1289 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1290 // Case mat-trans-vec: ready to go. 1291 if (layout({{k, m}, {k}, {m}})) 1292 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1293 // Case vec-mat: swap and transpose. 1294 if (layout({{k}, {m, k}, {m}})) 1295 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1296 // Case vec-mat-trans: swap and ready to go. 1297 if (layout({{k}, {k, m}, {m}})) 1298 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1299 return failure(); 1300 } 1301 1302 // 1303 // One outer reduction, one inner parallel (tmatvec flavor) 1304 // 1305 FailureOr<Value> tmatvec() { 1306 if (!iters({Red(), Par()})) 1307 return failure(); 1308 AffineExpr k, m; 1309 bindDims(builder.getContext(), k, m); 1310 1311 // Case mat-vec: transpose. 1312 if (layout({{m, k}, {k}, {m}})) 1313 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1314 // Case mat-trans-vec: ready to go. 1315 if (layout({{k, m}, {k}, {m}})) 1316 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1317 // Case vec-mat: swap and transpose. 1318 if (layout({{k}, {m, k}, {m}})) 1319 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1320 // Case vec-mat-trans: swap and ready to go. 1321 if (layout({{k}, {k, m}, {m}})) 1322 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1323 return failure(); 1324 } 1325 1326 private: 1327 vector::CombiningKind kind; 1328 Value lhs, rhs, res; 1329 VectorType lhsType; 1330 }; 1331 } // namespace 1332 1333 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1334 /// semantics to a reduction_size-unrolled sequence: 1335 /// ``` 1336 /// %at = vector.transpose %a, [1, 0] 1337 /// %bRow0 = vector.extract %b[0] 1338 /// %atRow0 = vector.extract %at[0] 1339 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 1340 /// ... 1341 /// %bRowK = vector.extract %b[K] 1342 /// %atRowK = vector.extract %at[K] 1343 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 1344 /// ``` 1345 /// 1346 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but 1347 /// otherwise supports any layout permutation of the matrix-multiply. 1348 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( 1349 vector::ContractionOp op, PatternRewriter &rewriter) const { 1350 // TODO: implement masks 1351 if (llvm::size(op.masks()) != 0) 1352 return failure(); 1353 1354 if (vectorTransformOptions.vectorContractLowering != 1355 vector::VectorContractLowering::OuterProduct) 1356 return failure(); 1357 1358 if (failed(filter(op))) 1359 return failure(); 1360 1361 UnrolledOuterProductGenerator e(rewriter, op); 1362 FailureOr<Value> matmatRes = e.matmat(); 1363 if (succeeded(matmatRes)) { 1364 rewriter.replaceOp(op, *matmatRes); 1365 return success(); 1366 } 1367 FailureOr<Value> matvecRes = e.matvec(); 1368 if (succeeded(matvecRes)) { 1369 rewriter.replaceOp(op, *matvecRes); 1370 return success(); 1371 } 1372 FailureOr<Value> tmatvecRes = e.tmatvec(); 1373 if (succeeded(tmatvecRes)) { 1374 rewriter.replaceOp(op, *tmatvecRes); 1375 return success(); 1376 } 1377 1378 return failure(); 1379 } 1380 1381 LogicalResult 1382 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, 1383 PatternRewriter &rewriter) const { 1384 // TODO: implement masks 1385 if (llvm::size(op.masks()) != 0) 1386 return failure(); 1387 1388 if (failed(filter(op))) 1389 return failure(); 1390 1391 if (vectorTransformOptions.vectorContractLowering != 1392 vector::VectorContractLowering::Dot) 1393 return failure(); 1394 1395 auto iteratorTypes = op.iterator_types().getValue(); 1396 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1397 Location loc = op.getLoc(); 1398 Value lhs = op.lhs(), rhs = op.rhs(); 1399 1400 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1401 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 1402 AffineExpr m, n, k; 1403 bindDims(rewriter.getContext(), m, n, k); 1404 SmallVector<AffineMap, 4> maps = op.getIndexingMaps(); 1405 // 1406 // In the following we wish to make the reduction dimension innermost so we 1407 // can load vectors and just fmul + reduce into a scalar. 1408 // 1409 if (isParallelIterator(iteratorTypes[0]) && 1410 isParallelIterator(iteratorTypes[1]) && 1411 isReductionIterator(iteratorTypes[2])) { 1412 // 1413 // Two outer parallel, one inner reduction (matmat flavor). 1414 // 1415 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 1416 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1417 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { 1418 // No need to permute anything. 1419 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 1420 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1421 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1422 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 1423 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1424 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 1425 // This is the classical row-major matmul. Just permute the lhs. 1426 Value tmp = lhs; 1427 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1428 rhs = tmp; 1429 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 1430 std::swap(lhs, rhs); 1431 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 1432 Value tmp = lhs; 1433 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1434 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm); 1435 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 1436 Value tmp = rhs; 1437 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1438 lhs = tmp; 1439 } else { 1440 return failure(); 1441 } 1442 } else if (isParallelIterator(iteratorTypes[0]) && 1443 isReductionIterator(iteratorTypes[1])) { 1444 // 1445 // One outer parallel, one inner reduction (matvec flavor) 1446 // 1447 if (maps == infer({{m, n}, {n}, {m}})) { 1448 // No need to permute anything. 1449 } else if (maps == infer({{n, m}, {n}, {m}})) { 1450 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1451 } else if (maps == infer({{n}, {m, n}, {m}})) { 1452 std::swap(lhs, rhs); 1453 } else if (maps == infer({{n}, {n, m}, {m}})) { 1454 std::swap(lhs, rhs); 1455 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1456 } else { 1457 return failure(); 1458 } 1459 } else { 1460 return failure(); 1461 } 1462 1463 VectorType dstType = op.getResultType().cast<VectorType>(); 1464 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && 1465 "Expected dst type of rank 1 or 2"); 1466 1467 unsigned rank = dstType.getRank(); 1468 unsigned dstRows = dstType.getShape()[0]; 1469 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; 1470 1471 // ExtractOp does not allow dynamic indexing, we must unroll explicitly. 1472 Value res = rewriter.create<arith::ConstantOp>(loc, dstType, 1473 rewriter.getZeroAttr(dstType)); 1474 bool isInt = dstType.getElementType().isa<IntegerType>(); 1475 for (unsigned r = 0; r < dstRows; ++r) { 1476 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r); 1477 for (unsigned c = 0; c < dstColumns; ++c) { 1478 Value b = rank == 1 1479 ? rhs 1480 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c); 1481 Value m = createMul(op.getLoc(), a, b, isInt, rewriter); 1482 Value reduced = rewriter.create<vector::ReductionOp>( 1483 op.getLoc(), vector::CombiningKind::ADD, m); 1484 1485 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r} 1486 : SmallVector<int64_t, 2>{r, c}; 1487 res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos); 1488 } 1489 } 1490 if (auto acc = op.acc()) 1491 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1492 rewriter.replaceOp(op, res); 1493 return success(); 1494 } 1495 1496 /// Progressive lowering of ContractionOp. 1497 /// One: 1498 /// %x = vector.contract with at least one free/batch dimension 1499 /// is replaced by: 1500 /// %a = vector.contract with one less free/batch dimension 1501 /// %b = vector.contract with one less free/batch dimension 1502 /// .. 1503 /// %x = combine %a %b .. 1504 /// until a pure contraction is reached (no free/batch dimensions), 1505 /// which is replaced by a dot-product. 1506 /// 1507 /// This only kicks in when either VectorTransformsOptions is set 1508 /// to DOT or when other contraction patterns fail. 1509 // 1510 // TODO: break down into transpose/reshape/cast ops 1511 // when they become available to avoid code dup 1512 // TODO: investigate lowering order impact on performance 1513 LogicalResult 1514 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, 1515 PatternRewriter &rewriter) const { 1516 // TODO: implement masks. 1517 if (llvm::size(op.masks()) != 0) 1518 return failure(); 1519 1520 if (failed(filter(op))) 1521 return failure(); 1522 1523 // TODO: support mixed mode contract lowering. 1524 if (op.getLhsType().getElementType() != 1525 getElementTypeOrSelf(op.getAccType()) || 1526 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) 1527 return failure(); 1528 1529 // TODO: implement benefits, cost models. 1530 MLIRContext *ctx = op.getContext(); 1531 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); 1532 if (succeeded(pat1.matchAndRewrite(op, rewriter))) 1533 return success(); 1534 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); 1535 if (succeeded(pat2.matchAndRewrite(op, rewriter))) 1536 return success(); 1537 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); 1538 if (succeeded(pat3.matchAndRewrite(op, rewriter))) 1539 return success(); 1540 1541 // Find first batch dimension in LHS/RHS, and lower when found. 1542 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap(); 1543 if (!batchDimMap.empty()) { 1544 int64_t lhsIndex = batchDimMap[0].first; 1545 int64_t rhsIndex = batchDimMap[0].second; 1546 rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter)); 1547 return success(); 1548 } 1549 1550 // Collect contracting dimensions. 1551 std::vector<std::pair<int64_t, int64_t>> contractingDimMap = 1552 op.getContractingDimMap(); 1553 DenseSet<int64_t> lhsContractingDimSet; 1554 DenseSet<int64_t> rhsContractingDimSet; 1555 for (auto &dimPair : contractingDimMap) { 1556 lhsContractingDimSet.insert(dimPair.first); 1557 rhsContractingDimSet.insert(dimPair.second); 1558 } 1559 1560 // Find first free dimension in LHS, and lower when found. 1561 VectorType lhsType = op.getLhsType(); 1562 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { 1563 if (lhsContractingDimSet.count(lhsIndex) == 0) { 1564 rewriter.replaceOp( 1565 op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter)); 1566 return success(); 1567 } 1568 } 1569 1570 // Find first free dimension in RHS, and lower when found. 1571 VectorType rhsType = op.getRhsType(); 1572 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { 1573 if (rhsContractingDimSet.count(rhsIndex) == 0) { 1574 rewriter.replaceOp( 1575 op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter)); 1576 return success(); 1577 } 1578 } 1579 1580 // Lower the first remaining reduction dimension. 1581 if (!contractingDimMap.empty()) { 1582 rewriter.replaceOp(op, lowerReduction(op, rewriter)); 1583 return success(); 1584 } 1585 1586 return failure(); 1587 } 1588 1589 // Lower one parallel dimension. 1590 // TODO: consider reusing existing contract unrolling 1591 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op, 1592 int64_t lhsIndex, int64_t rhsIndex, 1593 PatternRewriter &rewriter) const { 1594 VectorType lhsType = op.getLhsType(); 1595 VectorType rhsType = op.getRhsType(); 1596 VectorType resType = op.getResultType().cast<VectorType>(); 1597 // Find the iterator type index and result index. 1598 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1599 int64_t iterIndex = -1; 1600 int64_t dimSize = -1; 1601 if (lhsIndex >= 0) { 1602 iterIndex = iMap[0].getDimPosition(lhsIndex); 1603 assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) && 1604 "parallel index should be free in LHS or batch in LHS/RHS"); 1605 dimSize = lhsType.getDimSize(lhsIndex); 1606 } else { 1607 assert(rhsIndex >= 0 && "missing parallel index"); 1608 iterIndex = iMap[1].getDimPosition(rhsIndex); 1609 dimSize = rhsType.getDimSize(rhsIndex); 1610 } 1611 assert(iterIndex >= 0 && "parallel index not listed in operand mapping"); 1612 Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex); 1613 assert(lookup.hasValue() && "parallel index not listed in reduction"); 1614 int64_t resIndex = lookup.getValue(); 1615 // Construct new iterator types and affine map array attribute. 1616 std::array<AffineMap, 3> lowIndexingMaps = { 1617 adjustMap(iMap[0], iterIndex, rewriter), 1618 adjustMap(iMap[1], iterIndex, rewriter), 1619 adjustMap(iMap[2], iterIndex, rewriter)}; 1620 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1621 auto lowIter = 1622 rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); 1623 // Unroll into a series of lower dimensional vector.contract ops. 1624 Location loc = op.getLoc(); 1625 Value result = rewriter.create<arith::ConstantOp>( 1626 loc, resType, rewriter.getZeroAttr(resType)); 1627 for (int64_t d = 0; d < dimSize; ++d) { 1628 auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); 1629 auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); 1630 auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter); 1631 Value lowContract = rewriter.create<vector::ContractionOp>( 1632 loc, lhs, rhs, acc, lowAffine, lowIter); 1633 result = 1634 reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter); 1635 } 1636 return result; 1637 } 1638 1639 // Lower one reduction dimension. 1640 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op, 1641 PatternRewriter &rewriter) const { 1642 auto loc = op.getLoc(); 1643 VectorType lhsType = op.getLhsType(); 1644 VectorType rhsType = op.getRhsType(); 1645 Type resType = op.getResultType(); 1646 assert(!resType.isa<VectorType>()); 1647 bool isInt = resType.isa<IntegerType>(); 1648 // Use iterator index 0. 1649 int64_t iterIndex = 0; 1650 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1651 Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex); 1652 Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex); 1653 assert(lookupLhs.hasValue() && "missing LHS parallel index"); 1654 assert(lookupRhs.hasValue() && "missing RHS parallel index"); 1655 int64_t lhsIndex = lookupLhs.getValue(); 1656 int64_t rhsIndex = lookupRhs.getValue(); 1657 int64_t dimSize = lhsType.getDimSize(lhsIndex); 1658 assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape"); 1659 // Base case. 1660 if (lhsType.getRank() == 1) { 1661 assert(rhsType.getRank() == 1 && "corrupt contraction"); 1662 Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter); 1663 auto kind = vector::CombiningKind::ADD; 1664 Value res = rewriter.create<vector::ReductionOp>(loc, kind, m); 1665 if (auto acc = op.acc()) 1666 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1667 return res; 1668 } 1669 // Construct new iterator types and affine map array attribute. 1670 std::array<AffineMap, 3> lowIndexingMaps = { 1671 adjustMap(iMap[0], iterIndex, rewriter), 1672 adjustMap(iMap[1], iterIndex, rewriter), 1673 adjustMap(iMap[2], iterIndex, rewriter)}; 1674 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1675 auto lowIter = 1676 rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); 1677 // Unroll into a series of lower dimensional vector.contract ops. 1678 // By feeding the initial accumulator into the first contraction, 1679 // and the result of each contraction into the next, eventually 1680 // the sum of all reductions is computed. 1681 Value result = op.acc(); 1682 for (int64_t d = 0; d < dimSize; ++d) { 1683 auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); 1684 auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); 1685 result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result, 1686 lowAffine, lowIter); 1687 } 1688 return result; 1689 } 1690 1691 } // namespace mlir 1692 1693 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp( 1694 OpBuilder &builder, Operation *op, ArrayRef<Value> ids, 1695 ArrayRef<int64_t> multiplicity, const AffineMap &map) { 1696 OpBuilder::InsertionGuard guard(builder); 1697 builder.setInsertionPointAfter(op); 1698 Location loc = op->getLoc(); 1699 if (op->getNumResults() != 1) 1700 return {}; 1701 Value result = op->getResult(0); 1702 VectorType type = op->getResult(0).getType().dyn_cast<VectorType>(); 1703 if (!type || map.getNumResults() != multiplicity.size()) 1704 return {}; 1705 // For each dimension being distributed check that the size is a multiple of 1706 // the multiplicity. To handle more sizes we would need to support masking. 1707 unsigned multiplictyCount = 0; 1708 for (auto exp : map.getResults()) { 1709 auto affinExp = exp.dyn_cast<AffineDimExpr>(); 1710 if (!affinExp || affinExp.getPosition() >= type.getRank() || 1711 type.getDimSize(affinExp.getPosition()) % 1712 multiplicity[multiplictyCount++] != 1713 0) 1714 return {}; 1715 } 1716 DistributeOps ops; 1717 ops.extract = 1718 builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map); 1719 ops.insert = 1720 builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids); 1721 return ops; 1722 } 1723 1724 /// Progressive lowering of transfer_read. This pattern supports lowering of 1725 /// `vector.transfer_read` to a combination of `vector.load` and 1726 /// `vector.broadcast` if all of the following hold: 1727 /// - Stride of most minor memref dimension must be 1. 1728 /// - Out-of-bounds masking is not required. 1729 /// - If the memref's element type is a vector type then it coincides with the 1730 /// result type. 1731 /// - The permutation map doesn't perform permutation (broadcasting is allowed). 1732 struct TransferReadToVectorLoadLowering 1733 : public OpRewritePattern<vector::TransferReadOp> { 1734 TransferReadToVectorLoadLowering(MLIRContext *context, 1735 llvm::Optional<unsigned> maxRank) 1736 : OpRewritePattern<vector::TransferReadOp>(context), 1737 maxTransferRank(maxRank) {} 1738 1739 LogicalResult matchAndRewrite(vector::TransferReadOp read, 1740 PatternRewriter &rewriter) const override { 1741 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) 1742 return failure(); 1743 1744 SmallVector<unsigned, 4> broadcastedDims; 1745 // Permutations are handled by VectorToSCF or 1746 // populateVectorTransferPermutationMapLoweringPatterns. 1747 // We let the 0-d corner case pass-through as it is supported. 1748 if (!read.permutation_map().isMinorIdentityWithBroadcasting( 1749 &broadcastedDims)) 1750 return failure(); 1751 1752 auto memRefType = read.getShapedType().dyn_cast<MemRefType>(); 1753 if (!memRefType) 1754 return failure(); 1755 1756 // Non-unit strides are handled by VectorToSCF. 1757 if (!vector::isLastMemrefDimUnitStride(memRefType)) 1758 return failure(); 1759 1760 // If there is broadcasting involved then we first load the unbroadcasted 1761 // vector, and then broadcast it with `vector.broadcast`. 1762 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape(); 1763 SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(), 1764 vectorShape.end()); 1765 for (unsigned i : broadcastedDims) 1766 unbroadcastedVectorShape[i] = 1; 1767 VectorType unbroadcastedVectorType = VectorType::get( 1768 unbroadcastedVectorShape, read.getVectorType().getElementType()); 1769 1770 // `vector.load` supports vector types as memref's elements only when the 1771 // resulting vector type is the same as the element type. 1772 auto memrefElTy = memRefType.getElementType(); 1773 if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType) 1774 return failure(); 1775 1776 // Otherwise, element types of the memref and the vector must match. 1777 if (!memrefElTy.isa<VectorType>() && 1778 memrefElTy != read.getVectorType().getElementType()) 1779 return failure(); 1780 1781 // Out-of-bounds dims are handled by MaterializeTransferMask. 1782 if (read.hasOutOfBoundsDim()) 1783 return failure(); 1784 1785 // Create vector load op. 1786 Operation *loadOp; 1787 if (read.mask()) { 1788 Value fill = rewriter.create<vector::SplatOp>( 1789 read.getLoc(), unbroadcastedVectorType, read.padding()); 1790 loadOp = rewriter.create<vector::MaskedLoadOp>( 1791 read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(), 1792 read.mask(), fill); 1793 } else { 1794 loadOp = rewriter.create<vector::LoadOp>(read.getLoc(), 1795 unbroadcastedVectorType, 1796 read.source(), read.indices()); 1797 } 1798 1799 // Insert a broadcasting op if required. 1800 if (!broadcastedDims.empty()) { 1801 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1802 read, read.getVectorType(), loadOp->getResult(0)); 1803 } else { 1804 rewriter.replaceOp(read, loadOp->getResult(0)); 1805 } 1806 1807 return success(); 1808 } 1809 1810 llvm::Optional<unsigned> maxTransferRank; 1811 }; 1812 1813 /// Replace a 0-d vector.load with a memref.load + vector.broadcast. 1814 // TODO: we shouldn't cross the vector/scalar domains just for this 1815 // but atm we lack the infra to avoid it. Possible solutions include: 1816 // - go directly to LLVM + bitcast 1817 // - introduce a bitcast op and likely a new pointer dialect 1818 // - let memref.load/store additionally support the 0-d vector case 1819 // There are still deeper data layout issues lingering even in this 1820 // trivial case (for architectures for which this matters). 1821 struct VectorLoadToMemrefLoadLowering 1822 : public OpRewritePattern<vector::LoadOp> { 1823 using OpRewritePattern<vector::LoadOp>::OpRewritePattern; 1824 1825 LogicalResult matchAndRewrite(vector::LoadOp loadOp, 1826 PatternRewriter &rewriter) const override { 1827 auto vecType = loadOp.getVectorType(); 1828 if (vecType.getNumElements() != 1) 1829 return failure(); 1830 auto memrefLoad = rewriter.create<memref::LoadOp>( 1831 loadOp.getLoc(), loadOp.base(), loadOp.indices()); 1832 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType, 1833 memrefLoad); 1834 return success(); 1835 } 1836 }; 1837 1838 /// Replace a 0-d vector.store with a vector.extractelement + memref.store. 1839 struct VectorStoreToMemrefStoreLowering 1840 : public OpRewritePattern<vector::StoreOp> { 1841 using OpRewritePattern<vector::StoreOp>::OpRewritePattern; 1842 1843 LogicalResult matchAndRewrite(vector::StoreOp storeOp, 1844 PatternRewriter &rewriter) const override { 1845 auto vecType = storeOp.getVectorType(); 1846 if (vecType.getNumElements() != 1) 1847 return failure(); 1848 Value extracted; 1849 if (vecType.getRank() == 0) { 1850 // TODO: Unifiy once ExtractOp supports 0-d vectors. 1851 extracted = rewriter.create<vector::ExtractElementOp>( 1852 storeOp.getLoc(), storeOp.valueToStore()); 1853 } else { 1854 SmallVector<int64_t> indices(vecType.getRank(), 0); 1855 extracted = rewriter.create<vector::ExtractOp>( 1856 storeOp.getLoc(), storeOp.valueToStore(), indices); 1857 } 1858 1859 rewriter.replaceOpWithNewOp<memref::StoreOp>( 1860 storeOp, extracted, storeOp.base(), storeOp.indices()); 1861 return success(); 1862 } 1863 }; 1864 1865 /// Progressive lowering of transfer_write. This pattern supports lowering of 1866 /// `vector.transfer_write` to `vector.store` if all of the following hold: 1867 /// - Stride of most minor memref dimension must be 1. 1868 /// - Out-of-bounds masking is not required. 1869 /// - If the memref's element type is a vector type then it coincides with the 1870 /// type of the written value. 1871 /// - The permutation map is the minor identity map (neither permutation nor 1872 /// broadcasting is allowed). 1873 struct TransferWriteToVectorStoreLowering 1874 : public OpRewritePattern<vector::TransferWriteOp> { 1875 TransferWriteToVectorStoreLowering(MLIRContext *context, 1876 llvm::Optional<unsigned> maxRank) 1877 : OpRewritePattern<vector::TransferWriteOp>(context), 1878 maxTransferRank(maxRank) {} 1879 1880 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 1881 PatternRewriter &rewriter) const override { 1882 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) 1883 return failure(); 1884 1885 // Permutations are handled by VectorToSCF or 1886 // populateVectorTransferPermutationMapLoweringPatterns. 1887 if ( // pass-through for the 0-d corner case. 1888 !write.permutation_map().isMinorIdentity()) 1889 return failure(); 1890 1891 auto memRefType = write.getShapedType().dyn_cast<MemRefType>(); 1892 if (!memRefType) 1893 return failure(); 1894 1895 // Non-unit strides are handled by VectorToSCF. 1896 if (!vector::isLastMemrefDimUnitStride(memRefType)) 1897 return failure(); 1898 1899 // `vector.store` supports vector types as memref's elements only when the 1900 // type of the vector value being written is the same as the element type. 1901 auto memrefElTy = memRefType.getElementType(); 1902 if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType()) 1903 return failure(); 1904 1905 // Otherwise, element types of the memref and the vector must match. 1906 if (!memrefElTy.isa<VectorType>() && 1907 memrefElTy != write.getVectorType().getElementType()) 1908 return failure(); 1909 1910 // Out-of-bounds dims are handled by MaterializeTransferMask. 1911 if (write.hasOutOfBoundsDim()) 1912 return failure(); 1913 if (write.mask()) { 1914 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 1915 write, write.source(), write.indices(), write.mask(), write.vector()); 1916 } else { 1917 rewriter.replaceOpWithNewOp<vector::StoreOp>( 1918 write, write.vector(), write.source(), write.indices()); 1919 } 1920 return success(); 1921 } 1922 1923 llvm::Optional<unsigned> maxTransferRank; 1924 }; 1925 1926 // Returns the values in `arrayAttr` as an integer vector. 1927 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) { 1928 return llvm::to_vector<4>( 1929 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(), 1930 [](IntegerAttr attr) { return attr.getInt(); })); 1931 } 1932 1933 // Shuffles vector.bitcast op after vector.extract op. 1934 // 1935 // This transforms IR like: 1936 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> 1937 // %1 = vector.extract %0[3] : vector<8xf16> 1938 // Into: 1939 // %0 = vector.extract %src[1] : vector<4xf32> 1940 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16> 1941 // %2 = vector.extract %1[1] : vector<2xf16> 1942 struct BubbleDownVectorBitCastForExtract 1943 : public OpRewritePattern<vector::ExtractOp> { 1944 using OpRewritePattern::OpRewritePattern; 1945 1946 LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 1947 PatternRewriter &rewriter) const override { 1948 // Only support extracting scalars for now. 1949 if (extractOp.getVectorType().getRank() != 1) 1950 return failure(); 1951 1952 auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>(); 1953 if (!castOp) 1954 return failure(); 1955 1956 VectorType castSrcType = castOp.getSourceVectorType(); 1957 VectorType castDstType = castOp.getResultVectorType(); 1958 assert(castSrcType.getRank() == castDstType.getRank()); 1959 1960 // Fail to match if we only have one element in the cast op source. 1961 // This is to avoid infinite loop given that this pattern can generate 1962 // such cases. 1963 if (castSrcType.getNumElements() == 1) 1964 return failure(); 1965 1966 // Only support casting to a larger number of elements or now. 1967 // E.g., vector<4xf32> -> vector<8xf16>. 1968 if (castSrcType.getNumElements() > castDstType.getNumElements()) 1969 return failure(); 1970 1971 unsigned expandRatio = 1972 castDstType.getNumElements() / castSrcType.getNumElements(); 1973 1974 auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t { 1975 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 1976 }; 1977 1978 uint64_t index = getFirstIntValue(extractOp.position()); 1979 1980 // Get the single scalar (as a vector) in the source value that packs the 1981 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> 1982 VectorType oneScalarType = 1983 VectorType::get({1}, castSrcType.getElementType()); 1984 Value packedValue = rewriter.create<vector::ExtractOp>( 1985 extractOp.getLoc(), oneScalarType, castOp.source(), 1986 rewriter.getI64ArrayAttr(index / expandRatio)); 1987 1988 // Cast it to a vector with the desired scalar's type. 1989 // E.g. f32 -> vector<2xf16> 1990 VectorType packedType = 1991 VectorType::get({expandRatio}, castDstType.getElementType()); 1992 Value castedValue = rewriter.create<vector::BitCastOp>( 1993 extractOp.getLoc(), packedType, packedValue); 1994 1995 // Finally extract the desired scalar. 1996 rewriter.replaceOpWithNewOp<vector::ExtractOp>( 1997 extractOp, extractOp.getType(), castedValue, 1998 rewriter.getI64ArrayAttr(index % expandRatio)); 1999 2000 return success(); 2001 } 2002 }; 2003 2004 // Shuffles vector.bitcast op after vector.extract_strided_slice op. 2005 // 2006 // This transforms IR like: 2007 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> 2008 // %0 = vector.extract_strided_slice %cast { 2009 // offsets = [4], sizes = [4], strides = [1] 2010 // } : vector<8xf16> to vector<4xf16> 2011 // Into: 2012 // %0 = vector.extract_strided_slice %src { 2013 // offsets = [2], sizes = [2], strides = [1] 2014 // } : vector<4xf32> to vector<2xf32> 2015 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> 2016 struct BubbleDownBitCastForStridedSliceExtract 2017 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 2018 using OpRewritePattern::OpRewritePattern; 2019 2020 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 2021 PatternRewriter &rewriter) const override { 2022 auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>(); 2023 if (!castOp) 2024 return failure(); 2025 2026 VectorType castSrcType = castOp.getSourceVectorType(); 2027 VectorType castDstType = castOp.getResultVectorType(); 2028 assert(castSrcType.getRank() == castDstType.getRank()); 2029 2030 int64_t castSrcLastDim = castSrcType.getShape().back(); 2031 int64_t castDstLastDim = castDstType.getShape().back(); 2032 // Require casting to more elements for now; other cases to be implemented. 2033 if (castSrcLastDim > castDstLastDim) 2034 return failure(); 2035 2036 // Only accept all one strides for now. 2037 if (llvm::any_of(extractOp.strides().getAsValueRange<IntegerAttr>(), 2038 [](const APInt &val) { return !val.isOneValue(); })) 2039 return failure(); 2040 2041 unsigned rank = extractOp.getVectorType().getRank(); 2042 assert(castDstLastDim % castSrcLastDim == 0); 2043 int64_t expandRatio = castDstLastDim / castSrcLastDim; 2044 2045 // If we have a less number of offsets than the rank, then implicitly we 2046 // are selecting the full range for the last bitcasted dimension; other 2047 // dimensions aren't affected. Otherwise, we need to scale down the last 2048 // dimension's offset given we are extracting from less elements now. 2049 ArrayAttr newOffsets = extractOp.offsets(); 2050 if (newOffsets.size() == rank) { 2051 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2052 if (offsets.back() % expandRatio != 0) 2053 return failure(); 2054 offsets.back() = offsets.back() / expandRatio; 2055 newOffsets = rewriter.getI64ArrayAttr(offsets); 2056 } 2057 2058 // Similarly for sizes. 2059 ArrayAttr newSizes = extractOp.sizes(); 2060 if (newSizes.size() == rank) { 2061 SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes); 2062 if (sizes.back() % expandRatio != 0) 2063 return failure(); 2064 sizes.back() = sizes.back() / expandRatio; 2065 newSizes = rewriter.getI64ArrayAttr(sizes); 2066 } 2067 2068 SmallVector<int64_t, 4> dims = 2069 llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape()); 2070 dims.back() = dims.back() / expandRatio; 2071 VectorType newExtractType = 2072 VectorType::get(dims, castSrcType.getElementType()); 2073 2074 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 2075 extractOp.getLoc(), newExtractType, castOp.source(), newOffsets, 2076 newSizes, extractOp.strides()); 2077 2078 rewriter.replaceOpWithNewOp<vector::BitCastOp>( 2079 extractOp, extractOp.getType(), newExtractOp); 2080 2081 return success(); 2082 } 2083 }; 2084 2085 // Shuffles vector.bitcast op before vector.insert_strided_slice op. 2086 // 2087 // This transforms IR like: 2088 // %0 = vector.insert_strided_slice %src, %dst { 2089 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> 2090 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> 2091 // Into: 2092 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32> 2093 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32> 2094 // %2 = vector.insert_strided_slice %src, %dst { 2095 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 2096 struct BubbleUpBitCastForStridedSliceInsert 2097 : public OpRewritePattern<vector::BitCastOp> { 2098 using OpRewritePattern::OpRewritePattern; 2099 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 2100 PatternRewriter &rewriter) const override { 2101 VectorType castSrcType = bitcastOp.getSourceVectorType(); 2102 VectorType castDstType = bitcastOp.getResultVectorType(); 2103 assert(castSrcType.getRank() == castDstType.getRank()); 2104 2105 int64_t castSrcLastDim = castSrcType.getShape().back(); 2106 int64_t castDstLastDim = castDstType.getShape().back(); 2107 // Require casting to less elements for now; other cases to be implemented. 2108 if (castSrcLastDim < castDstLastDim) 2109 return failure(); 2110 2111 assert(castSrcLastDim % castDstLastDim == 0); 2112 int64_t shrinkRatio = castSrcLastDim / castDstLastDim; 2113 2114 auto insertOp = 2115 bitcastOp.source().getDefiningOp<vector::InsertStridedSliceOp>(); 2116 if (!insertOp) 2117 return failure(); 2118 2119 // Only accept all one strides for now. 2120 if (llvm::any_of(insertOp.strides().getAsValueRange<IntegerAttr>(), 2121 [](const APInt &val) { return !val.isOneValue(); })) 2122 return failure(); 2123 2124 unsigned rank = insertOp.getSourceVectorType().getRank(); 2125 // Require insert op to have the same rank for the source and destination 2126 // vector; other cases to be implemented. 2127 if (rank != insertOp.getDestVectorType().getRank()) 2128 return failure(); 2129 2130 ArrayAttr newOffsets = insertOp.offsets(); 2131 assert(newOffsets.size() == rank); 2132 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2133 if (offsets.back() % shrinkRatio != 0) 2134 return failure(); 2135 offsets.back() = offsets.back() / shrinkRatio; 2136 newOffsets = rewriter.getI64ArrayAttr(offsets); 2137 2138 SmallVector<int64_t, 4> srcDims = 2139 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape()); 2140 srcDims.back() = srcDims.back() / shrinkRatio; 2141 VectorType newCastSrcType = 2142 VectorType::get(srcDims, castDstType.getElementType()); 2143 2144 auto newCastSrcOp = rewriter.create<vector::BitCastOp>( 2145 bitcastOp.getLoc(), newCastSrcType, insertOp.source()); 2146 2147 SmallVector<int64_t, 4> dstDims = 2148 llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); 2149 dstDims.back() = dstDims.back() / shrinkRatio; 2150 VectorType newCastDstType = 2151 VectorType::get(dstDims, castDstType.getElementType()); 2152 2153 auto newCastDstOp = rewriter.create<vector::BitCastOp>( 2154 bitcastOp.getLoc(), newCastDstType, insertOp.dest()); 2155 2156 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>( 2157 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, 2158 insertOp.strides()); 2159 2160 return success(); 2161 } 2162 }; 2163 2164 static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc, 2165 Type targetType, Value value) { 2166 if (targetType == value.getType()) 2167 return value; 2168 2169 bool targetIsIndex = targetType.isIndex(); 2170 bool valueIsIndex = value.getType().isIndex(); 2171 if (targetIsIndex ^ valueIsIndex) 2172 return rewriter.create<arith::IndexCastOp>(loc, targetType, value); 2173 2174 auto targetIntegerType = targetType.dyn_cast<IntegerType>(); 2175 auto valueIntegerType = value.getType().dyn_cast<IntegerType>(); 2176 assert(targetIntegerType && valueIntegerType && 2177 "unexpected cast between types other than integers and index"); 2178 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); 2179 2180 if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) 2181 return rewriter.create<arith::ExtSIOp>(loc, targetIntegerType, value); 2182 return rewriter.create<arith::TruncIOp>(loc, targetIntegerType, value); 2183 } 2184 2185 // Helper that returns a vector comparison that constructs a mask: 2186 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 2187 // 2188 // If `dim == 0` then the result will be a 0-D vector. 2189 // 2190 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 2191 // much more compact, IR for this operation, but LLVM eventually 2192 // generates more elaborate instructions for this intrinsic since it 2193 // is very conservative on the boundary conditions. 2194 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, 2195 bool indexOptimizations, int64_t dim, 2196 Value b, Value *off = nullptr) { 2197 auto loc = op->getLoc(); 2198 // If we can assume all indices fit in 32-bit, we perform the vector 2199 // comparison in 32-bit to get a higher degree of SIMD parallelism. 2200 // Otherwise we perform the vector comparison using 64-bit indices. 2201 Type idxType = 2202 indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type(); 2203 DenseIntElementsAttr indicesAttr; 2204 if (dim == 0 && indexOptimizations) { 2205 indicesAttr = DenseIntElementsAttr::get( 2206 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0}); 2207 } else if (dim == 0) { 2208 indicesAttr = DenseIntElementsAttr::get( 2209 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0}); 2210 } else if (indexOptimizations) { 2211 indicesAttr = rewriter.getI32VectorAttr( 2212 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))); 2213 } else { 2214 indicesAttr = rewriter.getI64VectorAttr( 2215 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))); 2216 } 2217 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 2218 // Add in an offset if requested. 2219 if (off) { 2220 Value o = createCastToIndexLike(rewriter, loc, idxType, *off); 2221 Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o); 2222 indices = rewriter.create<arith::AddIOp>(loc, ov, indices); 2223 } 2224 // Construct the vector comparison. 2225 Value bound = createCastToIndexLike(rewriter, loc, idxType, b); 2226 Value bounds = 2227 rewriter.create<vector::SplatOp>(loc, indices.getType(), bound); 2228 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices, 2229 bounds); 2230 } 2231 2232 template <typename ConcreteOp> 2233 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> { 2234 public: 2235 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt) 2236 : mlir::OpRewritePattern<ConcreteOp>(context), 2237 indexOptimizations(enableIndexOpt) {} 2238 2239 LogicalResult matchAndRewrite(ConcreteOp xferOp, 2240 PatternRewriter &rewriter) const override { 2241 if (!xferOp.hasOutOfBoundsDim()) 2242 return failure(); 2243 2244 if (xferOp.getVectorType().getRank() > 1 || 2245 llvm::size(xferOp.indices()) == 0) 2246 return failure(); 2247 2248 Location loc = xferOp->getLoc(); 2249 VectorType vtp = xferOp.getVectorType(); 2250 2251 // Create the in-bounds mask with all elements between [0 .. dim - offset) 2252 // set and [dim - offset .. vector_length) unset. 2253 // 2254 // TODO: when the leaf transfer rank is k > 1, we need the last `k` 2255 // dimensions here. 2256 unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 2257 Value off = xferOp.indices()[lastIndex]; 2258 Value dim = 2259 vector::createOrFoldDimOp(rewriter, loc, xferOp.source(), lastIndex); 2260 Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off); 2261 Value mask = rewriter.create<vector::CreateMaskOp>( 2262 loc, 2263 VectorType::get(vtp.getShape(), rewriter.getI1Type(), 2264 vtp.getNumScalableDims()), 2265 b); 2266 if (xferOp.mask()) { 2267 // Intersect the in-bounds with the mask specified as an op parameter. 2268 mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.mask()); 2269 } 2270 2271 rewriter.updateRootInPlace(xferOp, [&]() { 2272 xferOp.maskMutable().assign(mask); 2273 xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true})); 2274 }); 2275 2276 return success(); 2277 } 2278 2279 private: 2280 const bool indexOptimizations; 2281 }; 2282 2283 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only). 2284 class VectorCreateMaskOpConversion 2285 : public OpRewritePattern<vector::CreateMaskOp> { 2286 public: 2287 explicit VectorCreateMaskOpConversion(MLIRContext *context, 2288 bool enableIndexOpt) 2289 : mlir::OpRewritePattern<vector::CreateMaskOp>(context), 2290 indexOptimizations(enableIndexOpt) {} 2291 2292 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 2293 PatternRewriter &rewriter) const override { 2294 auto dstType = op.getType(); 2295 int64_t rank = dstType.getRank(); 2296 if (rank > 1) 2297 return failure(); 2298 rewriter.replaceOp( 2299 op, buildVectorComparison(rewriter, op, indexOptimizations, 2300 rank == 0 ? 0 : dstType.getDimSize(0), 2301 op.getOperand(0))); 2302 return success(); 2303 } 2304 2305 private: 2306 const bool indexOptimizations; 2307 }; 2308 2309 // Drop inner most contiguous unit dimensions from transfer_read operand. 2310 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> { 2311 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 2312 2313 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 2314 PatternRewriter &rewriter) const override { 2315 // TODO: support 0-d corner case. 2316 if (readOp.getTransferRank() == 0) 2317 return failure(); 2318 2319 // TODO: support mask. 2320 if (readOp.mask()) 2321 return failure(); 2322 2323 auto srcType = readOp.source().getType().dyn_cast<MemRefType>(); 2324 if (!srcType || !srcType.hasStaticShape()) 2325 return failure(); 2326 2327 if (!readOp.permutation_map().isMinorIdentity()) 2328 return failure(); 2329 2330 auto targetType = readOp.getVectorType(); 2331 if (targetType.getRank() <= 1) 2332 return failure(); 2333 2334 SmallVector<int64_t> srcStrides; 2335 int64_t srcOffset; 2336 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) 2337 return failure(); 2338 2339 size_t dimsToDrop = 0; 2340 for (size_t i = 1; i < srcStrides.size(); ++i) { 2341 int dim = srcType.getRank() - i - 1; 2342 if (srcStrides[dim] == 1) { 2343 dimsToDrop++; 2344 } else { 2345 break; 2346 } 2347 } 2348 if (dimsToDrop == 0) 2349 return failure(); 2350 2351 auto resultTargetVecType = 2352 VectorType::get(targetType.getShape().drop_back(dimsToDrop), 2353 targetType.getElementType()); 2354 2355 MemRefType resultMemrefType; 2356 if (srcType.getLayout().getAffineMap().isIdentity()) { 2357 resultMemrefType = MemRefType::get( 2358 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2359 {}, srcType.getMemorySpaceAsInt()); 2360 } else { 2361 AffineMap map = srcType.getLayout().getAffineMap(); 2362 int numResultDims = map.getNumDims() - dimsToDrop; 2363 int numSymbols = map.getNumSymbols(); 2364 for (size_t i = 0; i < dimsToDrop; ++i) { 2365 int dim = srcType.getRank() - i - 1; 2366 map = map.replace(rewriter.getAffineDimExpr(dim), 2367 rewriter.getAffineConstantExpr(0), numResultDims, 2368 numSymbols); 2369 } 2370 resultMemrefType = MemRefType::get( 2371 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2372 map, srcType.getMemorySpaceAsInt()); 2373 } 2374 2375 auto loc = readOp.getLoc(); 2376 SmallVector<int64_t> offsets(srcType.getRank(), 0); 2377 SmallVector<int64_t> strides(srcType.getRank(), 1); 2378 2379 ArrayAttr inBoundsAttr = 2380 readOp.in_bounds() 2381 ? rewriter.getArrayAttr( 2382 readOp.in_boundsAttr().getValue().drop_back(dimsToDrop)) 2383 : ArrayAttr(); 2384 Value rankedReducedView = rewriter.create<memref::SubViewOp>( 2385 loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(), 2386 strides); 2387 auto permMap = getTransferMinorIdentityMap( 2388 rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType); 2389 Value result = rewriter.create<vector::TransferReadOp>( 2390 loc, resultTargetVecType, rankedReducedView, 2391 readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), 2392 readOp.padding(), 2393 // TODO: support mask. 2394 /*mask=*/Value(), inBoundsAttr); 2395 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType, 2396 result); 2397 return success(); 2398 } 2399 }; 2400 2401 namespace { 2402 2403 /// This function checks to see if the vector combining kind 2404 /// is consistent with the integer or float element type. 2405 static bool isValidKind(bool isInt, vector::CombiningKind kind) { 2406 using vector::CombiningKind; 2407 enum class KindType { FLOAT, INT, INVALID }; 2408 KindType type{KindType::INVALID}; 2409 switch (kind) { 2410 case CombiningKind::MINF: 2411 case CombiningKind::MAXF: 2412 type = KindType::FLOAT; 2413 break; 2414 case CombiningKind::MINUI: 2415 case CombiningKind::MINSI: 2416 case CombiningKind::MAXUI: 2417 case CombiningKind::MAXSI: 2418 case CombiningKind::AND: 2419 case CombiningKind::OR: 2420 case CombiningKind::XOR: 2421 type = KindType::INT; 2422 break; 2423 case CombiningKind::ADD: 2424 case CombiningKind::MUL: 2425 type = isInt ? KindType::INT : KindType::FLOAT; 2426 break; 2427 } 2428 bool isValidIntKind = (type == KindType::INT) && isInt; 2429 bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt); 2430 return (isValidIntKind || isValidFloatKind); 2431 } 2432 2433 /// This function constructs the appropriate integer or float 2434 /// operation given the vector combining kind and operands. The 2435 /// supported int operations are : add, mul, min (signed/unsigned), 2436 /// max(signed/unsigned), and, or, xor. The supported float 2437 /// operations are : add, mul, min and max. 2438 static Value genOperator(Location loc, Value x, Value y, 2439 vector::CombiningKind kind, 2440 PatternRewriter &rewriter) { 2441 using vector::CombiningKind; 2442 2443 auto elType = x.getType().cast<VectorType>().getElementType(); 2444 bool isInt = elType.isIntOrIndex(); 2445 2446 Value combinedResult{nullptr}; 2447 switch (kind) { 2448 case CombiningKind::ADD: 2449 if (isInt) 2450 combinedResult = rewriter.create<arith::AddIOp>(loc, x, y); 2451 else 2452 combinedResult = rewriter.create<arith::AddFOp>(loc, x, y); 2453 break; 2454 case CombiningKind::MUL: 2455 if (isInt) 2456 combinedResult = rewriter.create<arith::MulIOp>(loc, x, y); 2457 else 2458 combinedResult = rewriter.create<arith::MulFOp>(loc, x, y); 2459 break; 2460 case CombiningKind::MINUI: 2461 combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y); 2462 break; 2463 case CombiningKind::MINSI: 2464 combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y); 2465 break; 2466 case CombiningKind::MAXUI: 2467 combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y); 2468 break; 2469 case CombiningKind::MAXSI: 2470 combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y); 2471 break; 2472 case CombiningKind::AND: 2473 combinedResult = rewriter.create<arith::AndIOp>(loc, x, y); 2474 break; 2475 case CombiningKind::OR: 2476 combinedResult = rewriter.create<arith::OrIOp>(loc, x, y); 2477 break; 2478 case CombiningKind::XOR: 2479 combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y); 2480 break; 2481 case CombiningKind::MINF: 2482 combinedResult = rewriter.create<arith::MinFOp>(loc, x, y); 2483 break; 2484 case CombiningKind::MAXF: 2485 combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y); 2486 break; 2487 } 2488 return combinedResult; 2489 } 2490 2491 /// Convert vector.scan op into arith ops and 2492 /// vector.insert_strided_slice/extract_strided_slice 2493 /// 2494 /// Ex: 2495 /// ``` 2496 /// %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 2497 /// 1} : 2498 /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>) 2499 /// ``` 2500 /// Gets converted to: 2501 /// ``` 2502 /// %cst = arith.constant dense<0> : vector<2x3xi32> 2503 /// %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1], 2504 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 = 2505 /// vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]} 2506 /// : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice 2507 /// %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : 2508 /// vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 : 2509 /// vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1], 2510 /// strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 = 2511 /// vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1], 2512 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3, 2513 /// %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets = 2514 /// [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 = 2515 /// vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 : 2516 /// vector<2x3xi32>, vector<2xi32> 2517 /// ``` 2518 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { 2519 using OpRewritePattern<vector::ScanOp>::OpRewritePattern; 2520 2521 LogicalResult matchAndRewrite(vector::ScanOp scanOp, 2522 PatternRewriter &rewriter) const override { 2523 auto loc = scanOp.getLoc(); 2524 VectorType destType = scanOp.getDestType(); 2525 ArrayRef<int64_t> destShape = destType.getShape(); 2526 auto elType = destType.getElementType(); 2527 bool isInt = elType.isIntOrIndex(); 2528 if (!isValidKind(isInt, scanOp.kind())) 2529 return failure(); 2530 2531 VectorType resType = VectorType::get(destShape, elType); 2532 Value result = rewriter.create<arith::ConstantOp>( 2533 loc, resType, rewriter.getZeroAttr(resType)); 2534 int64_t reductionDim = scanOp.reduction_dim(); 2535 bool inclusive = scanOp.inclusive(); 2536 int64_t destRank = destType.getRank(); 2537 VectorType initialValueType = scanOp.getInitialValueType(); 2538 int64_t initialValueRank = initialValueType.getRank(); 2539 2540 SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end()); 2541 reductionShape[reductionDim] = 1; 2542 VectorType reductionType = VectorType::get(reductionShape, elType); 2543 SmallVector<int64_t> offsets(destRank, 0); 2544 SmallVector<int64_t> strides(destRank, 1); 2545 SmallVector<int64_t> sizes(destShape.begin(), destShape.end()); 2546 sizes[reductionDim] = 1; 2547 ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes); 2548 ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides); 2549 2550 Value lastOutput, lastInput; 2551 for (int i = 0; i < destShape[reductionDim]; i++) { 2552 offsets[reductionDim] = i; 2553 ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets); 2554 Value input = rewriter.create<vector::ExtractStridedSliceOp>( 2555 loc, reductionType, scanOp.source(), scanOffsets, scanSizes, 2556 scanStrides); 2557 Value output; 2558 if (i == 0) { 2559 if (inclusive) { 2560 output = input; 2561 } else { 2562 if (initialValueRank == 0) { 2563 // ShapeCastOp cannot handle 0-D vectors 2564 output = rewriter.create<vector::BroadcastOp>( 2565 loc, input.getType(), scanOp.initial_value()); 2566 } else { 2567 output = rewriter.create<vector::ShapeCastOp>( 2568 loc, input.getType(), scanOp.initial_value()); 2569 } 2570 } 2571 } else { 2572 Value y = inclusive ? input : lastInput; 2573 output = genOperator(loc, lastOutput, y, scanOp.kind(), rewriter); 2574 assert(output != nullptr); 2575 } 2576 result = rewriter.create<vector::InsertStridedSliceOp>( 2577 loc, output, result, offsets, strides); 2578 lastOutput = output; 2579 lastInput = input; 2580 } 2581 2582 Value reduction; 2583 if (initialValueRank == 0) { 2584 Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0); 2585 reduction = 2586 rewriter.create<vector::BroadcastOp>(loc, initialValueType, v); 2587 } else { 2588 reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType, 2589 lastOutput); 2590 } 2591 2592 rewriter.replaceOp(scanOp, {result, reduction}); 2593 return success(); 2594 } 2595 }; 2596 2597 } // namespace 2598 2599 void mlir::vector::populateVectorMaskMaterializationPatterns( 2600 RewritePatternSet &patterns, bool indexOptimizations) { 2601 patterns.add<VectorCreateMaskOpConversion, 2602 MaterializeTransferMask<vector::TransferReadOp>, 2603 MaterializeTransferMask<vector::TransferWriteOp>>( 2604 patterns.getContext(), indexOptimizations); 2605 } 2606 2607 void mlir::vector::populateShapeCastFoldingPatterns( 2608 RewritePatternSet &patterns) { 2609 patterns.add<ShapeCastOpFolder>(patterns.getContext()); 2610 } 2611 2612 void mlir::vector::populateBubbleVectorBitCastOpPatterns( 2613 RewritePatternSet &patterns) { 2614 patterns.add<BubbleDownVectorBitCastForExtract, 2615 BubbleDownBitCastForStridedSliceExtract, 2616 BubbleUpBitCastForStridedSliceInsert>(patterns.getContext()); 2617 } 2618 2619 void mlir::vector::populateVectorBroadcastLoweringPatterns( 2620 RewritePatternSet &patterns) { 2621 patterns.add<BroadcastOpLowering>(patterns.getContext()); 2622 } 2623 2624 void mlir::vector::populateVectorMaskOpLoweringPatterns( 2625 RewritePatternSet &patterns) { 2626 patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>( 2627 patterns.getContext()); 2628 } 2629 2630 void mlir::vector::populateVectorShapeCastLoweringPatterns( 2631 RewritePatternSet &patterns) { 2632 patterns.add<ShapeCastOp2DDownCastRewritePattern, 2633 ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>( 2634 patterns.getContext()); 2635 } 2636 2637 void mlir::vector::populateVectorContractLoweringPatterns( 2638 RewritePatternSet &patterns, VectorTransformsOptions options) { 2639 patterns.add<OuterProductOpLowering>(patterns.getContext()); 2640 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering, 2641 ContractionOpToOuterProductOpLowering>(options, 2642 patterns.getContext()); 2643 } 2644 2645 void mlir::vector::populateVectorTransposeLoweringPatterns( 2646 RewritePatternSet &patterns, VectorTransformsOptions options) { 2647 patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>( 2648 options, patterns.getContext()); 2649 } 2650 2651 void mlir::vector::populateVectorReductionToContractPatterns( 2652 RewritePatternSet &patterns) { 2653 patterns.add<MultiReduceToContract, CombineContractBroadcast, 2654 CombineContractTranspose, ReorderCastOpsOnBroadcast, 2655 ReorderCastOpsOnTranspose>(patterns.getContext()); 2656 } 2657 2658 void mlir::vector:: 2659 populateVectorTransferCollapseInnerMostContiguousDimsPatterns( 2660 RewritePatternSet &patterns) { 2661 patterns.add<DropInnerMostUnitDims>(patterns.getContext()); 2662 } 2663 2664 void mlir::vector::populateVectorTransferLoweringPatterns( 2665 RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) { 2666 patterns.add<TransferReadToVectorLoadLowering, 2667 TransferWriteToVectorStoreLowering>(patterns.getContext(), 2668 maxTransferRank); 2669 patterns 2670 .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>( 2671 patterns.getContext()); 2672 } 2673 2674 void mlir::vector::populateVectorScanLoweringPatterns( 2675 RewritePatternSet &patterns) { 2676 patterns.add<ScanToArithOps>(patterns.getContext()); 2677 } 2678