1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites as 1->N patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 14 15 #include <type_traits> 16 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 19 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 20 #include "mlir/Dialect/Linalg/IR/Linalg.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/SCF.h" 23 #include "mlir/Dialect/Utils/IndexingUtils.h" 24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 26 #include "mlir/IR/ImplicitLocOpBuilder.h" 27 #include "mlir/IR/Matchers.h" 28 #include "mlir/IR/PatternMatch.h" 29 #include "mlir/Interfaces/VectorInterfaces.h" 30 31 #include "llvm/ADT/DenseSet.h" 32 #include "llvm/ADT/MapVector.h" 33 #include "llvm/ADT/STLExtras.h" 34 #include "llvm/Support/CommandLine.h" 35 #include "llvm/Support/Debug.h" 36 #include "llvm/Support/raw_ostream.h" 37 38 #define DEBUG_TYPE "vector-to-vector" 39 40 using namespace mlir; 41 using namespace mlir::vector; 42 43 // Helper to find an index in an affine map. 44 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) { 45 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 46 int64_t idx = map.getDimPosition(i); 47 if (idx == index) 48 return i; 49 } 50 return None; 51 } 52 53 // Helper to construct iterator types with one index removed. 54 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes, 55 int64_t index) { 56 SmallVector<Attribute, 4> results; 57 for (const auto &it : llvm::enumerate(iteratorTypes)) { 58 int64_t idx = it.index(); 59 if (idx == index) 60 continue; 61 results.push_back(it.value()); 62 } 63 return results; 64 } 65 66 // Helper to construct an affine map with one index removed. 67 static AffineMap adjustMap(AffineMap map, int64_t index, 68 PatternRewriter &rewriter) { 69 auto *ctx = rewriter.getContext(); 70 SmallVector<AffineExpr, 4> results; 71 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 72 int64_t idx = map.getDimPosition(i); 73 if (idx == index) 74 continue; 75 // Re-insert remaining indices, but renamed when occurring 76 // after the removed index. 77 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); 78 results.push_back(targetExpr); 79 } 80 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); 81 } 82 83 // Helper method to possibly drop a dimension in a load. 84 // TODO 85 static Value reshapeLoad(Location loc, Value val, VectorType type, 86 int64_t index, int64_t pos, 87 PatternRewriter &rewriter) { 88 if (index == -1) 89 return val; 90 Type lowType = VectorType::Builder(type).dropDim(0); 91 // At extraction dimension? 92 if (index == 0) { 93 auto posAttr = rewriter.getI64ArrayAttr(pos); 94 return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr); 95 } 96 // Unroll leading dimensions. 97 VectorType vType = lowType.cast<VectorType>(); 98 Type resType = VectorType::Builder(type).dropDim(index); 99 auto resVectorType = resType.cast<VectorType>(); 100 Value result = rewriter.create<arith::ConstantOp>( 101 loc, resVectorType, rewriter.getZeroAttr(resVectorType)); 102 for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { 103 auto posAttr = rewriter.getI64ArrayAttr(d); 104 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr); 105 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); 106 result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, 107 posAttr); 108 } 109 return result; 110 } 111 112 // Helper method to possibly drop a dimension in a store. 113 // TODO 114 static Value reshapeStore(Location loc, Value val, Value result, 115 VectorType type, int64_t index, int64_t pos, 116 PatternRewriter &rewriter) { 117 // Unmodified? 118 if (index == -1) 119 return val; 120 // At insertion dimension? 121 if (index == 0) { 122 auto posAttr = rewriter.getI64ArrayAttr(pos); 123 return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr); 124 } 125 // Unroll leading dimensions. 126 Type lowType = VectorType::Builder(type).dropDim(0); 127 VectorType vType = lowType.cast<VectorType>(); 128 Type insType = VectorType::Builder(vType).dropDim(0); 129 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { 130 auto posAttr = rewriter.getI64ArrayAttr(d); 131 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr); 132 Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr); 133 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); 134 result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr); 135 } 136 return result; 137 } 138 139 template <typename IntType> 140 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) { 141 return llvm::to_vector<4>(llvm::map_range( 142 arrayAttr.getAsRange<IntegerAttr>(), 143 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); 144 } 145 146 namespace { 147 148 /// ShapeCastOpFolder folds cancelling ShapeCastOps away. 149 // 150 // Example: 151 // 152 // The following MLIR with cancelling ShapeCastOps: 153 // 154 // %0 = source : vector<5x4x2xf32> 155 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> 156 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> 157 // %3 = user %2 : vector<5x4x2xf32> 158 // 159 // Should canonicalize to the following: 160 // 161 // %0 = source : vector<5x4x2xf32> 162 // %1 = user %0 : vector<5x4x2xf32> 163 // 164 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> { 165 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 166 167 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, 168 PatternRewriter &rewriter) const override { 169 // Check if 'shapeCastOp' has vector source/result type. 170 auto sourceVectorType = 171 shapeCastOp.getSource().getType().dyn_cast_or_null<VectorType>(); 172 auto resultVectorType = 173 shapeCastOp.getResult().getType().dyn_cast_or_null<VectorType>(); 174 if (!sourceVectorType || !resultVectorType) 175 return failure(); 176 177 // Check if shape cast op source operand is also a shape cast op. 178 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>( 179 shapeCastOp.getSource().getDefiningOp()); 180 if (!sourceShapeCastOp) 181 return failure(); 182 auto operandSourceVectorType = 183 sourceShapeCastOp.getSource().getType().cast<VectorType>(); 184 auto operandResultVectorType = sourceShapeCastOp.getType(); 185 186 // Check if shape cast operations invert each other. 187 if (operandSourceVectorType != resultVectorType || 188 operandResultVectorType != sourceVectorType) 189 return failure(); 190 191 rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource()); 192 return success(); 193 } 194 }; 195 196 /// Progressive lowering of BroadcastOp. 197 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> { 198 public: 199 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern; 200 201 LogicalResult matchAndRewrite(vector::BroadcastOp op, 202 PatternRewriter &rewriter) const override { 203 auto loc = op.getLoc(); 204 VectorType dstType = op.getVectorType(); 205 VectorType srcType = op.getSourceType().dyn_cast<VectorType>(); 206 Type eltType = dstType.getElementType(); 207 208 // Scalar to any vector can use splat. 209 if (!srcType) { 210 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource()); 211 return success(); 212 } 213 214 // Determine rank of source and destination. 215 int64_t srcRank = srcType.getRank(); 216 int64_t dstRank = dstType.getRank(); 217 218 // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. 219 if (srcRank <= 1 && dstRank == 1) { 220 Value ext; 221 if (srcRank == 0) 222 ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource()); 223 else 224 ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0); 225 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext); 226 return success(); 227 } 228 229 // Duplicate this rank. 230 // For example: 231 // %x = broadcast %y : k-D to n-D, k < n 232 // becomes: 233 // %b = broadcast %y : k-D to (n-1)-D 234 // %x = [%b,%b,%b,%b] : n-D 235 // becomes: 236 // %b = [%y,%y] : (n-1)-D 237 // %x = [%b,%b,%b,%b] : n-D 238 if (srcRank < dstRank) { 239 // Duplication. 240 VectorType resType = 241 VectorType::get(dstType.getShape().drop_front(), eltType); 242 Value bcst = 243 rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource()); 244 Value result = rewriter.create<arith::ConstantOp>( 245 loc, dstType, rewriter.getZeroAttr(dstType)); 246 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 247 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 248 rewriter.replaceOp(op, result); 249 return success(); 250 } 251 252 // Find non-matching dimension, if any. 253 assert(srcRank == dstRank); 254 int64_t m = -1; 255 for (int64_t r = 0; r < dstRank; r++) 256 if (srcType.getDimSize(r) != dstType.getDimSize(r)) { 257 m = r; 258 break; 259 } 260 261 // All trailing dimensions are the same. Simply pass through. 262 if (m == -1) { 263 rewriter.replaceOp(op, op.getSource()); 264 return success(); 265 } 266 267 // Any non-matching dimension forces a stretch along this rank. 268 // For example: 269 // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32> 270 // becomes: 271 // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32> 272 // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32> 273 // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32> 274 // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32> 275 // %x = [%a,%b,%c,%d] 276 // becomes: 277 // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32> 278 // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32> 279 // %a = [%u, %v] 280 // .. 281 // %x = [%a,%b,%c,%d] 282 VectorType resType = 283 VectorType::get(dstType.getShape().drop_front(), eltType); 284 Value result = rewriter.create<arith::ConstantOp>( 285 loc, dstType, rewriter.getZeroAttr(dstType)); 286 if (m == 0) { 287 // Stetch at start. 288 Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0); 289 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 290 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 291 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 292 } else { 293 // Stetch not at start. 294 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) { 295 Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d); 296 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 297 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 298 } 299 } 300 rewriter.replaceOp(op, result); 301 return success(); 302 } 303 }; 304 305 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not 306 /// transposed. 307 void pruneNonTransposedDims(ArrayRef<int64_t> transpose, 308 SmallVectorImpl<int64_t> &result) { 309 size_t numTransposedDims = transpose.size(); 310 for (size_t transpDim : llvm::reverse(transpose)) { 311 if (transpDim != numTransposedDims - 1) 312 break; 313 numTransposedDims--; 314 } 315 316 result.append(transpose.begin(), transpose.begin() + numTransposedDims); 317 } 318 319 /// Progressive lowering of TransposeOp. 320 /// One: 321 /// %x = vector.transpose %y, [1, 0] 322 /// is replaced by: 323 /// %z = arith.constant dense<0.000000e+00> 324 /// %0 = vector.extract %y[0, 0] 325 /// %1 = vector.insert %0, %z [0, 0] 326 /// .. 327 /// %x = vector.insert .., .. [.., ..] 328 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { 329 public: 330 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 331 332 TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions, 333 MLIRContext *context) 334 : OpRewritePattern<vector::TransposeOp>(context), 335 vectorTransformOptions(vectorTransformOptions) {} 336 337 LogicalResult matchAndRewrite(vector::TransposeOp op, 338 PatternRewriter &rewriter) const override { 339 auto loc = op.getLoc(); 340 341 Value input = op.getVector(); 342 VectorType inputType = op.getVectorType(); 343 VectorType resType = op.getResultType(); 344 345 // Set up convenience transposition table. 346 SmallVector<int64_t, 4> transp; 347 for (auto attr : op.getTransp()) 348 transp.push_back(attr.cast<IntegerAttr>().getInt()); 349 350 if (vectorTransformOptions.vectorTransposeLowering == 351 vector::VectorTransposeLowering::Shuffle && 352 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) 353 return rewriter.notifyMatchFailure( 354 op, "Options specifies lowering to shuffle"); 355 356 // Handle a true 2-D matrix transpose differently when requested. 357 if (vectorTransformOptions.vectorTransposeLowering == 358 vector::VectorTransposeLowering::Flat && 359 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { 360 Type flattenedType = 361 VectorType::get(resType.getNumElements(), resType.getElementType()); 362 auto matrix = 363 rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input); 364 auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); 365 auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); 366 Value trans = rewriter.create<vector::FlatTransposeOp>( 367 loc, flattenedType, matrix, rows, columns); 368 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans); 369 return success(); 370 } 371 372 // Generate unrolled extract/insert ops. We do not unroll the rightmost 373 // (i.e., highest-order) dimensions that are not transposed and leave them 374 // in vector form to improve performance. Therefore, we prune those 375 // dimensions from the shape/transpose data structures used to generate the 376 // extract/insert ops. 377 SmallVector<int64_t, 4> prunedTransp; 378 pruneNonTransposedDims(transp, prunedTransp); 379 size_t numPrunedDims = transp.size() - prunedTransp.size(); 380 auto prunedInShape = inputType.getShape().drop_back(numPrunedDims); 381 SmallVector<int64_t, 4> ones(prunedInShape.size(), 1); 382 auto prunedInStrides = computeStrides(prunedInShape, ones); 383 384 // Generates the extract/insert operations for every scalar/vector element 385 // of the leftmost transposed dimensions. We traverse every transpose 386 // element using a linearized index that we delinearize to generate the 387 // appropriate indices for the extract/insert operations. 388 Value result = rewriter.create<arith::ConstantOp>( 389 loc, resType, rewriter.getZeroAttr(resType)); 390 int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape); 391 392 for (int64_t linearIdx = 0; linearIdx < numTransposedElements; 393 ++linearIdx) { 394 auto extractIdxs = delinearize(prunedInStrides, linearIdx); 395 SmallVector<int64_t, 4> insertIdxs(extractIdxs); 396 applyPermutationToVector(insertIdxs, prunedTransp); 397 Value extractOp = 398 rewriter.create<vector::ExtractOp>(loc, input, extractIdxs); 399 result = 400 rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs); 401 } 402 403 rewriter.replaceOp(op, result); 404 return success(); 405 } 406 407 private: 408 /// Options to control the vector patterns. 409 vector::VectorTransformsOptions vectorTransformOptions; 410 }; 411 412 /// Rewrite a 2-D vector.transpose as a sequence of: 413 /// vector.shape_cast 2D -> 1D 414 /// vector.shuffle 415 /// vector.shape_cast 1D -> 2D 416 class TransposeOp2DToShuffleLowering 417 : public OpRewritePattern<vector::TransposeOp> { 418 public: 419 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 420 421 TransposeOp2DToShuffleLowering( 422 vector::VectorTransformsOptions vectorTransformOptions, 423 MLIRContext *context) 424 : OpRewritePattern<vector::TransposeOp>(context), 425 vectorTransformOptions(vectorTransformOptions) {} 426 427 LogicalResult matchAndRewrite(vector::TransposeOp op, 428 PatternRewriter &rewriter) const override { 429 auto loc = op.getLoc(); 430 431 VectorType srcType = op.getVectorType(); 432 if (srcType.getRank() != 2) 433 return rewriter.notifyMatchFailure(op, "Not a 2D transpose"); 434 435 SmallVector<int64_t, 4> transp; 436 for (auto attr : op.getTransp()) 437 transp.push_back(attr.cast<IntegerAttr>().getInt()); 438 if (transp[0] != 1 && transp[1] != 0) 439 return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation"); 440 441 if (vectorTransformOptions.vectorTransposeLowering != 442 VectorTransposeLowering::Shuffle) 443 return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle"); 444 445 int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); 446 Value casted = rewriter.create<vector::ShapeCastOp>( 447 loc, VectorType::get({m * n}, srcType.getElementType()), 448 op.getVector()); 449 SmallVector<int64_t> mask; 450 mask.reserve(m * n); 451 for (int64_t j = 0; j < n; ++j) 452 for (int64_t i = 0; i < m; ++i) 453 mask.push_back(i * n + j); 454 455 Value shuffled = 456 rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask); 457 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(), 458 shuffled); 459 460 return success(); 461 } 462 463 private: 464 /// Options to control the vector patterns. 465 vector::VectorTransformsOptions vectorTransformOptions; 466 }; 467 468 /// Progressive lowering of OuterProductOp. 469 /// One: 470 /// %x = vector.outerproduct %lhs, %rhs, %acc 471 /// is replaced by: 472 /// %z = zero-result 473 /// %0 = vector.extract %lhs[0] 474 /// %1 = vector.broadcast %0 475 /// %2 = vector.extract %acc[0] 476 /// %3 = vector.fma %1, %rhs, %2 477 /// %4 = vector.insert %3, %z[0] 478 /// .. 479 /// %x = vector.insert %.., %..[N-1] 480 /// 481 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { 482 public: 483 using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern; 484 485 LogicalResult matchAndRewrite(vector::OuterProductOp op, 486 PatternRewriter &rewriter) const override { 487 auto loc = op.getLoc(); 488 489 VectorType lhsType = op.getOperandVectorTypeLHS(); 490 VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>(); 491 VectorType resType = op.getVectorType(); 492 Type eltType = resType.getElementType(); 493 bool isInt = eltType.isa<IntegerType, IndexType>(); 494 Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0]; 495 vector::CombiningKind kind = op.getKind(); 496 497 if (!rhsType) { 498 // Special case: AXPY operation. 499 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs()); 500 Optional<Value> mult = 501 isInt ? genMultI(loc, op.getLhs(), b, acc, kind, rewriter) 502 : genMultF(loc, op.getLhs(), b, acc, kind, rewriter); 503 if (!mult.hasValue()) 504 return failure(); 505 rewriter.replaceOp(op, mult.getValue()); 506 return success(); 507 } 508 509 Value result = rewriter.create<arith::ConstantOp>( 510 loc, resType, rewriter.getZeroAttr(resType)); 511 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { 512 auto pos = rewriter.getI64ArrayAttr(d); 513 Value x = 514 rewriter.create<vector::ExtractOp>(loc, eltType, op.getLhs(), pos); 515 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x); 516 Value r = nullptr; 517 if (acc) 518 r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos); 519 Optional<Value> m = 520 isInt ? genMultI(loc, a, op.getRhs(), r, kind, rewriter) 521 : genMultF(loc, a, op.getRhs(), r, kind, rewriter); 522 if (!m.hasValue()) 523 return failure(); 524 result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(), 525 result, pos); 526 } 527 rewriter.replaceOp(op, result); 528 return success(); 529 } 530 531 private: 532 static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc, 533 vector::CombiningKind kind, 534 PatternRewriter &rewriter) { 535 using vector::CombiningKind; 536 537 auto mul = rewriter.create<arith::MulIOp>(loc, x, y); 538 if (!acc) 539 return Optional<Value>(mul); 540 541 if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) 542 // Only valid for floating point types. 543 return Optional<Value>(); 544 545 return makeArithReduction(rewriter, loc, kind, mul, acc); 546 } 547 548 static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc, 549 vector::CombiningKind kind, 550 PatternRewriter &rewriter) { 551 using vector::CombiningKind; 552 553 // Special case for fused multiply-add. 554 if (acc && kind == CombiningKind::ADD) { 555 return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc)); 556 } 557 558 auto mul = rewriter.create<arith::MulFOp>(loc, x, y); 559 560 if (!acc) 561 return Optional<Value>(mul); 562 563 if (kind == CombiningKind::ADD || kind == CombiningKind::AND || 564 kind == CombiningKind::MINUI || kind == CombiningKind::MINSI || 565 kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI || 566 kind == CombiningKind::OR || kind == CombiningKind::XOR) 567 // Already handled or only valid for integer types. 568 return Optional<Value>(); 569 570 return makeArithReduction(rewriter, loc, kind, mul, acc); 571 } 572 }; 573 574 /// Progressive lowering of ConstantMaskOp. 575 /// One: 576 /// %x = vector.constant_mask [a,b] 577 /// is replaced by: 578 /// %z = zero-result 579 /// %l = vector.constant_mask [b] 580 /// %4 = vector.insert %l, %z[0] 581 /// .. 582 /// %x = vector.insert %l, %..[a-1] 583 /// until a one-dimensional vector is reached. All these operations 584 /// will be folded at LLVM IR level. 585 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { 586 public: 587 using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern; 588 589 LogicalResult matchAndRewrite(vector::ConstantMaskOp op, 590 PatternRewriter &rewriter) const override { 591 auto loc = op.getLoc(); 592 auto dstType = op.getType(); 593 auto eltType = dstType.getElementType(); 594 auto dimSizes = op.getMaskDimSizes(); 595 int64_t rank = dstType.getRank(); 596 597 if (rank == 0) { 598 assert(dimSizes.size() == 1 && 599 "Expected exactly one dim size for a 0-D vector"); 600 bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1; 601 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 602 op, dstType, 603 DenseIntElementsAttr::get( 604 VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()), 605 ArrayRef<bool>{value})); 606 return success(); 607 } 608 609 // Scalable constant masks can only be lowered for the "none set" case. 610 if (dstType.cast<VectorType>().isScalable()) { 611 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 612 op, DenseElementsAttr::get(dstType, false)); 613 return success(); 614 } 615 616 int64_t trueDim = std::min(dstType.getDimSize(0), 617 dimSizes[0].cast<IntegerAttr>().getInt()); 618 619 if (rank == 1) { 620 // Express constant 1-D case in explicit vector form: 621 // [T,..,T,F,..,F]. 622 SmallVector<bool, 4> values(dstType.getDimSize(0)); 623 for (int64_t d = 0; d < trueDim; d++) 624 values[d] = true; 625 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 626 op, dstType, rewriter.getBoolVectorAttr(values)); 627 return success(); 628 } 629 630 VectorType lowType = 631 VectorType::get(dstType.getShape().drop_front(), eltType); 632 SmallVector<int64_t, 4> newDimSizes; 633 for (int64_t r = 1; r < rank; r++) 634 newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt()); 635 Value trueVal = rewriter.create<vector::ConstantMaskOp>( 636 loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); 637 Value result = rewriter.create<arith::ConstantOp>( 638 loc, dstType, rewriter.getZeroAttr(dstType)); 639 for (int64_t d = 0; d < trueDim; d++) { 640 auto pos = rewriter.getI64ArrayAttr(d); 641 result = 642 rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos); 643 } 644 rewriter.replaceOp(op, result); 645 return success(); 646 } 647 }; 648 649 /// Progressive lowering of CreateMaskOp. 650 /// One: 651 /// %x = vector.create_mask %a, ... : vector<dx...> 652 /// is replaced by: 653 /// %l = vector.create_mask ... : vector<...> ; one lower rank 654 /// %0 = arith.cmpi "slt", %ci, %a | 655 /// %1 = select %0, %l, %zeroes | 656 /// %r = vector.insert %1, %pr [i] | d-times 657 /// %x = .... 658 /// until a one-dimensional vector is reached. 659 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { 660 public: 661 using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern; 662 663 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 664 PatternRewriter &rewriter) const override { 665 auto dstType = op.getResult().getType().cast<VectorType>(); 666 int64_t rank = dstType.getRank(); 667 if (rank <= 1) 668 return rewriter.notifyMatchFailure( 669 op, "0-D and 1-D vectors are handled separately"); 670 671 auto loc = op.getLoc(); 672 auto eltType = dstType.getElementType(); 673 int64_t dim = dstType.getDimSize(0); 674 Value idx = op.getOperand(0); 675 676 VectorType lowType = 677 VectorType::get(dstType.getShape().drop_front(), eltType); 678 Value trueVal = rewriter.create<vector::CreateMaskOp>( 679 loc, lowType, op.getOperands().drop_front()); 680 Value falseVal = rewriter.create<arith::ConstantOp>( 681 loc, lowType, rewriter.getZeroAttr(lowType)); 682 Value result = rewriter.create<arith::ConstantOp>( 683 loc, dstType, rewriter.getZeroAttr(dstType)); 684 for (int64_t d = 0; d < dim; d++) { 685 Value bnd = 686 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d)); 687 Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 688 bnd, idx); 689 Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal); 690 auto pos = rewriter.getI64ArrayAttr(d); 691 result = 692 rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos); 693 } 694 rewriter.replaceOp(op, result); 695 return success(); 696 } 697 }; 698 699 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D 700 /// vectors progressively on the way to target llvm.matrix intrinsics. 701 /// This iterates over the most major dimension of the 2-D vector and performs 702 /// rewrites into: 703 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D 704 class ShapeCastOp2DDownCastRewritePattern 705 : public OpRewritePattern<vector::ShapeCastOp> { 706 public: 707 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 708 709 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 710 PatternRewriter &rewriter) const override { 711 auto sourceVectorType = op.getSourceVectorType(); 712 auto resultVectorType = op.getResultVectorType(); 713 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) 714 return failure(); 715 716 auto loc = op.getLoc(); 717 Value desc = rewriter.create<arith::ConstantOp>( 718 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 719 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; 720 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { 721 Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i); 722 desc = rewriter.create<vector::InsertStridedSliceOp>( 723 loc, vec, desc, 724 /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); 725 } 726 rewriter.replaceOp(op, desc); 727 return success(); 728 } 729 }; 730 731 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D 732 /// vectors progressively. 733 /// This iterates over the most major dimension of the 2-D vector and performs 734 /// rewrites into: 735 /// vector.extract_strided_slice from 1-D + vector.insert into 2-D 736 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. 737 class ShapeCastOp2DUpCastRewritePattern 738 : public OpRewritePattern<vector::ShapeCastOp> { 739 public: 740 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 741 742 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 743 PatternRewriter &rewriter) const override { 744 auto sourceVectorType = op.getSourceVectorType(); 745 auto resultVectorType = op.getResultVectorType(); 746 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) 747 return failure(); 748 749 auto loc = op.getLoc(); 750 Value desc = rewriter.create<arith::ConstantOp>( 751 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 752 unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; 753 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { 754 Value vec = rewriter.create<vector::ExtractStridedSliceOp>( 755 loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize, 756 /*sizes=*/mostMinorVectorSize, 757 /*strides=*/1); 758 desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i); 759 } 760 rewriter.replaceOp(op, desc); 761 return success(); 762 } 763 }; 764 765 // We typically should not lower general shape cast operations into data 766 // movement instructions, since the assumption is that these casts are 767 // optimized away during progressive lowering. For completeness, however, 768 // we fall back to a reference implementation that moves all elements 769 // into the right place if we get here. 770 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { 771 public: 772 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 773 774 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 775 PatternRewriter &rewriter) const override { 776 Location loc = op.getLoc(); 777 auto sourceVectorType = op.getSourceVectorType(); 778 auto resultVectorType = op.getResultVectorType(); 779 780 // Special case 2D/1D lowerings with better implementations. 781 // TODO: make is ND/1D to allow generic ND->1D->MD. 782 int64_t srcRank = sourceVectorType.getRank(); 783 int64_t resRank = resultVectorType.getRank(); 784 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2)) 785 return failure(); 786 787 // Generic ShapeCast lowering path goes all the way down to unrolled scalar 788 // extract/insert chains. 789 // TODO: consider evolving the semantics to only allow 1D source or dest and 790 // drop this potentially very expensive lowering. 791 // Compute number of elements involved in the reshape. 792 int64_t numElts = 1; 793 for (int64_t r = 0; r < srcRank; r++) 794 numElts *= sourceVectorType.getDimSize(r); 795 // Replace with data movement operations: 796 // x[0,0,0] = y[0,0] 797 // x[0,0,1] = y[0,1] 798 // x[0,1,0] = y[0,2] 799 // etc., incrementing the two index vectors "row-major" 800 // within the source and result shape. 801 SmallVector<int64_t, 4> srcIdx(srcRank); 802 SmallVector<int64_t, 4> resIdx(resRank); 803 Value result = rewriter.create<arith::ConstantOp>( 804 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 805 for (int64_t i = 0; i < numElts; i++) { 806 if (i != 0) { 807 incIdx(srcIdx, sourceVectorType, srcRank - 1); 808 incIdx(resIdx, resultVectorType, resRank - 1); 809 } 810 Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx); 811 result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx); 812 } 813 rewriter.replaceOp(op, result); 814 return success(); 815 } 816 817 private: 818 static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) { 819 assert(0 <= r && r < tp.getRank()); 820 if (++idx[r] == tp.getDimSize(r)) { 821 idx[r] = 0; 822 incIdx(idx, tp, r - 1); 823 } 824 } 825 }; 826 827 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp. 828 /// Ex: 829 /// ``` 830 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> 831 /// %1 = vector.multi_reduction add, %0 [1] 832 /// : vector<8x32x16xf32> to vector<8x16xf32> 833 /// ``` 834 /// Gets converted to: 835 /// ``` 836 /// %1 = vector.contract {indexing_maps = [ 837 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 838 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 839 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 840 /// iterator_types = ["parallel", "parallel", "reduction"], 841 /// kind = add} %0, %arg1, %cst_f0 842 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 843 /// ``` 844 struct MultiReduceToContract 845 : public OpRewritePattern<vector::MultiDimReductionOp> { 846 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 847 848 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, 849 PatternRewriter &rewriter) const override { 850 if (reduceOp.getKind() != vector::CombiningKind::ADD) 851 return failure(); 852 Operation *mulOp = reduceOp.getSource().getDefiningOp(); 853 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp)) 854 return failure(); 855 SmallVector<bool> reductionMask = reduceOp.getReductionMask(); 856 auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size()); 857 SmallVector<AffineExpr> exprs; 858 SmallVector<StringRef> iteratorTypes; 859 for (const auto &isReduceDim : llvm::enumerate(reductionMask)) { 860 if (!isReduceDim.value()) { 861 iteratorTypes.push_back(getParallelIteratorTypeName()); 862 exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index())); 863 } else { 864 iteratorTypes.push_back(getReductionIteratorTypeName()); 865 } 866 } 867 auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(), 868 /*symCount=*/0, exprs, reduceOp.getContext()); 869 Value zero = rewriter.create<arith::ConstantOp>( 870 reduceOp.getLoc(), reduceOp.getDestType(), 871 rewriter.getZeroAttr(reduceOp.getDestType())); 872 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>( 873 reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero, 874 rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}), 875 rewriter.getStrArrayAttr(iteratorTypes)); 876 return success(); 877 } 878 }; 879 880 /// Merge TransposeOp into ContractionOp user. 881 /// Ex: 882 /// ``` 883 /// %0 = vector.transpose %arg0, [2, 0, 1] 884 /// : vector<32x16x8xf32> to vector<8x32x16xf32> 885 /// %1 = vector.contract {indexing_maps = [ 886 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 887 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 888 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 889 /// iterator_types = ["parallel", "parallel", "reduction"], 890 /// kind = add} %0, %arg1, %cst_f0 891 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 892 /// ``` 893 /// Gets converted to: 894 /// ``` 895 /// %1 = vector.contract {indexing_maps = [ 896 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>, 897 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 898 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 899 /// iterator_types = ["parallel", "parallel", "reduction"], 900 /// kind = add} %arg0, %arg1, %cst_f0 901 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 902 /// ``` 903 struct CombineContractTranspose 904 : public OpRewritePattern<vector::ContractionOp> { 905 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 906 907 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 908 PatternRewriter &rewriter) const override { 909 SmallVector<AffineMap, 4> maps = 910 llvm::to_vector<4>(contractOp.getIndexingMaps()); 911 Value lhs = contractOp.getLhs(); 912 Value rhs = contractOp.getRhs(); 913 size_t index = 0; 914 bool changed = false; 915 for (Value *operand : {&lhs, &rhs}) { 916 AffineMap &map = maps[index++]; 917 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>(); 918 if (!transposeOp) 919 continue; 920 SmallVector<int64_t> perm; 921 transposeOp.getTransp(perm); 922 AffineMap permutationMap = AffineMap::getPermutationMap( 923 extractVector<unsigned>(transposeOp.getTransp()), 924 contractOp.getContext()); 925 map = inversePermutation(permutationMap).compose(map); 926 *operand = transposeOp.getVector(); 927 changed = true; 928 } 929 if (!changed) 930 return failure(); 931 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 932 contractOp, lhs, rhs, contractOp.getAcc(), 933 rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); 934 return success(); 935 } 936 }; 937 938 /// Merge BroadcastOp into ContractionOp user. 939 /// Ex: 940 /// ``` 941 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> 942 /// %1 = vector.contract {indexing_maps = [ 943 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 944 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 945 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 946 /// iterator_types = ["parallel", "parallel", "reduction"], 947 /// kind = add} %0, %arg1, %cst_f0 948 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 949 /// ``` 950 /// Gets converted to: 951 /// ``` 952 /// %1 = vector.contract {indexing_maps = [ 953 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 954 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 955 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 956 /// iterator_types = ["parallel", "parallel", "reduction"], 957 /// kind = add} %arg0, %arg1, %cst_f0 958 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 959 /// ``` 960 struct CombineContractBroadcast 961 : public OpRewritePattern<vector::ContractionOp> { 962 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 963 964 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 965 PatternRewriter &rewriter) const override { 966 SmallVector<AffineMap, 4> maps = 967 llvm::to_vector<4>(contractOp.getIndexingMaps()); 968 Value lhs = contractOp.getLhs(); 969 Value rhs = contractOp.getRhs(); 970 size_t index = 0; 971 bool changed = false; 972 for (Value *operand : {&lhs, &rhs}) { 973 AffineMap &map = maps[index++]; 974 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>(); 975 if (!broadcast) 976 continue; 977 // contractionOp can only take vector as operands. 978 auto srcType = broadcast.getSourceType().dyn_cast<VectorType>(); 979 if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank()) 980 continue; 981 int64_t rankDiff = 982 broadcast.getVectorType().getRank() - srcType.getRank(); 983 bool innerDimBroadcast = false; 984 SmallVector<AffineExpr> originalDims; 985 for (const auto &dim : llvm::enumerate(srcType.getShape())) { 986 if (dim.value() != 987 broadcast.getVectorType().getDimSize(rankDiff + dim.index())) { 988 innerDimBroadcast = true; 989 break; 990 } 991 originalDims.push_back( 992 rewriter.getAffineDimExpr(dim.index() + rankDiff)); 993 } 994 // Contract doesn't support inner dimension broadcast. Once this is 995 // relaxed we can remove this case. 996 if (innerDimBroadcast) 997 continue; 998 AffineMap broadcastMap = 999 AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims, 1000 contractOp.getContext()); 1001 map = broadcastMap.compose(map); 1002 *operand = broadcast.getSource(); 1003 changed = true; 1004 } 1005 if (!changed) 1006 return failure(); 1007 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 1008 contractOp, lhs, rhs, contractOp.getAcc(), 1009 rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); 1010 return success(); 1011 } 1012 }; 1013 1014 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and 1015 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when 1016 /// casting ops are around these operations. 1017 /// Ex: 1018 /// ``` 1019 /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8> 1020 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> 1021 /// ``` 1022 /// Gets converted to: 1023 /// ``` 1024 /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32> 1025 /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32> 1026 /// ``` 1027 struct ReorderCastOpsOnBroadcast 1028 : public OpInterfaceRewritePattern<CastOpInterface> { 1029 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern; 1030 1031 LogicalResult matchAndRewrite(CastOpInterface op, 1032 PatternRewriter &rewriter) const override { 1033 if (op->getNumOperands() != 1) 1034 return failure(); 1035 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>(); 1036 if (!bcastOp) 1037 return failure(); 1038 1039 Type castResTy = getElementTypeOrSelf(op->getResult(0)); 1040 if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>()) 1041 castResTy = VectorType::get(vecTy.getShape(), castResTy); 1042 auto castOp = 1043 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 1044 bcastOp.getSource(), castResTy, op->getAttrs()); 1045 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1046 op, op->getResult(0).getType(), castOp->getResult(0)); 1047 return success(); 1048 } 1049 }; 1050 1051 /// Reorders cast(transpose) to transpose(cast). This makes broadcast ops and 1052 /// contraction ops closer, which kicks in CombineContractTranspose pattern when 1053 /// casting ops are around these operations. 1054 /// Ex: 1055 /// ``` 1056 /// %0 = vector.transpose %arg0, [2, 0, 1] 1057 /// : vector<32x16x8xi8> to vector<8x32x16xi8> 1058 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> 1059 /// ``` 1060 /// Gets converted to: 1061 /// ``` 1062 /// %0 = arith.extsi %0 : vector<32x16x8xi8> to vector<32x16x8xi32> 1063 /// %1 = vector.transpose %arg0, [2, 0, 1] 1064 /// : vector<32x16x8xi32> to vector<8x32x16xi32> 1065 /// ``` 1066 struct ReorderCastOpsOnTranspose 1067 : public OpInterfaceRewritePattern<CastOpInterface> { 1068 1069 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern; 1070 1071 LogicalResult matchAndRewrite(CastOpInterface op, 1072 PatternRewriter &rewriter) const override { 1073 if (op->getNumOperands() != 1) 1074 return failure(); 1075 auto transpOp = op->getOperand(0).getDefiningOp<vector::TransposeOp>(); 1076 if (!transpOp) 1077 return failure(); 1078 1079 auto castResTy = transpOp.getVectorType(); 1080 castResTy = VectorType::get(castResTy.getShape(), 1081 getElementTypeOrSelf(op->getResult(0))); 1082 auto castOp = 1083 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 1084 transpOp.getVector(), castResTy, op->getAttrs()); 1085 rewriter.replaceOpWithNewOp<vector::TransposeOp>( 1086 op, op->getResult(0).getType(), castOp->getResult(0), 1087 transpOp.getTransp()); 1088 return success(); 1089 } 1090 }; 1091 1092 } // namespace 1093 1094 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using 1095 /// operands `x` and `y`. 1096 static Value createAdd(Location loc, Value x, Value y, bool isInt, 1097 PatternRewriter &rewriter) { 1098 if (isInt) 1099 return rewriter.create<arith::AddIOp>(loc, x, y); 1100 return rewriter.create<arith::AddFOp>(loc, x, y); 1101 } 1102 1103 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using 1104 /// operands `x and `y`. 1105 static Value createMul(Location loc, Value x, Value y, bool isInt, 1106 PatternRewriter &rewriter) { 1107 if (isInt) 1108 return rewriter.create<arith::MulIOp>(loc, x, y); 1109 return rewriter.create<arith::MulFOp>(loc, x, y); 1110 } 1111 1112 namespace mlir { 1113 1114 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1115 /// semantics to: 1116 /// ``` 1117 /// %mta = maybe_transpose 1118 /// %mtb = maybe_transpose 1119 /// %flattened_a = vector.shape_cast %mta 1120 /// %flattened_b = vector.shape_cast %mtb 1121 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 1122 /// %mtd = vector.shape_cast %flattened_d 1123 /// %d = maybe_untranspose %mtd 1124 /// %e = add %c, %d 1125 /// ``` 1126 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 1127 // 1128 /// This only kicks in when VectorTransformsOptions is set to `Matmul`. 1129 /// vector.transpose operations are inserted if the vector.contract op is not a 1130 /// row-major matrix multiply. 1131 LogicalResult 1132 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, 1133 PatternRewriter &rew) const { 1134 // TODO: implement masks 1135 if (llvm::size(op.getMasks()) != 0) 1136 return failure(); 1137 if (vectorTransformOptions.vectorContractLowering != 1138 vector::VectorContractLowering::Matmul) 1139 return failure(); 1140 if (failed(filter(op))) 1141 return failure(); 1142 1143 auto iteratorTypes = op.getIteratorTypes().getValue(); 1144 if (!isParallelIterator(iteratorTypes[0]) || 1145 !isParallelIterator(iteratorTypes[1]) || 1146 !isReductionIterator(iteratorTypes[2])) 1147 return failure(); 1148 1149 Type elementType = op.getLhsType().getElementType(); 1150 if (!elementType.isIntOrFloat()) 1151 return failure(); 1152 1153 // Perform lhs + rhs transpositions to conform to matmul row-major semantics. 1154 // Bail out if the contraction cannot be put in this form. 1155 MLIRContext *ctx = op.getContext(); 1156 Location loc = op.getLoc(); 1157 AffineExpr m, n, k; 1158 bindDims(rew.getContext(), m, n, k); 1159 // LHS must be A(m, k) or A(k, m). 1160 Value lhs = op.getLhs(); 1161 auto lhsMap = op.getIndexingMaps()[0]; 1162 if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) 1163 lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0}); 1164 else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) 1165 return failure(); 1166 1167 // RHS must be B(k, n) or B(n, k). 1168 Value rhs = op.getRhs(); 1169 auto rhsMap = op.getIndexingMaps()[1]; 1170 if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) 1171 rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0}); 1172 else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) 1173 return failure(); 1174 1175 // At this point lhs and rhs are in row-major. 1176 VectorType lhsType = lhs.getType().cast<VectorType>(); 1177 VectorType rhsType = rhs.getType().cast<VectorType>(); 1178 int64_t lhsRows = lhsType.getDimSize(0); 1179 int64_t lhsColumns = lhsType.getDimSize(1); 1180 int64_t rhsColumns = rhsType.getDimSize(1); 1181 1182 Type flattenedLHSType = 1183 VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); 1184 lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs); 1185 1186 Type flattenedRHSType = 1187 VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); 1188 rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs); 1189 1190 Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns, 1191 rhsColumns); 1192 mul = rew.create<vector::ShapeCastOp>( 1193 loc, 1194 VectorType::get({lhsRows, rhsColumns}, 1195 getElementTypeOrSelf(op.getAcc().getType())), 1196 mul); 1197 1198 // ACC must be C(m, n) or C(n, m). 1199 auto accMap = op.getIndexingMaps()[2]; 1200 if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) 1201 mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0}); 1202 else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) 1203 llvm_unreachable("invalid contraction semantics"); 1204 1205 Value res = 1206 elementType.isa<IntegerType>() 1207 ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul)) 1208 : static_cast<Value>( 1209 rew.create<arith::AddFOp>(loc, op.getAcc(), mul)); 1210 1211 rew.replaceOp(op, res); 1212 return success(); 1213 } 1214 1215 namespace { 1216 struct IteratorType { 1217 IteratorType(StringRef strRef) : strRef(strRef) {} 1218 bool isOfType(Attribute attr) const { 1219 auto sAttr = attr.dyn_cast<StringAttr>(); 1220 return sAttr && sAttr.getValue() == strRef; 1221 } 1222 StringRef strRef; 1223 }; 1224 struct Par : public IteratorType { 1225 Par() : IteratorType(getParallelIteratorTypeName()) {} 1226 }; 1227 struct Red : public IteratorType { 1228 Red() : IteratorType(getReductionIteratorTypeName()) {} 1229 }; 1230 1231 /// Generate a vector implementation for matmat, matvec and tmatvec. 1232 /// This unrolls outer-products along the reduction dimension. 1233 struct UnrolledOuterProductGenerator 1234 : public StructuredGenerator<vector::ContractionOp> { 1235 UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op) 1236 : StructuredGenerator<vector::ContractionOp>(builder, op), 1237 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), 1238 res(op.getAcc()), lhsType(op.getLhsType()) {} 1239 1240 Value t(Value v) { 1241 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1242 return builder.create<vector::TransposeOp>(loc, v, perm); 1243 } 1244 1245 Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) { 1246 assert(reductionSize > 0); 1247 for (int64_t k = 0; k < reductionSize; ++k) { 1248 Value a = builder.create<vector::ExtractOp>(loc, lhs, k); 1249 Value b = builder.create<vector::ExtractOp>(loc, rhs, k); 1250 res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b, 1251 res, kind); 1252 } 1253 return res; 1254 } 1255 1256 /// Two outer parallel, one inner reduction (matmat flavor). 1257 FailureOr<Value> matmat() { 1258 if (!iters({Par(), Par(), Red()})) 1259 return failure(); 1260 // Set up the parallel/reduction structure in the right form. 1261 AffineExpr m, n, k; 1262 bindDims(builder.getContext(), m, n, k); 1263 // Classical row-major matmul: Just permute the lhs. 1264 if (layout({{m, k}, {k, n}, {m, n}})) 1265 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1266 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1267 if (layout({{m, k}, {n, k}, {m, n}})) { 1268 Value tlhs = t(lhs); 1269 return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1)); 1270 } 1271 // No need to permute anything. 1272 if (layout({{k, m}, {k, n}, {m, n}})) 1273 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1274 // Just permute the rhs. 1275 if (layout({{k, m}, {n, k}, {m, n}})) 1276 return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0)); 1277 // Transposed output: swap RHS and LHS. 1278 // Classical row-major matmul: permute the lhs. 1279 if (layout({{m, k}, {k, n}, {n, m}})) 1280 return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1)); 1281 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1282 if (layout({{m, k}, {n, k}, {n, m}})) { 1283 Value trhs = t(rhs); 1284 return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1)); 1285 } 1286 if (layout({{k, m}, {k, n}, {n, m}})) 1287 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1288 if (layout({{k, m}, {n, k}, {n, m}})) 1289 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1290 return failure(); 1291 } 1292 1293 /// One outer parallel, one inner reduction (matvec flavor) 1294 FailureOr<Value> matvec() { 1295 if (!iters({Par(), Red()})) 1296 return failure(); 1297 AffineExpr m, k; 1298 bindDims(builder.getContext(), m, k); 1299 1300 // Case mat-vec: transpose. 1301 if (layout({{m, k}, {k}, {m}})) 1302 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1303 // Case mat-trans-vec: ready to go. 1304 if (layout({{k, m}, {k}, {m}})) 1305 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1306 // Case vec-mat: swap and transpose. 1307 if (layout({{k}, {m, k}, {m}})) 1308 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1309 // Case vec-mat-trans: swap and ready to go. 1310 if (layout({{k}, {k, m}, {m}})) 1311 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1312 return failure(); 1313 } 1314 1315 // 1316 // One outer reduction, one inner parallel (tmatvec flavor) 1317 // 1318 FailureOr<Value> tmatvec() { 1319 if (!iters({Red(), Par()})) 1320 return failure(); 1321 AffineExpr k, m; 1322 bindDims(builder.getContext(), k, m); 1323 1324 // Case mat-vec: transpose. 1325 if (layout({{m, k}, {k}, {m}})) 1326 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1327 // Case mat-trans-vec: ready to go. 1328 if (layout({{k, m}, {k}, {m}})) 1329 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1330 // Case vec-mat: swap and transpose. 1331 if (layout({{k}, {m, k}, {m}})) 1332 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1333 // Case vec-mat-trans: swap and ready to go. 1334 if (layout({{k}, {k, m}, {m}})) 1335 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1336 return failure(); 1337 } 1338 1339 private: 1340 vector::CombiningKind kind; 1341 Value lhs, rhs, res; 1342 VectorType lhsType; 1343 }; 1344 } // namespace 1345 1346 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1347 /// semantics to a reduction_size-unrolled sequence: 1348 /// ``` 1349 /// %at = vector.transpose %a, [1, 0] 1350 /// %bRow0 = vector.extract %b[0] 1351 /// %atRow0 = vector.extract %at[0] 1352 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 1353 /// ... 1354 /// %bRowK = vector.extract %b[K] 1355 /// %atRowK = vector.extract %at[K] 1356 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 1357 /// ``` 1358 /// 1359 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but 1360 /// otherwise supports any layout permutation of the matrix-multiply. 1361 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( 1362 vector::ContractionOp op, PatternRewriter &rewriter) const { 1363 // TODO: implement masks 1364 if (llvm::size(op.getMasks()) != 0) 1365 return failure(); 1366 1367 if (vectorTransformOptions.vectorContractLowering != 1368 vector::VectorContractLowering::OuterProduct) 1369 return failure(); 1370 1371 if (failed(filter(op))) 1372 return failure(); 1373 1374 UnrolledOuterProductGenerator e(rewriter, op); 1375 FailureOr<Value> matmatRes = e.matmat(); 1376 if (succeeded(matmatRes)) { 1377 rewriter.replaceOp(op, *matmatRes); 1378 return success(); 1379 } 1380 FailureOr<Value> matvecRes = e.matvec(); 1381 if (succeeded(matvecRes)) { 1382 rewriter.replaceOp(op, *matvecRes); 1383 return success(); 1384 } 1385 FailureOr<Value> tmatvecRes = e.tmatvec(); 1386 if (succeeded(tmatvecRes)) { 1387 rewriter.replaceOp(op, *tmatvecRes); 1388 return success(); 1389 } 1390 1391 return failure(); 1392 } 1393 1394 LogicalResult 1395 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, 1396 PatternRewriter &rewriter) const { 1397 // TODO: implement masks 1398 if (llvm::size(op.getMasks()) != 0) 1399 return failure(); 1400 1401 if (failed(filter(op))) 1402 return failure(); 1403 1404 if (vectorTransformOptions.vectorContractLowering != 1405 vector::VectorContractLowering::Dot) 1406 return failure(); 1407 1408 auto iteratorTypes = op.getIteratorTypes().getValue(); 1409 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1410 Location loc = op.getLoc(); 1411 Value lhs = op.getLhs(), rhs = op.getRhs(); 1412 1413 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1414 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 1415 AffineExpr m, n, k; 1416 bindDims(rewriter.getContext(), m, n, k); 1417 SmallVector<AffineMap, 4> maps = op.getIndexingMaps(); 1418 // 1419 // In the following we wish to make the reduction dimension innermost so we 1420 // can load vectors and just fmul + reduce into a scalar. 1421 // 1422 if (isParallelIterator(iteratorTypes[0]) && 1423 isParallelIterator(iteratorTypes[1]) && 1424 isReductionIterator(iteratorTypes[2])) { 1425 // 1426 // Two outer parallel, one inner reduction (matmat flavor). 1427 // 1428 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 1429 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1430 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { 1431 // No need to permute anything. 1432 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 1433 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1434 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1435 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 1436 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1437 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 1438 // This is the classical row-major matmul. Just permute the lhs. 1439 Value tmp = lhs; 1440 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1441 rhs = tmp; 1442 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 1443 std::swap(lhs, rhs); 1444 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 1445 Value tmp = lhs; 1446 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1447 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm); 1448 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 1449 Value tmp = rhs; 1450 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1451 lhs = tmp; 1452 } else { 1453 return failure(); 1454 } 1455 } else if (isParallelIterator(iteratorTypes[0]) && 1456 isReductionIterator(iteratorTypes[1])) { 1457 // 1458 // One outer parallel, one inner reduction (matvec flavor) 1459 // 1460 if (maps == infer({{m, n}, {n}, {m}})) { 1461 // No need to permute anything. 1462 } else if (maps == infer({{n, m}, {n}, {m}})) { 1463 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1464 } else if (maps == infer({{n}, {m, n}, {m}})) { 1465 std::swap(lhs, rhs); 1466 } else if (maps == infer({{n}, {n, m}, {m}})) { 1467 std::swap(lhs, rhs); 1468 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1469 } else { 1470 return failure(); 1471 } 1472 } else { 1473 return failure(); 1474 } 1475 1476 VectorType dstType = op.getResultType().cast<VectorType>(); 1477 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && 1478 "Expected dst type of rank 1 or 2"); 1479 1480 unsigned rank = dstType.getRank(); 1481 unsigned dstRows = dstType.getShape()[0]; 1482 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; 1483 1484 // ExtractOp does not allow dynamic indexing, we must unroll explicitly. 1485 Value res = rewriter.create<arith::ConstantOp>(loc, dstType, 1486 rewriter.getZeroAttr(dstType)); 1487 bool isInt = dstType.getElementType().isa<IntegerType>(); 1488 for (unsigned r = 0; r < dstRows; ++r) { 1489 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r); 1490 for (unsigned c = 0; c < dstColumns; ++c) { 1491 Value b = rank == 1 1492 ? rhs 1493 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c); 1494 Value m = createMul(op.getLoc(), a, b, isInt, rewriter); 1495 Value reduced = rewriter.create<vector::ReductionOp>( 1496 op.getLoc(), vector::CombiningKind::ADD, m); 1497 1498 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r} 1499 : SmallVector<int64_t, 2>{r, c}; 1500 res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos); 1501 } 1502 } 1503 if (auto acc = op.getAcc()) 1504 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1505 rewriter.replaceOp(op, res); 1506 return success(); 1507 } 1508 1509 /// Progressive lowering of ContractionOp. 1510 /// One: 1511 /// %x = vector.contract with at least one free/batch dimension 1512 /// is replaced by: 1513 /// %a = vector.contract with one less free/batch dimension 1514 /// %b = vector.contract with one less free/batch dimension 1515 /// .. 1516 /// %x = combine %a %b .. 1517 /// until a pure contraction is reached (no free/batch dimensions), 1518 /// which is replaced by a dot-product. 1519 /// 1520 /// This only kicks in when either VectorTransformsOptions is set 1521 /// to DOT or when other contraction patterns fail. 1522 // 1523 // TODO: break down into transpose/reshape/cast ops 1524 // when they become available to avoid code dup 1525 // TODO: investigate lowering order impact on performance 1526 LogicalResult 1527 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, 1528 PatternRewriter &rewriter) const { 1529 // TODO: implement masks. 1530 if (llvm::size(op.getMasks()) != 0) 1531 return failure(); 1532 1533 if (failed(filter(op))) 1534 return failure(); 1535 1536 // TODO: support mixed mode contract lowering. 1537 if (op.getLhsType().getElementType() != 1538 getElementTypeOrSelf(op.getAccType()) || 1539 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) 1540 return failure(); 1541 1542 // TODO: implement benefits, cost models. 1543 MLIRContext *ctx = op.getContext(); 1544 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); 1545 if (succeeded(pat1.matchAndRewrite(op, rewriter))) 1546 return success(); 1547 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); 1548 if (succeeded(pat2.matchAndRewrite(op, rewriter))) 1549 return success(); 1550 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); 1551 if (succeeded(pat3.matchAndRewrite(op, rewriter))) 1552 return success(); 1553 1554 // Find first batch dimension in LHS/RHS, and lower when found. 1555 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap(); 1556 if (!batchDimMap.empty()) { 1557 int64_t lhsIndex = batchDimMap[0].first; 1558 int64_t rhsIndex = batchDimMap[0].second; 1559 rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter)); 1560 return success(); 1561 } 1562 1563 // Collect contracting dimensions. 1564 std::vector<std::pair<int64_t, int64_t>> contractingDimMap = 1565 op.getContractingDimMap(); 1566 DenseSet<int64_t> lhsContractingDimSet; 1567 DenseSet<int64_t> rhsContractingDimSet; 1568 for (auto &dimPair : contractingDimMap) { 1569 lhsContractingDimSet.insert(dimPair.first); 1570 rhsContractingDimSet.insert(dimPair.second); 1571 } 1572 1573 // Find first free dimension in LHS, and lower when found. 1574 VectorType lhsType = op.getLhsType(); 1575 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { 1576 if (lhsContractingDimSet.count(lhsIndex) == 0) { 1577 rewriter.replaceOp( 1578 op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter)); 1579 return success(); 1580 } 1581 } 1582 1583 // Find first free dimension in RHS, and lower when found. 1584 VectorType rhsType = op.getRhsType(); 1585 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { 1586 if (rhsContractingDimSet.count(rhsIndex) == 0) { 1587 rewriter.replaceOp( 1588 op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter)); 1589 return success(); 1590 } 1591 } 1592 1593 // Lower the first remaining reduction dimension. 1594 if (!contractingDimMap.empty()) { 1595 rewriter.replaceOp(op, lowerReduction(op, rewriter)); 1596 return success(); 1597 } 1598 1599 return failure(); 1600 } 1601 1602 // Lower one parallel dimension. 1603 // TODO: consider reusing existing contract unrolling 1604 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op, 1605 int64_t lhsIndex, int64_t rhsIndex, 1606 PatternRewriter &rewriter) const { 1607 VectorType lhsType = op.getLhsType(); 1608 VectorType rhsType = op.getRhsType(); 1609 VectorType resType = op.getResultType().cast<VectorType>(); 1610 // Find the iterator type index and result index. 1611 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1612 int64_t iterIndex = -1; 1613 int64_t dimSize = -1; 1614 if (lhsIndex >= 0) { 1615 iterIndex = iMap[0].getDimPosition(lhsIndex); 1616 assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) && 1617 "parallel index should be free in LHS or batch in LHS/RHS"); 1618 dimSize = lhsType.getDimSize(lhsIndex); 1619 } else { 1620 assert(rhsIndex >= 0 && "missing parallel index"); 1621 iterIndex = iMap[1].getDimPosition(rhsIndex); 1622 dimSize = rhsType.getDimSize(rhsIndex); 1623 } 1624 assert(iterIndex >= 0 && "parallel index not listed in operand mapping"); 1625 Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex); 1626 assert(lookup.hasValue() && "parallel index not listed in reduction"); 1627 int64_t resIndex = lookup.getValue(); 1628 // Construct new iterator types and affine map array attribute. 1629 std::array<AffineMap, 3> lowIndexingMaps = { 1630 adjustMap(iMap[0], iterIndex, rewriter), 1631 adjustMap(iMap[1], iterIndex, rewriter), 1632 adjustMap(iMap[2], iterIndex, rewriter)}; 1633 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1634 auto lowIter = 1635 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); 1636 // Unroll into a series of lower dimensional vector.contract ops. 1637 Location loc = op.getLoc(); 1638 Value result = rewriter.create<arith::ConstantOp>( 1639 loc, resType, rewriter.getZeroAttr(resType)); 1640 for (int64_t d = 0; d < dimSize; ++d) { 1641 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); 1642 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); 1643 auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter); 1644 Value lowContract = rewriter.create<vector::ContractionOp>( 1645 loc, lhs, rhs, acc, lowAffine, lowIter); 1646 result = 1647 reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter); 1648 } 1649 return result; 1650 } 1651 1652 // Lower one reduction dimension. 1653 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op, 1654 PatternRewriter &rewriter) const { 1655 auto loc = op.getLoc(); 1656 VectorType lhsType = op.getLhsType(); 1657 VectorType rhsType = op.getRhsType(); 1658 Type resType = op.getResultType(); 1659 assert(!resType.isa<VectorType>()); 1660 bool isInt = resType.isa<IntegerType>(); 1661 // Use iterator index 0. 1662 int64_t iterIndex = 0; 1663 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1664 Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex); 1665 Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex); 1666 assert(lookupLhs.hasValue() && "missing LHS parallel index"); 1667 assert(lookupRhs.hasValue() && "missing RHS parallel index"); 1668 int64_t lhsIndex = lookupLhs.getValue(); 1669 int64_t rhsIndex = lookupRhs.getValue(); 1670 int64_t dimSize = lhsType.getDimSize(lhsIndex); 1671 assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape"); 1672 // Base case. 1673 if (lhsType.getRank() == 1) { 1674 assert(rhsType.getRank() == 1 && "corrupt contraction"); 1675 Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter); 1676 auto kind = vector::CombiningKind::ADD; 1677 Value res = rewriter.create<vector::ReductionOp>(loc, kind, m); 1678 if (auto acc = op.getAcc()) 1679 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1680 return res; 1681 } 1682 // Construct new iterator types and affine map array attribute. 1683 std::array<AffineMap, 3> lowIndexingMaps = { 1684 adjustMap(iMap[0], iterIndex, rewriter), 1685 adjustMap(iMap[1], iterIndex, rewriter), 1686 adjustMap(iMap[2], iterIndex, rewriter)}; 1687 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1688 auto lowIter = 1689 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); 1690 // Unroll into a series of lower dimensional vector.contract ops. 1691 // By feeding the initial accumulator into the first contraction, 1692 // and the result of each contraction into the next, eventually 1693 // the sum of all reductions is computed. 1694 Value result = op.getAcc(); 1695 for (int64_t d = 0; d < dimSize; ++d) { 1696 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); 1697 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); 1698 result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result, 1699 lowAffine, lowIter); 1700 } 1701 return result; 1702 } 1703 1704 } // namespace mlir 1705 1706 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp( 1707 OpBuilder &builder, Operation *op, ArrayRef<Value> ids, 1708 ArrayRef<int64_t> multiplicity, const AffineMap &map) { 1709 OpBuilder::InsertionGuard guard(builder); 1710 builder.setInsertionPointAfter(op); 1711 Location loc = op->getLoc(); 1712 if (op->getNumResults() != 1) 1713 return {}; 1714 Value result = op->getResult(0); 1715 VectorType type = op->getResult(0).getType().dyn_cast<VectorType>(); 1716 if (!type || map.getNumResults() != multiplicity.size()) 1717 return {}; 1718 // For each dimension being distributed check that the size is a multiple of 1719 // the multiplicity. To handle more sizes we would need to support masking. 1720 unsigned multiplictyCount = 0; 1721 for (auto exp : map.getResults()) { 1722 auto affinExp = exp.dyn_cast<AffineDimExpr>(); 1723 if (!affinExp || affinExp.getPosition() >= type.getRank() || 1724 type.getDimSize(affinExp.getPosition()) % 1725 multiplicity[multiplictyCount++] != 1726 0) 1727 return {}; 1728 } 1729 DistributeOps ops; 1730 ops.extract = 1731 builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map); 1732 ops.insert = 1733 builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids); 1734 return ops; 1735 } 1736 1737 /// Progressive lowering of transfer_read. This pattern supports lowering of 1738 /// `vector.transfer_read` to a combination of `vector.load` and 1739 /// `vector.broadcast` if all of the following hold: 1740 /// - Stride of most minor memref dimension must be 1. 1741 /// - Out-of-bounds masking is not required. 1742 /// - If the memref's element type is a vector type then it coincides with the 1743 /// result type. 1744 /// - The permutation map doesn't perform permutation (broadcasting is allowed). 1745 struct TransferReadToVectorLoadLowering 1746 : public OpRewritePattern<vector::TransferReadOp> { 1747 TransferReadToVectorLoadLowering(MLIRContext *context, 1748 llvm::Optional<unsigned> maxRank) 1749 : OpRewritePattern<vector::TransferReadOp>(context), 1750 maxTransferRank(maxRank) {} 1751 1752 LogicalResult matchAndRewrite(vector::TransferReadOp read, 1753 PatternRewriter &rewriter) const override { 1754 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) 1755 return failure(); 1756 1757 SmallVector<unsigned, 4> broadcastedDims; 1758 // Permutations are handled by VectorToSCF or 1759 // populateVectorTransferPermutationMapLoweringPatterns. 1760 // We let the 0-d corner case pass-through as it is supported. 1761 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting( 1762 &broadcastedDims)) 1763 return failure(); 1764 1765 auto memRefType = read.getShapedType().dyn_cast<MemRefType>(); 1766 if (!memRefType) 1767 return failure(); 1768 1769 // Non-unit strides are handled by VectorToSCF. 1770 if (!vector::isLastMemrefDimUnitStride(memRefType)) 1771 return failure(); 1772 1773 // If there is broadcasting involved then we first load the unbroadcasted 1774 // vector, and then broadcast it with `vector.broadcast`. 1775 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape(); 1776 SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(), 1777 vectorShape.end()); 1778 for (unsigned i : broadcastedDims) 1779 unbroadcastedVectorShape[i] = 1; 1780 VectorType unbroadcastedVectorType = VectorType::get( 1781 unbroadcastedVectorShape, read.getVectorType().getElementType()); 1782 1783 // `vector.load` supports vector types as memref's elements only when the 1784 // resulting vector type is the same as the element type. 1785 auto memrefElTy = memRefType.getElementType(); 1786 if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType) 1787 return failure(); 1788 1789 // Otherwise, element types of the memref and the vector must match. 1790 if (!memrefElTy.isa<VectorType>() && 1791 memrefElTy != read.getVectorType().getElementType()) 1792 return failure(); 1793 1794 // Out-of-bounds dims are handled by MaterializeTransferMask. 1795 if (read.hasOutOfBoundsDim()) 1796 return failure(); 1797 1798 // Create vector load op. 1799 Operation *loadOp; 1800 if (read.getMask()) { 1801 Value fill = rewriter.create<vector::SplatOp>( 1802 read.getLoc(), unbroadcastedVectorType, read.getPadding()); 1803 loadOp = rewriter.create<vector::MaskedLoadOp>( 1804 read.getLoc(), unbroadcastedVectorType, read.getSource(), 1805 read.getIndices(), read.getMask(), fill); 1806 } else { 1807 loadOp = rewriter.create<vector::LoadOp>( 1808 read.getLoc(), unbroadcastedVectorType, read.getSource(), 1809 read.getIndices()); 1810 } 1811 1812 // Insert a broadcasting op if required. 1813 if (!broadcastedDims.empty()) { 1814 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1815 read, read.getVectorType(), loadOp->getResult(0)); 1816 } else { 1817 rewriter.replaceOp(read, loadOp->getResult(0)); 1818 } 1819 1820 return success(); 1821 } 1822 1823 llvm::Optional<unsigned> maxTransferRank; 1824 }; 1825 1826 /// Replace a 0-d vector.load with a memref.load + vector.broadcast. 1827 // TODO: we shouldn't cross the vector/scalar domains just for this 1828 // but atm we lack the infra to avoid it. Possible solutions include: 1829 // - go directly to LLVM + bitcast 1830 // - introduce a bitcast op and likely a new pointer dialect 1831 // - let memref.load/store additionally support the 0-d vector case 1832 // There are still deeper data layout issues lingering even in this 1833 // trivial case (for architectures for which this matters). 1834 struct VectorLoadToMemrefLoadLowering 1835 : public OpRewritePattern<vector::LoadOp> { 1836 using OpRewritePattern<vector::LoadOp>::OpRewritePattern; 1837 1838 LogicalResult matchAndRewrite(vector::LoadOp loadOp, 1839 PatternRewriter &rewriter) const override { 1840 auto vecType = loadOp.getVectorType(); 1841 if (vecType.getNumElements() != 1) 1842 return failure(); 1843 auto memrefLoad = rewriter.create<memref::LoadOp>( 1844 loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices()); 1845 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType, 1846 memrefLoad); 1847 return success(); 1848 } 1849 }; 1850 1851 /// Replace a 0-d vector.store with a vector.extractelement + memref.store. 1852 struct VectorStoreToMemrefStoreLowering 1853 : public OpRewritePattern<vector::StoreOp> { 1854 using OpRewritePattern<vector::StoreOp>::OpRewritePattern; 1855 1856 LogicalResult matchAndRewrite(vector::StoreOp storeOp, 1857 PatternRewriter &rewriter) const override { 1858 auto vecType = storeOp.getVectorType(); 1859 if (vecType.getNumElements() != 1) 1860 return failure(); 1861 Value extracted; 1862 if (vecType.getRank() == 0) { 1863 // TODO: Unifiy once ExtractOp supports 0-d vectors. 1864 extracted = rewriter.create<vector::ExtractElementOp>( 1865 storeOp.getLoc(), storeOp.getValueToStore()); 1866 } else { 1867 SmallVector<int64_t> indices(vecType.getRank(), 0); 1868 extracted = rewriter.create<vector::ExtractOp>( 1869 storeOp.getLoc(), storeOp.getValueToStore(), indices); 1870 } 1871 1872 rewriter.replaceOpWithNewOp<memref::StoreOp>( 1873 storeOp, extracted, storeOp.getBase(), storeOp.getIndices()); 1874 return success(); 1875 } 1876 }; 1877 1878 /// Progressive lowering of transfer_write. This pattern supports lowering of 1879 /// `vector.transfer_write` to `vector.store` if all of the following hold: 1880 /// - Stride of most minor memref dimension must be 1. 1881 /// - Out-of-bounds masking is not required. 1882 /// - If the memref's element type is a vector type then it coincides with the 1883 /// type of the written value. 1884 /// - The permutation map is the minor identity map (neither permutation nor 1885 /// broadcasting is allowed). 1886 struct TransferWriteToVectorStoreLowering 1887 : public OpRewritePattern<vector::TransferWriteOp> { 1888 TransferWriteToVectorStoreLowering(MLIRContext *context, 1889 llvm::Optional<unsigned> maxRank) 1890 : OpRewritePattern<vector::TransferWriteOp>(context), 1891 maxTransferRank(maxRank) {} 1892 1893 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 1894 PatternRewriter &rewriter) const override { 1895 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) 1896 return failure(); 1897 1898 // Permutations are handled by VectorToSCF or 1899 // populateVectorTransferPermutationMapLoweringPatterns. 1900 if ( // pass-through for the 0-d corner case. 1901 !write.getPermutationMap().isMinorIdentity()) 1902 return failure(); 1903 1904 auto memRefType = write.getShapedType().dyn_cast<MemRefType>(); 1905 if (!memRefType) 1906 return failure(); 1907 1908 // Non-unit strides are handled by VectorToSCF. 1909 if (!vector::isLastMemrefDimUnitStride(memRefType)) 1910 return failure(); 1911 1912 // `vector.store` supports vector types as memref's elements only when the 1913 // type of the vector value being written is the same as the element type. 1914 auto memrefElTy = memRefType.getElementType(); 1915 if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType()) 1916 return failure(); 1917 1918 // Otherwise, element types of the memref and the vector must match. 1919 if (!memrefElTy.isa<VectorType>() && 1920 memrefElTy != write.getVectorType().getElementType()) 1921 return failure(); 1922 1923 // Out-of-bounds dims are handled by MaterializeTransferMask. 1924 if (write.hasOutOfBoundsDim()) 1925 return failure(); 1926 if (write.getMask()) { 1927 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 1928 write, write.getSource(), write.getIndices(), write.getMask(), 1929 write.getVector()); 1930 } else { 1931 rewriter.replaceOpWithNewOp<vector::StoreOp>( 1932 write, write.getVector(), write.getSource(), write.getIndices()); 1933 } 1934 return success(); 1935 } 1936 1937 llvm::Optional<unsigned> maxTransferRank; 1938 }; 1939 1940 // Returns the values in `arrayAttr` as an integer vector. 1941 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) { 1942 return llvm::to_vector<4>( 1943 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(), 1944 [](IntegerAttr attr) { return attr.getInt(); })); 1945 } 1946 1947 // Shuffles vector.bitcast op after vector.extract op. 1948 // 1949 // This transforms IR like: 1950 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> 1951 // %1 = vector.extract %0[3] : vector<8xf16> 1952 // Into: 1953 // %0 = vector.extract %src[1] : vector<4xf32> 1954 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16> 1955 // %2 = vector.extract %1[1] : vector<2xf16> 1956 struct BubbleDownVectorBitCastForExtract 1957 : public OpRewritePattern<vector::ExtractOp> { 1958 using OpRewritePattern::OpRewritePattern; 1959 1960 LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 1961 PatternRewriter &rewriter) const override { 1962 // Only support extracting scalars for now. 1963 if (extractOp.getVectorType().getRank() != 1) 1964 return failure(); 1965 1966 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>(); 1967 if (!castOp) 1968 return failure(); 1969 1970 VectorType castSrcType = castOp.getSourceVectorType(); 1971 VectorType castDstType = castOp.getResultVectorType(); 1972 assert(castSrcType.getRank() == castDstType.getRank()); 1973 1974 // Fail to match if we only have one element in the cast op source. 1975 // This is to avoid infinite loop given that this pattern can generate 1976 // such cases. 1977 if (castSrcType.getNumElements() == 1) 1978 return failure(); 1979 1980 // Only support casting to a larger number of elements or now. 1981 // E.g., vector<4xf32> -> vector<8xf16>. 1982 if (castSrcType.getNumElements() > castDstType.getNumElements()) 1983 return failure(); 1984 1985 unsigned expandRatio = 1986 castDstType.getNumElements() / castSrcType.getNumElements(); 1987 1988 auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t { 1989 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 1990 }; 1991 1992 uint64_t index = getFirstIntValue(extractOp.getPosition()); 1993 1994 // Get the single scalar (as a vector) in the source value that packs the 1995 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> 1996 VectorType oneScalarType = 1997 VectorType::get({1}, castSrcType.getElementType()); 1998 Value packedValue = rewriter.create<vector::ExtractOp>( 1999 extractOp.getLoc(), oneScalarType, castOp.getSource(), 2000 rewriter.getI64ArrayAttr(index / expandRatio)); 2001 2002 // Cast it to a vector with the desired scalar's type. 2003 // E.g. f32 -> vector<2xf16> 2004 VectorType packedType = 2005 VectorType::get({expandRatio}, castDstType.getElementType()); 2006 Value castedValue = rewriter.create<vector::BitCastOp>( 2007 extractOp.getLoc(), packedType, packedValue); 2008 2009 // Finally extract the desired scalar. 2010 rewriter.replaceOpWithNewOp<vector::ExtractOp>( 2011 extractOp, extractOp.getType(), castedValue, 2012 rewriter.getI64ArrayAttr(index % expandRatio)); 2013 2014 return success(); 2015 } 2016 }; 2017 2018 // Shuffles vector.bitcast op after vector.extract_strided_slice op. 2019 // 2020 // This transforms IR like: 2021 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> 2022 // %0 = vector.extract_strided_slice %cast { 2023 // offsets = [4], sizes = [4], strides = [1] 2024 // } : vector<8xf16> to vector<4xf16> 2025 // Into: 2026 // %0 = vector.extract_strided_slice %src { 2027 // offsets = [2], sizes = [2], strides = [1] 2028 // } : vector<4xf32> to vector<2xf32> 2029 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> 2030 struct BubbleDownBitCastForStridedSliceExtract 2031 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 2032 using OpRewritePattern::OpRewritePattern; 2033 2034 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 2035 PatternRewriter &rewriter) const override { 2036 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>(); 2037 if (!castOp) 2038 return failure(); 2039 2040 VectorType castSrcType = castOp.getSourceVectorType(); 2041 VectorType castDstType = castOp.getResultVectorType(); 2042 assert(castSrcType.getRank() == castDstType.getRank()); 2043 2044 int64_t castSrcLastDim = castSrcType.getShape().back(); 2045 int64_t castDstLastDim = castDstType.getShape().back(); 2046 // Require casting to more elements for now; other cases to be implemented. 2047 if (castSrcLastDim > castDstLastDim) 2048 return failure(); 2049 2050 // Only accept all one strides for now. 2051 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(), 2052 [](const APInt &val) { return !val.isOneValue(); })) 2053 return failure(); 2054 2055 unsigned rank = extractOp.getVectorType().getRank(); 2056 assert(castDstLastDim % castSrcLastDim == 0); 2057 int64_t expandRatio = castDstLastDim / castSrcLastDim; 2058 2059 // If we have a less number of offsets than the rank, then implicitly we 2060 // are selecting the full range for the last bitcasted dimension; other 2061 // dimensions aren't affected. Otherwise, we need to scale down the last 2062 // dimension's offset given we are extracting from less elements now. 2063 ArrayAttr newOffsets = extractOp.getOffsets(); 2064 if (newOffsets.size() == rank) { 2065 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2066 if (offsets.back() % expandRatio != 0) 2067 return failure(); 2068 offsets.back() = offsets.back() / expandRatio; 2069 newOffsets = rewriter.getI64ArrayAttr(offsets); 2070 } 2071 2072 // Similarly for sizes. 2073 ArrayAttr newSizes = extractOp.getSizes(); 2074 if (newSizes.size() == rank) { 2075 SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes); 2076 if (sizes.back() % expandRatio != 0) 2077 return failure(); 2078 sizes.back() = sizes.back() / expandRatio; 2079 newSizes = rewriter.getI64ArrayAttr(sizes); 2080 } 2081 2082 SmallVector<int64_t, 4> dims = 2083 llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape()); 2084 dims.back() = dims.back() / expandRatio; 2085 VectorType newExtractType = 2086 VectorType::get(dims, castSrcType.getElementType()); 2087 2088 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 2089 extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets, 2090 newSizes, extractOp.getStrides()); 2091 2092 rewriter.replaceOpWithNewOp<vector::BitCastOp>( 2093 extractOp, extractOp.getType(), newExtractOp); 2094 2095 return success(); 2096 } 2097 }; 2098 2099 // Shuffles vector.bitcast op before vector.insert_strided_slice op. 2100 // 2101 // This transforms IR like: 2102 // %0 = vector.insert_strided_slice %src, %dst { 2103 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> 2104 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> 2105 // Into: 2106 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32> 2107 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32> 2108 // %2 = vector.insert_strided_slice %src, %dst { 2109 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 2110 struct BubbleUpBitCastForStridedSliceInsert 2111 : public OpRewritePattern<vector::BitCastOp> { 2112 using OpRewritePattern::OpRewritePattern; 2113 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 2114 PatternRewriter &rewriter) const override { 2115 VectorType castSrcType = bitcastOp.getSourceVectorType(); 2116 VectorType castDstType = bitcastOp.getResultVectorType(); 2117 assert(castSrcType.getRank() == castDstType.getRank()); 2118 2119 int64_t castSrcLastDim = castSrcType.getShape().back(); 2120 int64_t castDstLastDim = castDstType.getShape().back(); 2121 // Require casting to less elements for now; other cases to be implemented. 2122 if (castSrcLastDim < castDstLastDim) 2123 return failure(); 2124 2125 assert(castSrcLastDim % castDstLastDim == 0); 2126 int64_t shrinkRatio = castSrcLastDim / castDstLastDim; 2127 2128 auto insertOp = 2129 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>(); 2130 if (!insertOp) 2131 return failure(); 2132 2133 // Only accept all one strides for now. 2134 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(), 2135 [](const APInt &val) { return !val.isOneValue(); })) 2136 return failure(); 2137 2138 unsigned rank = insertOp.getSourceVectorType().getRank(); 2139 // Require insert op to have the same rank for the source and destination 2140 // vector; other cases to be implemented. 2141 if (rank != insertOp.getDestVectorType().getRank()) 2142 return failure(); 2143 2144 ArrayAttr newOffsets = insertOp.getOffsets(); 2145 assert(newOffsets.size() == rank); 2146 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2147 if (offsets.back() % shrinkRatio != 0) 2148 return failure(); 2149 offsets.back() = offsets.back() / shrinkRatio; 2150 newOffsets = rewriter.getI64ArrayAttr(offsets); 2151 2152 SmallVector<int64_t, 4> srcDims = 2153 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape()); 2154 srcDims.back() = srcDims.back() / shrinkRatio; 2155 VectorType newCastSrcType = 2156 VectorType::get(srcDims, castDstType.getElementType()); 2157 2158 auto newCastSrcOp = rewriter.create<vector::BitCastOp>( 2159 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource()); 2160 2161 SmallVector<int64_t, 4> dstDims = 2162 llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); 2163 dstDims.back() = dstDims.back() / shrinkRatio; 2164 VectorType newCastDstType = 2165 VectorType::get(dstDims, castDstType.getElementType()); 2166 2167 auto newCastDstOp = rewriter.create<vector::BitCastOp>( 2168 bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); 2169 2170 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>( 2171 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, 2172 insertOp.getStrides()); 2173 2174 return success(); 2175 } 2176 }; 2177 2178 // Helper that returns a vector comparison that constructs a mask: 2179 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 2180 // 2181 // If `dim == 0` then the result will be a 0-D vector. 2182 // 2183 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 2184 // much more compact, IR for this operation, but LLVM eventually 2185 // generates more elaborate instructions for this intrinsic since it 2186 // is very conservative on the boundary conditions. 2187 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, 2188 bool force32BitVectorIndices, int64_t dim, 2189 Value b, Value *off = nullptr) { 2190 auto loc = op->getLoc(); 2191 // If we can assume all indices fit in 32-bit, we perform the vector 2192 // comparison in 32-bit to get a higher degree of SIMD parallelism. 2193 // Otherwise we perform the vector comparison using 64-bit indices. 2194 Type idxType = 2195 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 2196 DenseIntElementsAttr indicesAttr; 2197 if (dim == 0 && force32BitVectorIndices) { 2198 indicesAttr = DenseIntElementsAttr::get( 2199 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0}); 2200 } else if (dim == 0) { 2201 indicesAttr = DenseIntElementsAttr::get( 2202 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0}); 2203 } else if (force32BitVectorIndices) { 2204 indicesAttr = rewriter.getI32VectorAttr( 2205 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))); 2206 } else { 2207 indicesAttr = rewriter.getI64VectorAttr( 2208 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))); 2209 } 2210 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 2211 // Add in an offset if requested. 2212 if (off) { 2213 Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off); 2214 Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o); 2215 indices = rewriter.create<arith::AddIOp>(loc, ov, indices); 2216 } 2217 // Construct the vector comparison. 2218 Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b); 2219 Value bounds = 2220 rewriter.create<vector::SplatOp>(loc, indices.getType(), bound); 2221 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices, 2222 bounds); 2223 } 2224 2225 template <typename ConcreteOp> 2226 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> { 2227 public: 2228 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt) 2229 : mlir::OpRewritePattern<ConcreteOp>(context), 2230 force32BitVectorIndices(enableIndexOpt) {} 2231 2232 LogicalResult matchAndRewrite(ConcreteOp xferOp, 2233 PatternRewriter &rewriter) const override { 2234 if (!xferOp.hasOutOfBoundsDim()) 2235 return failure(); 2236 2237 if (xferOp.getVectorType().getRank() > 1 || 2238 llvm::size(xferOp.getIndices()) == 0) 2239 return failure(); 2240 2241 Location loc = xferOp->getLoc(); 2242 VectorType vtp = xferOp.getVectorType(); 2243 2244 // Create the in-bounds mask with all elements between [0 .. dim - offset) 2245 // set and [dim - offset .. vector_length) unset. 2246 // 2247 // TODO: when the leaf transfer rank is k > 1, we need the last `k` 2248 // dimensions here. 2249 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1; 2250 Value off = xferOp.getIndices()[lastIndex]; 2251 Value dim = 2252 vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex); 2253 Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off); 2254 Value mask = rewriter.create<vector::CreateMaskOp>( 2255 loc, 2256 VectorType::get(vtp.getShape(), rewriter.getI1Type(), 2257 vtp.getNumScalableDims()), 2258 b); 2259 if (xferOp.getMask()) { 2260 // Intersect the in-bounds with the mask specified as an op parameter. 2261 mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask()); 2262 } 2263 2264 rewriter.updateRootInPlace(xferOp, [&]() { 2265 xferOp.getMaskMutable().assign(mask); 2266 xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 2267 }); 2268 2269 return success(); 2270 } 2271 2272 private: 2273 const bool force32BitVectorIndices; 2274 }; 2275 2276 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only). 2277 class VectorCreateMaskOpConversion 2278 : public OpRewritePattern<vector::CreateMaskOp> { 2279 public: 2280 explicit VectorCreateMaskOpConversion(MLIRContext *context, 2281 bool enableIndexOpt) 2282 : mlir::OpRewritePattern<vector::CreateMaskOp>(context), 2283 force32BitVectorIndices(enableIndexOpt) {} 2284 2285 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 2286 PatternRewriter &rewriter) const override { 2287 auto dstType = op.getType(); 2288 if (dstType.cast<VectorType>().isScalable()) 2289 return failure(); 2290 int64_t rank = dstType.getRank(); 2291 if (rank > 1) 2292 return failure(); 2293 rewriter.replaceOp( 2294 op, buildVectorComparison(rewriter, op, force32BitVectorIndices, 2295 rank == 0 ? 0 : dstType.getDimSize(0), 2296 op.getOperand(0))); 2297 return success(); 2298 } 2299 2300 private: 2301 const bool force32BitVectorIndices; 2302 }; 2303 2304 // Drop inner most contiguous unit dimensions from transfer_read operand. 2305 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> { 2306 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 2307 2308 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 2309 PatternRewriter &rewriter) const override { 2310 // TODO: support 0-d corner case. 2311 if (readOp.getTransferRank() == 0) 2312 return failure(); 2313 2314 // TODO: support mask. 2315 if (readOp.getMask()) 2316 return failure(); 2317 2318 auto srcType = readOp.getSource().getType().dyn_cast<MemRefType>(); 2319 if (!srcType || !srcType.hasStaticShape()) 2320 return failure(); 2321 2322 if (!readOp.getPermutationMap().isMinorIdentity()) 2323 return failure(); 2324 2325 auto targetType = readOp.getVectorType(); 2326 if (targetType.getRank() <= 1) 2327 return failure(); 2328 2329 SmallVector<int64_t> srcStrides; 2330 int64_t srcOffset; 2331 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) 2332 return failure(); 2333 2334 size_t dimsToDrop = 0; 2335 for (size_t i = 1; i < srcStrides.size(); ++i) { 2336 int dim = srcType.getRank() - i - 1; 2337 if (srcStrides[dim] == 1) { 2338 dimsToDrop++; 2339 } else { 2340 break; 2341 } 2342 } 2343 if (dimsToDrop == 0) 2344 return failure(); 2345 2346 auto resultTargetVecType = 2347 VectorType::get(targetType.getShape().drop_back(dimsToDrop), 2348 targetType.getElementType()); 2349 2350 MemRefType resultMemrefType; 2351 if (srcType.getLayout().getAffineMap().isIdentity()) { 2352 resultMemrefType = MemRefType::get( 2353 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2354 {}, srcType.getMemorySpaceAsInt()); 2355 } else { 2356 AffineMap map = srcType.getLayout().getAffineMap(); 2357 int numResultDims = map.getNumDims() - dimsToDrop; 2358 int numSymbols = map.getNumSymbols(); 2359 for (size_t i = 0; i < dimsToDrop; ++i) { 2360 int dim = srcType.getRank() - i - 1; 2361 map = map.replace(rewriter.getAffineDimExpr(dim), 2362 rewriter.getAffineConstantExpr(0), numResultDims, 2363 numSymbols); 2364 } 2365 resultMemrefType = MemRefType::get( 2366 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2367 map, srcType.getMemorySpaceAsInt()); 2368 } 2369 2370 auto loc = readOp.getLoc(); 2371 SmallVector<int64_t> offsets(srcType.getRank(), 0); 2372 SmallVector<int64_t> strides(srcType.getRank(), 1); 2373 2374 ArrayAttr inBoundsAttr = 2375 readOp.getInBounds() 2376 ? rewriter.getArrayAttr( 2377 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)) 2378 : ArrayAttr(); 2379 Value rankedReducedView = rewriter.create<memref::SubViewOp>( 2380 loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(), 2381 strides); 2382 auto permMap = getTransferMinorIdentityMap( 2383 rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType); 2384 Value result = rewriter.create<vector::TransferReadOp>( 2385 loc, resultTargetVecType, rankedReducedView, 2386 readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), 2387 readOp.getPadding(), 2388 // TODO: support mask. 2389 /*mask=*/Value(), inBoundsAttr); 2390 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType, 2391 result); 2392 return success(); 2393 } 2394 }; 2395 2396 namespace { 2397 2398 /// This function checks to see if the vector combining kind 2399 /// is consistent with the integer or float element type. 2400 static bool isValidKind(bool isInt, vector::CombiningKind kind) { 2401 using vector::CombiningKind; 2402 enum class KindType { FLOAT, INT, INVALID }; 2403 KindType type{KindType::INVALID}; 2404 switch (kind) { 2405 case CombiningKind::MINF: 2406 case CombiningKind::MAXF: 2407 type = KindType::FLOAT; 2408 break; 2409 case CombiningKind::MINUI: 2410 case CombiningKind::MINSI: 2411 case CombiningKind::MAXUI: 2412 case CombiningKind::MAXSI: 2413 case CombiningKind::AND: 2414 case CombiningKind::OR: 2415 case CombiningKind::XOR: 2416 type = KindType::INT; 2417 break; 2418 case CombiningKind::ADD: 2419 case CombiningKind::MUL: 2420 type = isInt ? KindType::INT : KindType::FLOAT; 2421 break; 2422 } 2423 bool isValidIntKind = (type == KindType::INT) && isInt; 2424 bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt); 2425 return (isValidIntKind || isValidFloatKind); 2426 } 2427 2428 /// This function constructs the appropriate integer or float 2429 /// operation given the vector combining kind and operands. The 2430 /// supported int operations are : add, mul, min (signed/unsigned), 2431 /// max(signed/unsigned), and, or, xor. The supported float 2432 /// operations are : add, mul, min and max. 2433 static Value genOperator(Location loc, Value x, Value y, 2434 vector::CombiningKind kind, 2435 PatternRewriter &rewriter) { 2436 using vector::CombiningKind; 2437 2438 auto elType = x.getType().cast<VectorType>().getElementType(); 2439 bool isInt = elType.isIntOrIndex(); 2440 2441 Value combinedResult{nullptr}; 2442 switch (kind) { 2443 case CombiningKind::ADD: 2444 if (isInt) 2445 combinedResult = rewriter.create<arith::AddIOp>(loc, x, y); 2446 else 2447 combinedResult = rewriter.create<arith::AddFOp>(loc, x, y); 2448 break; 2449 case CombiningKind::MUL: 2450 if (isInt) 2451 combinedResult = rewriter.create<arith::MulIOp>(loc, x, y); 2452 else 2453 combinedResult = rewriter.create<arith::MulFOp>(loc, x, y); 2454 break; 2455 case CombiningKind::MINUI: 2456 combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y); 2457 break; 2458 case CombiningKind::MINSI: 2459 combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y); 2460 break; 2461 case CombiningKind::MAXUI: 2462 combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y); 2463 break; 2464 case CombiningKind::MAXSI: 2465 combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y); 2466 break; 2467 case CombiningKind::AND: 2468 combinedResult = rewriter.create<arith::AndIOp>(loc, x, y); 2469 break; 2470 case CombiningKind::OR: 2471 combinedResult = rewriter.create<arith::OrIOp>(loc, x, y); 2472 break; 2473 case CombiningKind::XOR: 2474 combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y); 2475 break; 2476 case CombiningKind::MINF: 2477 combinedResult = rewriter.create<arith::MinFOp>(loc, x, y); 2478 break; 2479 case CombiningKind::MAXF: 2480 combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y); 2481 break; 2482 } 2483 return combinedResult; 2484 } 2485 2486 /// Convert vector.scan op into arith ops and 2487 /// vector.insert_strided_slice/extract_strided_slice 2488 /// 2489 /// Ex: 2490 /// ``` 2491 /// %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 2492 /// 1} : 2493 /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>) 2494 /// ``` 2495 /// Gets converted to: 2496 /// ``` 2497 /// %cst = arith.constant dense<0> : vector<2x3xi32> 2498 /// %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1], 2499 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 = 2500 /// vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]} 2501 /// : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice 2502 /// %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : 2503 /// vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 : 2504 /// vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1], 2505 /// strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 = 2506 /// vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1], 2507 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3, 2508 /// %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets = 2509 /// [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 = 2510 /// vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 : 2511 /// vector<2x3xi32>, vector<2xi32> 2512 /// ``` 2513 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { 2514 using OpRewritePattern<vector::ScanOp>::OpRewritePattern; 2515 2516 LogicalResult matchAndRewrite(vector::ScanOp scanOp, 2517 PatternRewriter &rewriter) const override { 2518 auto loc = scanOp.getLoc(); 2519 VectorType destType = scanOp.getDestType(); 2520 ArrayRef<int64_t> destShape = destType.getShape(); 2521 auto elType = destType.getElementType(); 2522 bool isInt = elType.isIntOrIndex(); 2523 if (!isValidKind(isInt, scanOp.getKind())) 2524 return failure(); 2525 2526 VectorType resType = VectorType::get(destShape, elType); 2527 Value result = rewriter.create<arith::ConstantOp>( 2528 loc, resType, rewriter.getZeroAttr(resType)); 2529 int64_t reductionDim = scanOp.getReductionDim(); 2530 bool inclusive = scanOp.getInclusive(); 2531 int64_t destRank = destType.getRank(); 2532 VectorType initialValueType = scanOp.getInitialValueType(); 2533 int64_t initialValueRank = initialValueType.getRank(); 2534 2535 SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end()); 2536 reductionShape[reductionDim] = 1; 2537 VectorType reductionType = VectorType::get(reductionShape, elType); 2538 SmallVector<int64_t> offsets(destRank, 0); 2539 SmallVector<int64_t> strides(destRank, 1); 2540 SmallVector<int64_t> sizes(destShape.begin(), destShape.end()); 2541 sizes[reductionDim] = 1; 2542 ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes); 2543 ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides); 2544 2545 Value lastOutput, lastInput; 2546 for (int i = 0; i < destShape[reductionDim]; i++) { 2547 offsets[reductionDim] = i; 2548 ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets); 2549 Value input = rewriter.create<vector::ExtractStridedSliceOp>( 2550 loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes, 2551 scanStrides); 2552 Value output; 2553 if (i == 0) { 2554 if (inclusive) { 2555 output = input; 2556 } else { 2557 if (initialValueRank == 0) { 2558 // ShapeCastOp cannot handle 0-D vectors 2559 output = rewriter.create<vector::BroadcastOp>( 2560 loc, input.getType(), scanOp.getInitialValue()); 2561 } else { 2562 output = rewriter.create<vector::ShapeCastOp>( 2563 loc, input.getType(), scanOp.getInitialValue()); 2564 } 2565 } 2566 } else { 2567 Value y = inclusive ? input : lastInput; 2568 output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter); 2569 assert(output != nullptr); 2570 } 2571 result = rewriter.create<vector::InsertStridedSliceOp>( 2572 loc, output, result, offsets, strides); 2573 lastOutput = output; 2574 lastInput = input; 2575 } 2576 2577 Value reduction; 2578 if (initialValueRank == 0) { 2579 Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0); 2580 reduction = 2581 rewriter.create<vector::BroadcastOp>(loc, initialValueType, v); 2582 } else { 2583 reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType, 2584 lastOutput); 2585 } 2586 2587 rewriter.replaceOp(scanOp, {result, reduction}); 2588 return success(); 2589 } 2590 }; 2591 2592 } // namespace 2593 2594 void mlir::vector::populateVectorMaskMaterializationPatterns( 2595 RewritePatternSet &patterns, bool force32BitVectorIndices) { 2596 patterns.add<VectorCreateMaskOpConversion, 2597 MaterializeTransferMask<vector::TransferReadOp>, 2598 MaterializeTransferMask<vector::TransferWriteOp>>( 2599 patterns.getContext(), force32BitVectorIndices); 2600 } 2601 2602 void mlir::vector::populateShapeCastFoldingPatterns( 2603 RewritePatternSet &patterns) { 2604 patterns.add<ShapeCastOpFolder>(patterns.getContext()); 2605 } 2606 2607 void mlir::vector::populateBubbleVectorBitCastOpPatterns( 2608 RewritePatternSet &patterns) { 2609 patterns.add<BubbleDownVectorBitCastForExtract, 2610 BubbleDownBitCastForStridedSliceExtract, 2611 BubbleUpBitCastForStridedSliceInsert>(patterns.getContext()); 2612 } 2613 2614 void mlir::vector::populateVectorBroadcastLoweringPatterns( 2615 RewritePatternSet &patterns) { 2616 patterns.add<BroadcastOpLowering>(patterns.getContext()); 2617 } 2618 2619 void mlir::vector::populateVectorMaskOpLoweringPatterns( 2620 RewritePatternSet &patterns) { 2621 patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>( 2622 patterns.getContext()); 2623 } 2624 2625 void mlir::vector::populateVectorShapeCastLoweringPatterns( 2626 RewritePatternSet &patterns) { 2627 patterns.add<ShapeCastOp2DDownCastRewritePattern, 2628 ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>( 2629 patterns.getContext()); 2630 } 2631 2632 void mlir::vector::populateVectorContractLoweringPatterns( 2633 RewritePatternSet &patterns, VectorTransformsOptions options) { 2634 patterns.add<OuterProductOpLowering>(patterns.getContext()); 2635 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering, 2636 ContractionOpToOuterProductOpLowering>(options, 2637 patterns.getContext()); 2638 } 2639 2640 void mlir::vector::populateVectorTransposeLoweringPatterns( 2641 RewritePatternSet &patterns, VectorTransformsOptions options) { 2642 patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>( 2643 options, patterns.getContext()); 2644 } 2645 2646 void mlir::vector::populateVectorReductionToContractPatterns( 2647 RewritePatternSet &patterns) { 2648 patterns.add<MultiReduceToContract, CombineContractBroadcast, 2649 CombineContractTranspose, ReorderCastOpsOnBroadcast, 2650 ReorderCastOpsOnTranspose>(patterns.getContext()); 2651 } 2652 2653 void mlir::vector:: 2654 populateVectorTransferCollapseInnerMostContiguousDimsPatterns( 2655 RewritePatternSet &patterns) { 2656 patterns.add<DropInnerMostUnitDims>(patterns.getContext()); 2657 } 2658 2659 void mlir::vector::populateVectorTransferLoweringPatterns( 2660 RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) { 2661 patterns.add<TransferReadToVectorLoadLowering, 2662 TransferWriteToVectorStoreLowering>(patterns.getContext(), 2663 maxTransferRank); 2664 patterns 2665 .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>( 2666 patterns.getContext()); 2667 } 2668 2669 void mlir::vector::populateVectorScanLoweringPatterns( 2670 RewritePatternSet &patterns) { 2671 patterns.add<ScanToArithOps>(patterns.getContext()); 2672 } 2673