1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites as 1->N patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 14 15 #include <type_traits> 16 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 19 #include "mlir/Dialect/Linalg/IR/Linalg.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/SCF/SCF.h" 22 #include "mlir/Dialect/Utils/IndexingUtils.h" 23 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 24 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 25 #include "mlir/IR/ImplicitLocOpBuilder.h" 26 #include "mlir/IR/Matchers.h" 27 #include "mlir/IR/PatternMatch.h" 28 #include "mlir/Interfaces/VectorInterfaces.h" 29 30 #include "llvm/ADT/DenseSet.h" 31 #include "llvm/ADT/MapVector.h" 32 #include "llvm/ADT/STLExtras.h" 33 #include "llvm/Support/CommandLine.h" 34 #include "llvm/Support/Debug.h" 35 #include "llvm/Support/raw_ostream.h" 36 37 #define DEBUG_TYPE "vector-to-vector" 38 39 using namespace mlir; 40 using namespace mlir::vector; 41 42 // Helper to find an index in an affine map. 43 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) { 44 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 45 int64_t idx = map.getDimPosition(i); 46 if (idx == index) 47 return i; 48 } 49 return None; 50 } 51 52 // Helper to construct iterator types with one index removed. 53 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes, 54 int64_t index) { 55 SmallVector<Attribute, 4> results; 56 for (const auto &it : llvm::enumerate(iteratorTypes)) { 57 int64_t idx = it.index(); 58 if (idx == index) 59 continue; 60 results.push_back(it.value()); 61 } 62 return results; 63 } 64 65 // Helper to construct an affine map with one index removed. 66 static AffineMap adjustMap(AffineMap map, int64_t index, 67 PatternRewriter &rewriter) { 68 auto *ctx = rewriter.getContext(); 69 SmallVector<AffineExpr, 4> results; 70 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 71 int64_t idx = map.getDimPosition(i); 72 if (idx == index) 73 continue; 74 // Re-insert remaining indices, but renamed when occurring 75 // after the removed index. 76 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); 77 results.push_back(targetExpr); 78 } 79 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); 80 } 81 82 // Helper method to possibly drop a dimension in a load. 83 // TODO 84 static Value reshapeLoad(Location loc, Value val, VectorType type, 85 int64_t index, int64_t pos, 86 PatternRewriter &rewriter) { 87 if (index == -1) 88 return val; 89 Type lowType = VectorType::Builder(type).dropDim(0); 90 // At extraction dimension? 91 if (index == 0) { 92 auto posAttr = rewriter.getI64ArrayAttr(pos); 93 return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr); 94 } 95 // Unroll leading dimensions. 96 VectorType vType = lowType.cast<VectorType>(); 97 Type resType = VectorType::Builder(type).dropDim(index); 98 auto resVectorType = resType.cast<VectorType>(); 99 Value result = rewriter.create<arith::ConstantOp>( 100 loc, resVectorType, rewriter.getZeroAttr(resVectorType)); 101 for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { 102 auto posAttr = rewriter.getI64ArrayAttr(d); 103 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr); 104 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); 105 result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, 106 posAttr); 107 } 108 return result; 109 } 110 111 // Helper method to possibly drop a dimension in a store. 112 // TODO 113 static Value reshapeStore(Location loc, Value val, Value result, 114 VectorType type, int64_t index, int64_t pos, 115 PatternRewriter &rewriter) { 116 // Unmodified? 117 if (index == -1) 118 return val; 119 // At insertion dimension? 120 if (index == 0) { 121 auto posAttr = rewriter.getI64ArrayAttr(pos); 122 return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr); 123 } 124 // Unroll leading dimensions. 125 Type lowType = VectorType::Builder(type).dropDim(0); 126 VectorType vType = lowType.cast<VectorType>(); 127 Type insType = VectorType::Builder(vType).dropDim(0); 128 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { 129 auto posAttr = rewriter.getI64ArrayAttr(d); 130 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr); 131 Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr); 132 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); 133 result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr); 134 } 135 return result; 136 } 137 138 template <typename IntType> 139 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) { 140 return llvm::to_vector<4>(llvm::map_range( 141 arrayAttr.getAsRange<IntegerAttr>(), 142 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); 143 } 144 145 namespace { 146 147 /// ShapeCastOpFolder folds cancelling ShapeCastOps away. 148 // 149 // Example: 150 // 151 // The following MLIR with cancelling ShapeCastOps: 152 // 153 // %0 = source : vector<5x4x2xf32> 154 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> 155 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> 156 // %3 = user %2 : vector<5x4x2xf32> 157 // 158 // Should canonicalize to the following: 159 // 160 // %0 = source : vector<5x4x2xf32> 161 // %1 = user %0 : vector<5x4x2xf32> 162 // 163 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> { 164 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 165 166 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, 167 PatternRewriter &rewriter) const override { 168 // Check if 'shapeCastOp' has vector source/result type. 169 auto sourceVectorType = 170 shapeCastOp.source().getType().dyn_cast_or_null<VectorType>(); 171 auto resultVectorType = 172 shapeCastOp.result().getType().dyn_cast_or_null<VectorType>(); 173 if (!sourceVectorType || !resultVectorType) 174 return failure(); 175 176 // Check if shape cast op source operand is also a shape cast op. 177 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>( 178 shapeCastOp.source().getDefiningOp()); 179 if (!sourceShapeCastOp) 180 return failure(); 181 auto operandSourceVectorType = 182 sourceShapeCastOp.source().getType().cast<VectorType>(); 183 auto operandResultVectorType = sourceShapeCastOp.getType(); 184 185 // Check if shape cast operations invert each other. 186 if (operandSourceVectorType != resultVectorType || 187 operandResultVectorType != sourceVectorType) 188 return failure(); 189 190 rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source()); 191 return success(); 192 } 193 }; 194 195 /// Progressive lowering of BroadcastOp. 196 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> { 197 public: 198 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern; 199 200 LogicalResult matchAndRewrite(vector::BroadcastOp op, 201 PatternRewriter &rewriter) const override { 202 auto loc = op.getLoc(); 203 VectorType dstType = op.getVectorType(); 204 VectorType srcType = op.getSourceType().dyn_cast<VectorType>(); 205 Type eltType = dstType.getElementType(); 206 207 // Scalar to any vector can use splat. 208 if (!srcType) { 209 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.source()); 210 return success(); 211 } 212 213 // Determine rank of source and destination. 214 int64_t srcRank = srcType.getRank(); 215 int64_t dstRank = dstType.getRank(); 216 217 // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. 218 if (srcRank <= 1 && dstRank == 1) { 219 Value ext; 220 if (srcRank == 0) 221 ext = rewriter.create<vector::ExtractElementOp>(loc, op.source()); 222 else 223 ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0); 224 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext); 225 return success(); 226 } 227 228 // Duplicate this rank. 229 // For example: 230 // %x = broadcast %y : k-D to n-D, k < n 231 // becomes: 232 // %b = broadcast %y : k-D to (n-1)-D 233 // %x = [%b,%b,%b,%b] : n-D 234 // becomes: 235 // %b = [%y,%y] : (n-1)-D 236 // %x = [%b,%b,%b,%b] : n-D 237 if (srcRank < dstRank) { 238 // Duplication. 239 VectorType resType = 240 VectorType::get(dstType.getShape().drop_front(), eltType); 241 Value bcst = 242 rewriter.create<vector::BroadcastOp>(loc, resType, op.source()); 243 Value result = rewriter.create<arith::ConstantOp>( 244 loc, dstType, rewriter.getZeroAttr(dstType)); 245 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 246 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 247 rewriter.replaceOp(op, result); 248 return success(); 249 } 250 251 // Find non-matching dimension, if any. 252 assert(srcRank == dstRank); 253 int64_t m = -1; 254 for (int64_t r = 0; r < dstRank; r++) 255 if (srcType.getDimSize(r) != dstType.getDimSize(r)) { 256 m = r; 257 break; 258 } 259 260 // All trailing dimensions are the same. Simply pass through. 261 if (m == -1) { 262 rewriter.replaceOp(op, op.source()); 263 return success(); 264 } 265 266 // Any non-matching dimension forces a stretch along this rank. 267 // For example: 268 // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32> 269 // becomes: 270 // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32> 271 // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32> 272 // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32> 273 // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32> 274 // %x = [%a,%b,%c,%d] 275 // becomes: 276 // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32> 277 // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32> 278 // %a = [%u, %v] 279 // .. 280 // %x = [%a,%b,%c,%d] 281 VectorType resType = 282 VectorType::get(dstType.getShape().drop_front(), eltType); 283 Value result = rewriter.create<arith::ConstantOp>( 284 loc, dstType, rewriter.getZeroAttr(dstType)); 285 if (m == 0) { 286 // Stetch at start. 287 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0); 288 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 289 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 290 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 291 } else { 292 // Stetch not at start. 293 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) { 294 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d); 295 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 296 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 297 } 298 } 299 rewriter.replaceOp(op, result); 300 return success(); 301 } 302 }; 303 304 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not 305 /// transposed. 306 void pruneNonTransposedDims(ArrayRef<int64_t> transpose, 307 SmallVectorImpl<int64_t> &result) { 308 size_t numTransposedDims = transpose.size(); 309 for (size_t transpDim : llvm::reverse(transpose)) { 310 if (transpDim != numTransposedDims - 1) 311 break; 312 numTransposedDims--; 313 } 314 315 result.append(transpose.begin(), transpose.begin() + numTransposedDims); 316 } 317 318 /// Progressive lowering of TransposeOp. 319 /// One: 320 /// %x = vector.transpose %y, [1, 0] 321 /// is replaced by: 322 /// %z = arith.constant dense<0.000000e+00> 323 /// %0 = vector.extract %y[0, 0] 324 /// %1 = vector.insert %0, %z [0, 0] 325 /// .. 326 /// %x = vector.insert .., .. [.., ..] 327 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { 328 public: 329 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 330 331 TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions, 332 MLIRContext *context) 333 : OpRewritePattern<vector::TransposeOp>(context), 334 vectorTransformOptions(vectorTransformOptions) {} 335 336 LogicalResult matchAndRewrite(vector::TransposeOp op, 337 PatternRewriter &rewriter) const override { 338 auto loc = op.getLoc(); 339 340 Value input = op.vector(); 341 VectorType inputType = op.getVectorType(); 342 VectorType resType = op.getResultType(); 343 344 // Set up convenience transposition table. 345 SmallVector<int64_t, 4> transp; 346 for (auto attr : op.transp()) 347 transp.push_back(attr.cast<IntegerAttr>().getInt()); 348 349 if (vectorTransformOptions.vectorTransposeLowering == 350 vector::VectorTransposeLowering::Shuffle && 351 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) 352 return rewriter.notifyMatchFailure( 353 op, "Options specifies lowering to shuffle"); 354 355 // Handle a true 2-D matrix transpose differently when requested. 356 if (vectorTransformOptions.vectorTransposeLowering == 357 vector::VectorTransposeLowering::Flat && 358 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { 359 Type flattenedType = 360 VectorType::get(resType.getNumElements(), resType.getElementType()); 361 auto matrix = 362 rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input); 363 auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); 364 auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); 365 Value trans = rewriter.create<vector::FlatTransposeOp>( 366 loc, flattenedType, matrix, rows, columns); 367 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans); 368 return success(); 369 } 370 371 // Generate unrolled extract/insert ops. We do not unroll the rightmost 372 // (i.e., highest-order) dimensions that are not transposed and leave them 373 // in vector form to improve performance. Therefore, we prune those 374 // dimensions from the shape/transpose data structures used to generate the 375 // extract/insert ops. 376 SmallVector<int64_t, 4> prunedTransp; 377 pruneNonTransposedDims(transp, prunedTransp); 378 size_t numPrunedDims = transp.size() - prunedTransp.size(); 379 auto prunedInShape = inputType.getShape().drop_back(numPrunedDims); 380 SmallVector<int64_t, 4> ones(prunedInShape.size(), 1); 381 auto prunedInStrides = computeStrides(prunedInShape, ones); 382 383 // Generates the extract/insert operations for every scalar/vector element 384 // of the leftmost transposed dimensions. We traverse every transpose 385 // element using a linearized index that we delinearize to generate the 386 // appropriate indices for the extract/insert operations. 387 Value result = rewriter.create<arith::ConstantOp>( 388 loc, resType, rewriter.getZeroAttr(resType)); 389 int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape); 390 391 for (int64_t linearIdx = 0; linearIdx < numTransposedElements; 392 ++linearIdx) { 393 auto extractIdxs = delinearize(prunedInStrides, linearIdx); 394 SmallVector<int64_t, 4> insertIdxs(extractIdxs); 395 applyPermutationToVector(insertIdxs, prunedTransp); 396 Value extractOp = 397 rewriter.create<vector::ExtractOp>(loc, input, extractIdxs); 398 result = 399 rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs); 400 } 401 402 rewriter.replaceOp(op, result); 403 return success(); 404 } 405 406 private: 407 /// Options to control the vector patterns. 408 vector::VectorTransformsOptions vectorTransformOptions; 409 }; 410 411 /// Rewrite a 2-D vector.transpose as a sequence of: 412 /// vector.shape_cast 2D -> 1D 413 /// vector.shuffle 414 /// vector.shape_cast 1D -> 2D 415 class TransposeOp2DToShuffleLowering 416 : public OpRewritePattern<vector::TransposeOp> { 417 public: 418 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 419 420 TransposeOp2DToShuffleLowering( 421 vector::VectorTransformsOptions vectorTransformOptions, 422 MLIRContext *context) 423 : OpRewritePattern<vector::TransposeOp>(context), 424 vectorTransformOptions(vectorTransformOptions) {} 425 426 LogicalResult matchAndRewrite(vector::TransposeOp op, 427 PatternRewriter &rewriter) const override { 428 auto loc = op.getLoc(); 429 430 VectorType srcType = op.getVectorType(); 431 if (srcType.getRank() != 2) 432 return rewriter.notifyMatchFailure(op, "Not a 2D transpose"); 433 434 SmallVector<int64_t, 4> transp; 435 for (auto attr : op.transp()) 436 transp.push_back(attr.cast<IntegerAttr>().getInt()); 437 if (transp[0] != 1 && transp[1] != 0) 438 return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation"); 439 440 if (vectorTransformOptions.vectorTransposeLowering != 441 VectorTransposeLowering::Shuffle) 442 return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle"); 443 444 int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); 445 Value casted = rewriter.create<vector::ShapeCastOp>( 446 loc, VectorType::get({m * n}, srcType.getElementType()), op.vector()); 447 SmallVector<int64_t> mask; 448 mask.reserve(m * n); 449 for (int64_t j = 0; j < n; ++j) 450 for (int64_t i = 0; i < m; ++i) 451 mask.push_back(i * n + j); 452 453 Value shuffled = 454 rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask); 455 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(), 456 shuffled); 457 458 return success(); 459 } 460 461 private: 462 /// Options to control the vector patterns. 463 vector::VectorTransformsOptions vectorTransformOptions; 464 }; 465 466 /// Progressive lowering of OuterProductOp. 467 /// One: 468 /// %x = vector.outerproduct %lhs, %rhs, %acc 469 /// is replaced by: 470 /// %z = zero-result 471 /// %0 = vector.extract %lhs[0] 472 /// %1 = vector.broadcast %0 473 /// %2 = vector.extract %acc[0] 474 /// %3 = vector.fma %1, %rhs, %2 475 /// %4 = vector.insert %3, %z[0] 476 /// .. 477 /// %x = vector.insert %.., %..[N-1] 478 /// 479 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { 480 public: 481 using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern; 482 483 LogicalResult matchAndRewrite(vector::OuterProductOp op, 484 PatternRewriter &rewriter) const override { 485 auto loc = op.getLoc(); 486 487 VectorType lhsType = op.getOperandVectorTypeLHS(); 488 VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>(); 489 VectorType resType = op.getVectorType(); 490 Type eltType = resType.getElementType(); 491 bool isInt = eltType.isa<IntegerType, IndexType>(); 492 Value acc = (op.acc().empty()) ? nullptr : op.acc()[0]; 493 vector::CombiningKind kind = op.kind(); 494 495 if (!rhsType) { 496 // Special case: AXPY operation. 497 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs()); 498 Optional<Value> mult = 499 isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter) 500 : genMultF(loc, op.lhs(), b, acc, kind, rewriter); 501 if (!mult.hasValue()) 502 return failure(); 503 rewriter.replaceOp(op, mult.getValue()); 504 return success(); 505 } 506 507 Value result = rewriter.create<arith::ConstantOp>( 508 loc, resType, rewriter.getZeroAttr(resType)); 509 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { 510 auto pos = rewriter.getI64ArrayAttr(d); 511 Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos); 512 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x); 513 Value r = nullptr; 514 if (acc) 515 r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos); 516 Optional<Value> m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter) 517 : genMultF(loc, a, op.rhs(), r, kind, rewriter); 518 if (!m.hasValue()) 519 return failure(); 520 result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(), 521 result, pos); 522 } 523 rewriter.replaceOp(op, result); 524 return success(); 525 } 526 527 private: 528 static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc, 529 vector::CombiningKind kind, 530 PatternRewriter &rewriter) { 531 using vector::CombiningKind; 532 533 auto mul = rewriter.create<arith::MulIOp>(loc, x, y); 534 if (!acc) 535 return Optional<Value>(mul); 536 537 if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) 538 // Only valid for floating point types. 539 return Optional<Value>(); 540 541 return makeArithReduction(rewriter, loc, kind, mul, acc); 542 } 543 544 static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc, 545 vector::CombiningKind kind, 546 PatternRewriter &rewriter) { 547 using vector::CombiningKind; 548 549 // Special case for fused multiply-add. 550 if (acc && kind == CombiningKind::ADD) { 551 return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc)); 552 } 553 554 auto mul = rewriter.create<arith::MulFOp>(loc, x, y); 555 556 if (!acc) 557 return Optional<Value>(mul); 558 559 if (kind == CombiningKind::ADD || kind == CombiningKind::AND || 560 kind == CombiningKind::MINUI || kind == CombiningKind::MINSI || 561 kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI || 562 kind == CombiningKind::OR || kind == CombiningKind::XOR) 563 // Already handled or only valid for integer types. 564 return Optional<Value>(); 565 566 return makeArithReduction(rewriter, loc, kind, mul, acc); 567 } 568 }; 569 570 /// Progressive lowering of ConstantMaskOp. 571 /// One: 572 /// %x = vector.constant_mask [a,b] 573 /// is replaced by: 574 /// %z = zero-result 575 /// %l = vector.constant_mask [b] 576 /// %4 = vector.insert %l, %z[0] 577 /// .. 578 /// %x = vector.insert %l, %..[a-1] 579 /// until a one-dimensional vector is reached. All these operations 580 /// will be folded at LLVM IR level. 581 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { 582 public: 583 using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern; 584 585 LogicalResult matchAndRewrite(vector::ConstantMaskOp op, 586 PatternRewriter &rewriter) const override { 587 auto loc = op.getLoc(); 588 auto dstType = op.getType(); 589 auto eltType = dstType.getElementType(); 590 auto dimSizes = op.mask_dim_sizes(); 591 int64_t rank = dstType.getRank(); 592 593 if (rank == 0) { 594 assert(dimSizes.size() == 1 && 595 "Expected exactly one dim size for a 0-D vector"); 596 bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1; 597 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 598 op, dstType, 599 DenseIntElementsAttr::get( 600 VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()), 601 ArrayRef<bool>{value})); 602 return success(); 603 } 604 605 int64_t trueDim = std::min(dstType.getDimSize(0), 606 dimSizes[0].cast<IntegerAttr>().getInt()); 607 608 if (rank == 1) { 609 // Express constant 1-D case in explicit vector form: 610 // [T,..,T,F,..,F]. 611 SmallVector<bool, 4> values(dstType.getDimSize(0)); 612 for (int64_t d = 0; d < trueDim; d++) 613 values[d] = true; 614 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 615 op, dstType, rewriter.getBoolVectorAttr(values)); 616 return success(); 617 } 618 619 VectorType lowType = 620 VectorType::get(dstType.getShape().drop_front(), eltType); 621 SmallVector<int64_t, 4> newDimSizes; 622 for (int64_t r = 1; r < rank; r++) 623 newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt()); 624 Value trueVal = rewriter.create<vector::ConstantMaskOp>( 625 loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); 626 Value result = rewriter.create<arith::ConstantOp>( 627 loc, dstType, rewriter.getZeroAttr(dstType)); 628 for (int64_t d = 0; d < trueDim; d++) { 629 auto pos = rewriter.getI64ArrayAttr(d); 630 result = 631 rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos); 632 } 633 rewriter.replaceOp(op, result); 634 return success(); 635 } 636 }; 637 638 /// Progressive lowering of CreateMaskOp. 639 /// One: 640 /// %x = vector.create_mask %a, ... : vector<dx...> 641 /// is replaced by: 642 /// %l = vector.create_mask ... : vector<...> ; one lower rank 643 /// %0 = arith.cmpi "slt", %ci, %a | 644 /// %1 = select %0, %l, %zeroes | 645 /// %r = vector.insert %1, %pr [i] | d-times 646 /// %x = .... 647 /// until a one-dimensional vector is reached. 648 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { 649 public: 650 using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern; 651 652 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 653 PatternRewriter &rewriter) const override { 654 auto dstType = op.getResult().getType().cast<VectorType>(); 655 int64_t rank = dstType.getRank(); 656 if (rank <= 1) 657 return rewriter.notifyMatchFailure( 658 op, "0-D and 1-D vectors are handled separately"); 659 660 auto loc = op.getLoc(); 661 auto eltType = dstType.getElementType(); 662 int64_t dim = dstType.getDimSize(0); 663 Value idx = op.getOperand(0); 664 665 VectorType lowType = 666 VectorType::get(dstType.getShape().drop_front(), eltType); 667 Value trueVal = rewriter.create<vector::CreateMaskOp>( 668 loc, lowType, op.getOperands().drop_front()); 669 Value falseVal = rewriter.create<arith::ConstantOp>( 670 loc, lowType, rewriter.getZeroAttr(lowType)); 671 Value result = rewriter.create<arith::ConstantOp>( 672 loc, dstType, rewriter.getZeroAttr(dstType)); 673 for (int64_t d = 0; d < dim; d++) { 674 Value bnd = 675 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d)); 676 Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 677 bnd, idx); 678 Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal); 679 auto pos = rewriter.getI64ArrayAttr(d); 680 result = 681 rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos); 682 } 683 rewriter.replaceOp(op, result); 684 return success(); 685 } 686 }; 687 688 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D 689 /// vectors progressively on the way to target llvm.matrix intrinsics. 690 /// This iterates over the most major dimension of the 2-D vector and performs 691 /// rewrites into: 692 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D 693 class ShapeCastOp2DDownCastRewritePattern 694 : public OpRewritePattern<vector::ShapeCastOp> { 695 public: 696 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 697 698 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 699 PatternRewriter &rewriter) const override { 700 auto sourceVectorType = op.getSourceVectorType(); 701 auto resultVectorType = op.getResultVectorType(); 702 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) 703 return failure(); 704 705 auto loc = op.getLoc(); 706 Value desc = rewriter.create<arith::ConstantOp>( 707 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 708 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; 709 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { 710 Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i); 711 desc = rewriter.create<vector::InsertStridedSliceOp>( 712 loc, vec, desc, 713 /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); 714 } 715 rewriter.replaceOp(op, desc); 716 return success(); 717 } 718 }; 719 720 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D 721 /// vectors progressively. 722 /// This iterates over the most major dimension of the 2-D vector and performs 723 /// rewrites into: 724 /// vector.extract_strided_slice from 1-D + vector.insert into 2-D 725 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. 726 class ShapeCastOp2DUpCastRewritePattern 727 : public OpRewritePattern<vector::ShapeCastOp> { 728 public: 729 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 730 731 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 732 PatternRewriter &rewriter) const override { 733 auto sourceVectorType = op.getSourceVectorType(); 734 auto resultVectorType = op.getResultVectorType(); 735 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) 736 return failure(); 737 738 auto loc = op.getLoc(); 739 Value desc = rewriter.create<arith::ConstantOp>( 740 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 741 unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; 742 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { 743 Value vec = rewriter.create<vector::ExtractStridedSliceOp>( 744 loc, op.source(), /*offsets=*/i * mostMinorVectorSize, 745 /*sizes=*/mostMinorVectorSize, 746 /*strides=*/1); 747 desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i); 748 } 749 rewriter.replaceOp(op, desc); 750 return success(); 751 } 752 }; 753 754 // We typically should not lower general shape cast operations into data 755 // movement instructions, since the assumption is that these casts are 756 // optimized away during progressive lowering. For completeness, however, 757 // we fall back to a reference implementation that moves all elements 758 // into the right place if we get here. 759 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { 760 public: 761 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 762 763 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 764 PatternRewriter &rewriter) const override { 765 Location loc = op.getLoc(); 766 auto sourceVectorType = op.getSourceVectorType(); 767 auto resultVectorType = op.getResultVectorType(); 768 769 // Special case 2D/1D lowerings with better implementations. 770 // TODO: make is ND/1D to allow generic ND->1D->MD. 771 int64_t srcRank = sourceVectorType.getRank(); 772 int64_t resRank = resultVectorType.getRank(); 773 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2)) 774 return failure(); 775 776 // Generic ShapeCast lowering path goes all the way down to unrolled scalar 777 // extract/insert chains. 778 // TODO: consider evolving the semantics to only allow 1D source or dest and 779 // drop this potentially very expensive lowering. 780 // Compute number of elements involved in the reshape. 781 int64_t numElts = 1; 782 for (int64_t r = 0; r < srcRank; r++) 783 numElts *= sourceVectorType.getDimSize(r); 784 // Replace with data movement operations: 785 // x[0,0,0] = y[0,0] 786 // x[0,0,1] = y[0,1] 787 // x[0,1,0] = y[0,2] 788 // etc., incrementing the two index vectors "row-major" 789 // within the source and result shape. 790 SmallVector<int64_t, 4> srcIdx(srcRank); 791 SmallVector<int64_t, 4> resIdx(resRank); 792 Value result = rewriter.create<arith::ConstantOp>( 793 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 794 for (int64_t i = 0; i < numElts; i++) { 795 if (i != 0) { 796 incIdx(srcIdx, sourceVectorType, srcRank - 1); 797 incIdx(resIdx, resultVectorType, resRank - 1); 798 } 799 Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx); 800 result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx); 801 } 802 rewriter.replaceOp(op, result); 803 return success(); 804 } 805 806 private: 807 static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) { 808 assert(0 <= r && r < tp.getRank()); 809 if (++idx[r] == tp.getDimSize(r)) { 810 idx[r] = 0; 811 incIdx(idx, tp, r - 1); 812 } 813 } 814 }; 815 816 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp. 817 /// Ex: 818 /// ``` 819 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> 820 /// %1 = vector.multi_reduction add, %0 [1] 821 /// : vector<8x32x16xf32> to vector<8x16xf32> 822 /// ``` 823 /// Gets converted to: 824 /// ``` 825 /// %1 = vector.contract {indexing_maps = [ 826 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 827 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 828 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 829 /// iterator_types = ["parallel", "parallel", "reduction"], 830 /// kind = add} %0, %arg1, %cst_f0 831 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 832 /// ``` 833 struct MultiReduceToContract 834 : public OpRewritePattern<vector::MultiDimReductionOp> { 835 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 836 837 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, 838 PatternRewriter &rewriter) const override { 839 if (reduceOp.kind() != vector::CombiningKind::ADD) 840 return failure(); 841 Operation *mulOp = reduceOp.source().getDefiningOp(); 842 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp)) 843 return failure(); 844 SmallVector<bool> reductionMask = reduceOp.getReductionMask(); 845 auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size()); 846 SmallVector<AffineExpr> exprs; 847 SmallVector<StringRef> iteratorTypes; 848 for (const auto &isReduceDim : llvm::enumerate(reductionMask)) { 849 if (!isReduceDim.value()) { 850 iteratorTypes.push_back(getParallelIteratorTypeName()); 851 exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index())); 852 } else { 853 iteratorTypes.push_back(getReductionIteratorTypeName()); 854 } 855 } 856 auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(), 857 /*symCount=*/0, exprs, reduceOp.getContext()); 858 Value zero = rewriter.create<arith::ConstantOp>( 859 reduceOp.getLoc(), reduceOp.getDestType(), 860 rewriter.getZeroAttr(reduceOp.getDestType())); 861 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>( 862 reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero, 863 rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}), 864 rewriter.getStrArrayAttr(iteratorTypes)); 865 return success(); 866 } 867 }; 868 869 /// Merge TransposeOp into ContractionOp user. 870 /// Ex: 871 /// ``` 872 /// %0 = vector.transpose %arg0, [2, 0, 1] 873 /// : vector<32x16x8xf32> to vector<8x32x16xf32> 874 /// %1 = vector.contract {indexing_maps = [ 875 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 876 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 877 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 878 /// iterator_types = ["parallel", "parallel", "reduction"], 879 /// kind = add} %0, %arg1, %cst_f0 880 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 881 /// ``` 882 /// Gets converted to: 883 /// ``` 884 /// %1 = vector.contract {indexing_maps = [ 885 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>, 886 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 887 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 888 /// iterator_types = ["parallel", "parallel", "reduction"], 889 /// kind = add} %arg0, %arg1, %cst_f0 890 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 891 /// ``` 892 struct CombineContractTranspose 893 : public OpRewritePattern<vector::ContractionOp> { 894 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 895 896 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 897 PatternRewriter &rewriter) const override { 898 SmallVector<AffineMap, 4> maps = 899 llvm::to_vector<4>(contractOp.getIndexingMaps()); 900 Value lhs = contractOp.lhs(); 901 Value rhs = contractOp.rhs(); 902 size_t index = 0; 903 bool changed = false; 904 for (Value *operand : {&lhs, &rhs}) { 905 AffineMap &map = maps[index++]; 906 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>(); 907 if (!transposeOp) 908 continue; 909 SmallVector<int64_t> perm; 910 transposeOp.getTransp(perm); 911 AffineMap permutationMap = AffineMap::getPermutationMap( 912 extractVector<unsigned>(transposeOp.transp()), 913 contractOp.getContext()); 914 map = inversePermutation(permutationMap).compose(map); 915 *operand = transposeOp.vector(); 916 changed = true; 917 } 918 if (!changed) 919 return failure(); 920 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 921 contractOp, lhs, rhs, contractOp.acc(), 922 rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); 923 return success(); 924 } 925 }; 926 927 /// Merge BroadcastOp into ContractionOp user. 928 /// Ex: 929 /// ``` 930 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> 931 /// %1 = vector.contract {indexing_maps = [ 932 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 933 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 934 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 935 /// iterator_types = ["parallel", "parallel", "reduction"], 936 /// kind = add} %0, %arg1, %cst_f0 937 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 938 /// ``` 939 /// Gets converted to: 940 /// ``` 941 /// %1 = vector.contract {indexing_maps = [ 942 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 943 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 944 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 945 /// iterator_types = ["parallel", "parallel", "reduction"], 946 /// kind = add} %arg0, %arg1, %cst_f0 947 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 948 /// ``` 949 struct CombineContractBroadcast 950 : public OpRewritePattern<vector::ContractionOp> { 951 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 952 953 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 954 PatternRewriter &rewriter) const override { 955 SmallVector<AffineMap, 4> maps = 956 llvm::to_vector<4>(contractOp.getIndexingMaps()); 957 Value lhs = contractOp.lhs(); 958 Value rhs = contractOp.rhs(); 959 size_t index = 0; 960 bool changed = false; 961 for (Value *operand : {&lhs, &rhs}) { 962 AffineMap &map = maps[index++]; 963 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>(); 964 if (!broadcast) 965 continue; 966 // contractionOp can only take vector as operands. 967 auto srcType = broadcast.getSourceType().dyn_cast<VectorType>(); 968 if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank()) 969 continue; 970 int64_t rankDiff = 971 broadcast.getVectorType().getRank() - srcType.getRank(); 972 bool innerDimBroadcast = false; 973 SmallVector<AffineExpr> originalDims; 974 for (const auto &dim : llvm::enumerate(srcType.getShape())) { 975 if (dim.value() != 976 broadcast.getVectorType().getDimSize(rankDiff + dim.index())) { 977 innerDimBroadcast = true; 978 break; 979 } 980 originalDims.push_back( 981 rewriter.getAffineDimExpr(dim.index() + rankDiff)); 982 } 983 // Contract doesn't support inner dimension broadcast. Once this is 984 // relaxed we can remove this case. 985 if (innerDimBroadcast) 986 continue; 987 AffineMap broadcastMap = 988 AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims, 989 contractOp.getContext()); 990 map = broadcastMap.compose(map); 991 *operand = broadcast.source(); 992 changed = true; 993 } 994 if (!changed) 995 return failure(); 996 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 997 contractOp, lhs, rhs, contractOp.acc(), 998 rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); 999 return success(); 1000 } 1001 }; 1002 1003 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and 1004 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when 1005 /// casting ops are around these operations. 1006 /// Ex: 1007 /// ``` 1008 /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8> 1009 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> 1010 /// ``` 1011 /// Gets converted to: 1012 /// ``` 1013 /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32> 1014 /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32> 1015 /// ``` 1016 struct ReorderCastOpsOnBroadcast 1017 : public OpInterfaceRewritePattern<CastOpInterface> { 1018 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern; 1019 1020 LogicalResult matchAndRewrite(CastOpInterface op, 1021 PatternRewriter &rewriter) const override { 1022 if (op->getNumOperands() != 1) 1023 return failure(); 1024 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>(); 1025 if (!bcastOp) 1026 return failure(); 1027 1028 Type castResTy = getElementTypeOrSelf(op->getResult(0)); 1029 if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>()) 1030 castResTy = VectorType::get(vecTy.getShape(), castResTy); 1031 OperationState state(op->getLoc(), op->getName(), bcastOp.source(), 1032 castResTy, op->getAttrs()); 1033 auto castOp = rewriter.createOperation(state); 1034 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1035 op, op->getResult(0).getType(), castOp->getResult(0)); 1036 return success(); 1037 } 1038 }; 1039 1040 /// Reorders cast(transpose) to transpose(cast). This makes broadcast ops and 1041 /// contraction ops closer, which kicks in CombineContractTranspose pattern when 1042 /// casting ops are around these operations. 1043 /// Ex: 1044 /// ``` 1045 /// %0 = vector.transpose %arg0, [2, 0, 1] 1046 /// : vector<32x16x8xi8> to vector<8x32x16xi8> 1047 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> 1048 /// ``` 1049 /// Gets converted to: 1050 /// ``` 1051 /// %0 = arith.extsi %0 : vector<32x16x8xi8> to vector<32x16x8xi32> 1052 /// %1 = vector.transpose %arg0, [2, 0, 1] 1053 /// : vector<32x16x8xi32> to vector<8x32x16xi32> 1054 /// ``` 1055 struct ReorderCastOpsOnTranspose 1056 : public OpInterfaceRewritePattern<CastOpInterface> { 1057 1058 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern; 1059 1060 LogicalResult matchAndRewrite(CastOpInterface op, 1061 PatternRewriter &rewriter) const override { 1062 if (op->getNumOperands() != 1) 1063 return failure(); 1064 auto transpOp = op->getOperand(0).getDefiningOp<vector::TransposeOp>(); 1065 if (!transpOp) 1066 return failure(); 1067 1068 auto castResTy = transpOp.getVectorType(); 1069 castResTy = VectorType::get(castResTy.getShape(), 1070 getElementTypeOrSelf(op->getResult(0))); 1071 OperationState state(op->getLoc(), op->getName(), transpOp.vector(), 1072 castResTy, op->getAttrs()); 1073 auto castOp = rewriter.createOperation(state); 1074 rewriter.replaceOpWithNewOp<vector::TransposeOp>( 1075 op, op->getResult(0).getType(), castOp->getResult(0), 1076 transpOp.getTransp()); 1077 return success(); 1078 } 1079 }; 1080 1081 } // namespace 1082 1083 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using 1084 /// operands `x` and `y`. 1085 static Value createAdd(Location loc, Value x, Value y, bool isInt, 1086 PatternRewriter &rewriter) { 1087 if (isInt) 1088 return rewriter.create<arith::AddIOp>(loc, x, y); 1089 return rewriter.create<arith::AddFOp>(loc, x, y); 1090 } 1091 1092 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using 1093 /// operands `x and `y`. 1094 static Value createMul(Location loc, Value x, Value y, bool isInt, 1095 PatternRewriter &rewriter) { 1096 if (isInt) 1097 return rewriter.create<arith::MulIOp>(loc, x, y); 1098 return rewriter.create<arith::MulFOp>(loc, x, y); 1099 } 1100 1101 namespace mlir { 1102 1103 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1104 /// semantics to: 1105 /// ``` 1106 /// %mta = maybe_transpose 1107 /// %mtb = maybe_transpose 1108 /// %flattened_a = vector.shape_cast %mta 1109 /// %flattened_b = vector.shape_cast %mtb 1110 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 1111 /// %mtd = vector.shape_cast %flattened_d 1112 /// %d = maybe_untranspose %mtd 1113 /// %e = add %c, %d 1114 /// ``` 1115 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 1116 // 1117 /// This only kicks in when VectorTransformsOptions is set to `Matmul`. 1118 /// vector.transpose operations are inserted if the vector.contract op is not a 1119 /// row-major matrix multiply. 1120 LogicalResult 1121 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, 1122 PatternRewriter &rew) const { 1123 // TODO: implement masks 1124 if (llvm::size(op.masks()) != 0) 1125 return failure(); 1126 if (vectorTransformOptions.vectorContractLowering != 1127 vector::VectorContractLowering::Matmul) 1128 return failure(); 1129 if (failed(filter(op))) 1130 return failure(); 1131 1132 auto iteratorTypes = op.iterator_types().getValue(); 1133 if (!isParallelIterator(iteratorTypes[0]) || 1134 !isParallelIterator(iteratorTypes[1]) || 1135 !isReductionIterator(iteratorTypes[2])) 1136 return failure(); 1137 1138 Type elementType = op.getLhsType().getElementType(); 1139 if (!elementType.isIntOrFloat()) 1140 return failure(); 1141 1142 // Perform lhs + rhs transpositions to conform to matmul row-major semantics. 1143 // Bail out if the contraction cannot be put in this form. 1144 MLIRContext *ctx = op.getContext(); 1145 Location loc = op.getLoc(); 1146 AffineExpr m, n, k; 1147 bindDims(rew.getContext(), m, n, k); 1148 // LHS must be A(m, k) or A(k, m). 1149 Value lhs = op.lhs(); 1150 auto lhsMap = op.indexing_maps()[0]; 1151 if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) 1152 lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0}); 1153 else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) 1154 return failure(); 1155 1156 // RHS must be B(k, n) or B(n, k). 1157 Value rhs = op.rhs(); 1158 auto rhsMap = op.indexing_maps()[1]; 1159 if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) 1160 rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0}); 1161 else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) 1162 return failure(); 1163 1164 // At this point lhs and rhs are in row-major. 1165 VectorType lhsType = lhs.getType().cast<VectorType>(); 1166 VectorType rhsType = rhs.getType().cast<VectorType>(); 1167 int64_t lhsRows = lhsType.getDimSize(0); 1168 int64_t lhsColumns = lhsType.getDimSize(1); 1169 int64_t rhsColumns = rhsType.getDimSize(1); 1170 1171 Type flattenedLHSType = 1172 VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); 1173 lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs); 1174 1175 Type flattenedRHSType = 1176 VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); 1177 rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs); 1178 1179 Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns, 1180 rhsColumns); 1181 mul = rew.create<vector::ShapeCastOp>( 1182 loc, 1183 VectorType::get({lhsRows, rhsColumns}, 1184 getElementTypeOrSelf(op.acc().getType())), 1185 mul); 1186 1187 // ACC must be C(m, n) or C(n, m). 1188 auto accMap = op.indexing_maps()[2]; 1189 if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) 1190 mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0}); 1191 else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) 1192 llvm_unreachable("invalid contraction semantics"); 1193 1194 Value res = 1195 elementType.isa<IntegerType>() 1196 ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.acc(), mul)) 1197 : static_cast<Value>(rew.create<arith::AddFOp>(loc, op.acc(), mul)); 1198 1199 rew.replaceOp(op, res); 1200 return success(); 1201 } 1202 1203 namespace { 1204 struct IteratorType { 1205 IteratorType(StringRef strRef) : strRef(strRef) {} 1206 bool isOfType(Attribute attr) const { 1207 auto sAttr = attr.dyn_cast<StringAttr>(); 1208 return sAttr && sAttr.getValue() == strRef; 1209 } 1210 StringRef strRef; 1211 }; 1212 struct Par : public IteratorType { 1213 Par() : IteratorType(getParallelIteratorTypeName()) {} 1214 }; 1215 struct Red : public IteratorType { 1216 Red() : IteratorType(getReductionIteratorTypeName()) {} 1217 }; 1218 1219 /// Generate a vector implementation for matmat, matvec and tmatvec. 1220 /// This unrolls outer-products along the reduction dimension. 1221 struct UnrolledOuterProductGenerator 1222 : public StructuredGenerator<vector::ContractionOp> { 1223 1224 UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op) 1225 : StructuredGenerator<vector::ContractionOp>(builder, op), 1226 kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()), 1227 lhsType(op.getLhsType()) {} 1228 1229 Value t(Value v) { 1230 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1231 return builder.create<vector::TransposeOp>(loc, v, perm); 1232 } 1233 1234 Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) { 1235 assert(reductionSize > 0); 1236 for (int64_t k = 0; k < reductionSize; ++k) { 1237 Value a = builder.create<vector::ExtractOp>(loc, lhs, k); 1238 Value b = builder.create<vector::ExtractOp>(loc, rhs, k); 1239 res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b, 1240 res, kind); 1241 } 1242 return res; 1243 } 1244 1245 /// Two outer parallel, one inner reduction (matmat flavor). 1246 FailureOr<Value> matmat() { 1247 if (!iters({Par(), Par(), Red()})) 1248 return failure(); 1249 // Set up the parallel/reduction structure in the right form. 1250 AffineExpr m, n, k; 1251 bindDims(builder.getContext(), m, n, k); 1252 // Classical row-major matmul: Just permute the lhs. 1253 if (layout({{m, k}, {k, n}, {m, n}})) 1254 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1255 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1256 if (layout({{m, k}, {n, k}, {m, n}})) { 1257 Value tlhs = t(lhs); 1258 return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1)); 1259 } 1260 // No need to permute anything. 1261 if (layout({{k, m}, {k, n}, {m, n}})) 1262 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1263 // Just permute the rhs. 1264 if (layout({{k, m}, {n, k}, {m, n}})) 1265 return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0)); 1266 // Transposed output: swap RHS and LHS. 1267 // Classical row-major matmul: permute the lhs. 1268 if (layout({{m, k}, {k, n}, {n, m}})) 1269 return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1)); 1270 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1271 if (layout({{m, k}, {n, k}, {n, m}})) { 1272 Value trhs = t(rhs); 1273 return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1)); 1274 } 1275 if (layout({{k, m}, {k, n}, {n, m}})) 1276 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1277 if (layout({{k, m}, {n, k}, {n, m}})) 1278 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1279 return failure(); 1280 } 1281 1282 /// One outer parallel, one inner reduction (matvec flavor) 1283 FailureOr<Value> matvec() { 1284 if (!iters({Par(), Red()})) 1285 return failure(); 1286 AffineExpr m, k; 1287 bindDims(builder.getContext(), m, k); 1288 1289 // Case mat-vec: transpose. 1290 if (layout({{m, k}, {k}, {m}})) 1291 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1292 // Case mat-trans-vec: ready to go. 1293 if (layout({{k, m}, {k}, {m}})) 1294 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1295 // Case vec-mat: swap and transpose. 1296 if (layout({{k}, {m, k}, {m}})) 1297 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1298 // Case vec-mat-trans: swap and ready to go. 1299 if (layout({{k}, {k, m}, {m}})) 1300 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1301 return failure(); 1302 } 1303 1304 // 1305 // One outer reduction, one inner parallel (tmatvec flavor) 1306 // 1307 FailureOr<Value> tmatvec() { 1308 if (!iters({Red(), Par()})) 1309 return failure(); 1310 AffineExpr k, m; 1311 bindDims(builder.getContext(), k, m); 1312 1313 // Case mat-vec: transpose. 1314 if (layout({{m, k}, {k}, {m}})) 1315 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1316 // Case mat-trans-vec: ready to go. 1317 if (layout({{k, m}, {k}, {m}})) 1318 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1319 // Case vec-mat: swap and transpose. 1320 if (layout({{k}, {m, k}, {m}})) 1321 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1322 // Case vec-mat-trans: swap and ready to go. 1323 if (layout({{k}, {k, m}, {m}})) 1324 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1325 return failure(); 1326 } 1327 1328 private: 1329 vector::CombiningKind kind; 1330 Value lhs, rhs, res; 1331 VectorType lhsType; 1332 }; 1333 } // namespace 1334 1335 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1336 /// semantics to a reduction_size-unrolled sequence: 1337 /// ``` 1338 /// %at = vector.transpose %a, [1, 0] 1339 /// %bRow0 = vector.extract %b[0] 1340 /// %atRow0 = vector.extract %at[0] 1341 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 1342 /// ... 1343 /// %bRowK = vector.extract %b[K] 1344 /// %atRowK = vector.extract %at[K] 1345 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 1346 /// ``` 1347 /// 1348 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but 1349 /// otherwise supports any layout permutation of the matrix-multiply. 1350 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( 1351 vector::ContractionOp op, PatternRewriter &rewriter) const { 1352 // TODO: implement masks 1353 if (llvm::size(op.masks()) != 0) 1354 return failure(); 1355 1356 if (vectorTransformOptions.vectorContractLowering != 1357 vector::VectorContractLowering::OuterProduct) 1358 return failure(); 1359 1360 if (failed(filter(op))) 1361 return failure(); 1362 1363 UnrolledOuterProductGenerator e(rewriter, op); 1364 FailureOr<Value> matmatRes = e.matmat(); 1365 if (succeeded(matmatRes)) { 1366 rewriter.replaceOp(op, *matmatRes); 1367 return success(); 1368 } 1369 FailureOr<Value> matvecRes = e.matvec(); 1370 if (succeeded(matvecRes)) { 1371 rewriter.replaceOp(op, *matvecRes); 1372 return success(); 1373 } 1374 FailureOr<Value> tmatvecRes = e.tmatvec(); 1375 if (succeeded(tmatvecRes)) { 1376 rewriter.replaceOp(op, *tmatvecRes); 1377 return success(); 1378 } 1379 1380 return failure(); 1381 } 1382 1383 LogicalResult 1384 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, 1385 PatternRewriter &rewriter) const { 1386 // TODO: implement masks 1387 if (llvm::size(op.masks()) != 0) 1388 return failure(); 1389 1390 if (failed(filter(op))) 1391 return failure(); 1392 1393 if (vectorTransformOptions.vectorContractLowering != 1394 vector::VectorContractLowering::Dot) 1395 return failure(); 1396 1397 auto iteratorTypes = op.iterator_types().getValue(); 1398 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1399 Location loc = op.getLoc(); 1400 Value lhs = op.lhs(), rhs = op.rhs(); 1401 1402 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1403 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 1404 AffineExpr m, n, k; 1405 bindDims(rewriter.getContext(), m, n, k); 1406 SmallVector<AffineMap, 4> maps = op.getIndexingMaps(); 1407 // 1408 // In the following we wish to make the reduction dimension innermost so we 1409 // can load vectors and just fmul + reduce into a scalar. 1410 // 1411 if (isParallelIterator(iteratorTypes[0]) && 1412 isParallelIterator(iteratorTypes[1]) && 1413 isReductionIterator(iteratorTypes[2])) { 1414 // 1415 // Two outer parallel, one inner reduction (matmat flavor). 1416 // 1417 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 1418 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1419 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { 1420 // No need to permute anything. 1421 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 1422 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1423 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1424 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 1425 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1426 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 1427 // This is the classical row-major matmul. Just permute the lhs. 1428 Value tmp = lhs; 1429 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1430 rhs = tmp; 1431 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 1432 std::swap(lhs, rhs); 1433 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 1434 Value tmp = lhs; 1435 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1436 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm); 1437 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 1438 Value tmp = rhs; 1439 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1440 lhs = tmp; 1441 } else { 1442 return failure(); 1443 } 1444 } else if (isParallelIterator(iteratorTypes[0]) && 1445 isReductionIterator(iteratorTypes[1])) { 1446 // 1447 // One outer parallel, one inner reduction (matvec flavor) 1448 // 1449 if (maps == infer({{m, n}, {n}, {m}})) { 1450 // No need to permute anything. 1451 } else if (maps == infer({{n, m}, {n}, {m}})) { 1452 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1453 } else if (maps == infer({{n}, {m, n}, {m}})) { 1454 std::swap(lhs, rhs); 1455 } else if (maps == infer({{n}, {n, m}, {m}})) { 1456 std::swap(lhs, rhs); 1457 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1458 } else { 1459 return failure(); 1460 } 1461 } else { 1462 return failure(); 1463 } 1464 1465 VectorType dstType = op.getResultType().cast<VectorType>(); 1466 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && 1467 "Expected dst type of rank 1 or 2"); 1468 1469 unsigned rank = dstType.getRank(); 1470 unsigned dstRows = dstType.getShape()[0]; 1471 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; 1472 1473 // ExtractOp does not allow dynamic indexing, we must unroll explicitly. 1474 Value res = rewriter.create<arith::ConstantOp>(loc, dstType, 1475 rewriter.getZeroAttr(dstType)); 1476 bool isInt = dstType.getElementType().isa<IntegerType>(); 1477 for (unsigned r = 0; r < dstRows; ++r) { 1478 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r); 1479 for (unsigned c = 0; c < dstColumns; ++c) { 1480 Value b = rank == 1 1481 ? rhs 1482 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c); 1483 Value m = createMul(op.getLoc(), a, b, isInt, rewriter); 1484 Value reduced = rewriter.create<vector::ReductionOp>( 1485 op.getLoc(), vector::CombiningKind::ADD, m); 1486 1487 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r} 1488 : SmallVector<int64_t, 2>{r, c}; 1489 res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos); 1490 } 1491 } 1492 if (auto acc = op.acc()) 1493 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1494 rewriter.replaceOp(op, res); 1495 return success(); 1496 } 1497 1498 /// Progressive lowering of ContractionOp. 1499 /// One: 1500 /// %x = vector.contract with at least one free/batch dimension 1501 /// is replaced by: 1502 /// %a = vector.contract with one less free/batch dimension 1503 /// %b = vector.contract with one less free/batch dimension 1504 /// .. 1505 /// %x = combine %a %b .. 1506 /// until a pure contraction is reached (no free/batch dimensions), 1507 /// which is replaced by a dot-product. 1508 /// 1509 /// This only kicks in when either VectorTransformsOptions is set 1510 /// to DOT or when other contraction patterns fail. 1511 // 1512 // TODO: break down into transpose/reshape/cast ops 1513 // when they become available to avoid code dup 1514 // TODO: investigate lowering order impact on performance 1515 LogicalResult 1516 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, 1517 PatternRewriter &rewriter) const { 1518 // TODO: implement masks. 1519 if (llvm::size(op.masks()) != 0) 1520 return failure(); 1521 1522 if (failed(filter(op))) 1523 return failure(); 1524 1525 // TODO: support mixed mode contract lowering. 1526 if (op.getLhsType().getElementType() != 1527 getElementTypeOrSelf(op.getAccType()) || 1528 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) 1529 return failure(); 1530 1531 // TODO: implement benefits, cost models. 1532 MLIRContext *ctx = op.getContext(); 1533 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); 1534 if (succeeded(pat1.matchAndRewrite(op, rewriter))) 1535 return success(); 1536 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); 1537 if (succeeded(pat2.matchAndRewrite(op, rewriter))) 1538 return success(); 1539 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); 1540 if (succeeded(pat3.matchAndRewrite(op, rewriter))) 1541 return success(); 1542 1543 // Find first batch dimension in LHS/RHS, and lower when found. 1544 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap(); 1545 if (!batchDimMap.empty()) { 1546 int64_t lhsIndex = batchDimMap[0].first; 1547 int64_t rhsIndex = batchDimMap[0].second; 1548 rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter)); 1549 return success(); 1550 } 1551 1552 // Collect contracting dimensions. 1553 std::vector<std::pair<int64_t, int64_t>> contractingDimMap = 1554 op.getContractingDimMap(); 1555 DenseSet<int64_t> lhsContractingDimSet; 1556 DenseSet<int64_t> rhsContractingDimSet; 1557 for (auto &dimPair : contractingDimMap) { 1558 lhsContractingDimSet.insert(dimPair.first); 1559 rhsContractingDimSet.insert(dimPair.second); 1560 } 1561 1562 // Find first free dimension in LHS, and lower when found. 1563 VectorType lhsType = op.getLhsType(); 1564 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { 1565 if (lhsContractingDimSet.count(lhsIndex) == 0) { 1566 rewriter.replaceOp( 1567 op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter)); 1568 return success(); 1569 } 1570 } 1571 1572 // Find first free dimension in RHS, and lower when found. 1573 VectorType rhsType = op.getRhsType(); 1574 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { 1575 if (rhsContractingDimSet.count(rhsIndex) == 0) { 1576 rewriter.replaceOp( 1577 op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter)); 1578 return success(); 1579 } 1580 } 1581 1582 // Lower the first remaining reduction dimension. 1583 if (!contractingDimMap.empty()) { 1584 rewriter.replaceOp(op, lowerReduction(op, rewriter)); 1585 return success(); 1586 } 1587 1588 return failure(); 1589 } 1590 1591 // Lower one parallel dimension. 1592 // TODO: consider reusing existing contract unrolling 1593 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op, 1594 int64_t lhsIndex, int64_t rhsIndex, 1595 PatternRewriter &rewriter) const { 1596 VectorType lhsType = op.getLhsType(); 1597 VectorType rhsType = op.getRhsType(); 1598 VectorType resType = op.getResultType().cast<VectorType>(); 1599 // Find the iterator type index and result index. 1600 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1601 int64_t iterIndex = -1; 1602 int64_t dimSize = -1; 1603 if (lhsIndex >= 0) { 1604 iterIndex = iMap[0].getDimPosition(lhsIndex); 1605 assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) && 1606 "parallel index should be free in LHS or batch in LHS/RHS"); 1607 dimSize = lhsType.getDimSize(lhsIndex); 1608 } else { 1609 assert(rhsIndex >= 0 && "missing parallel index"); 1610 iterIndex = iMap[1].getDimPosition(rhsIndex); 1611 dimSize = rhsType.getDimSize(rhsIndex); 1612 } 1613 assert(iterIndex >= 0 && "parallel index not listed in operand mapping"); 1614 Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex); 1615 assert(lookup.hasValue() && "parallel index not listed in reduction"); 1616 int64_t resIndex = lookup.getValue(); 1617 // Construct new iterator types and affine map array attribute. 1618 std::array<AffineMap, 3> lowIndexingMaps = { 1619 adjustMap(iMap[0], iterIndex, rewriter), 1620 adjustMap(iMap[1], iterIndex, rewriter), 1621 adjustMap(iMap[2], iterIndex, rewriter)}; 1622 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1623 auto lowIter = 1624 rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); 1625 // Unroll into a series of lower dimensional vector.contract ops. 1626 Location loc = op.getLoc(); 1627 Value result = rewriter.create<arith::ConstantOp>( 1628 loc, resType, rewriter.getZeroAttr(resType)); 1629 for (int64_t d = 0; d < dimSize; ++d) { 1630 auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); 1631 auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); 1632 auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter); 1633 Value lowContract = rewriter.create<vector::ContractionOp>( 1634 loc, lhs, rhs, acc, lowAffine, lowIter); 1635 result = 1636 reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter); 1637 } 1638 return result; 1639 } 1640 1641 // Lower one reduction dimension. 1642 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op, 1643 PatternRewriter &rewriter) const { 1644 auto loc = op.getLoc(); 1645 VectorType lhsType = op.getLhsType(); 1646 VectorType rhsType = op.getRhsType(); 1647 Type resType = op.getResultType(); 1648 assert(!resType.isa<VectorType>()); 1649 bool isInt = resType.isa<IntegerType>(); 1650 // Use iterator index 0. 1651 int64_t iterIndex = 0; 1652 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1653 Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex); 1654 Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex); 1655 assert(lookupLhs.hasValue() && "missing LHS parallel index"); 1656 assert(lookupRhs.hasValue() && "missing RHS parallel index"); 1657 int64_t lhsIndex = lookupLhs.getValue(); 1658 int64_t rhsIndex = lookupRhs.getValue(); 1659 int64_t dimSize = lhsType.getDimSize(lhsIndex); 1660 assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape"); 1661 // Base case. 1662 if (lhsType.getRank() == 1) { 1663 assert(rhsType.getRank() == 1 && "corrupt contraction"); 1664 Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter); 1665 auto kind = vector::CombiningKind::ADD; 1666 Value res = rewriter.create<vector::ReductionOp>(loc, kind, m); 1667 if (auto acc = op.acc()) 1668 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1669 return res; 1670 } 1671 // Construct new iterator types and affine map array attribute. 1672 std::array<AffineMap, 3> lowIndexingMaps = { 1673 adjustMap(iMap[0], iterIndex, rewriter), 1674 adjustMap(iMap[1], iterIndex, rewriter), 1675 adjustMap(iMap[2], iterIndex, rewriter)}; 1676 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1677 auto lowIter = 1678 rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); 1679 // Unroll into a series of lower dimensional vector.contract ops. 1680 // By feeding the initial accumulator into the first contraction, 1681 // and the result of each contraction into the next, eventually 1682 // the sum of all reductions is computed. 1683 Value result = op.acc(); 1684 for (int64_t d = 0; d < dimSize; ++d) { 1685 auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); 1686 auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); 1687 result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result, 1688 lowAffine, lowIter); 1689 } 1690 return result; 1691 } 1692 1693 } // namespace mlir 1694 1695 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp( 1696 OpBuilder &builder, Operation *op, ArrayRef<Value> ids, 1697 ArrayRef<int64_t> multiplicity, const AffineMap &map) { 1698 OpBuilder::InsertionGuard guard(builder); 1699 builder.setInsertionPointAfter(op); 1700 Location loc = op->getLoc(); 1701 if (op->getNumResults() != 1) 1702 return {}; 1703 Value result = op->getResult(0); 1704 VectorType type = op->getResult(0).getType().dyn_cast<VectorType>(); 1705 if (!type || map.getNumResults() != multiplicity.size()) 1706 return {}; 1707 // For each dimension being distributed check that the size is a multiple of 1708 // the multiplicity. To handle more sizes we would need to support masking. 1709 unsigned multiplictyCount = 0; 1710 for (auto exp : map.getResults()) { 1711 auto affinExp = exp.dyn_cast<AffineDimExpr>(); 1712 if (!affinExp || affinExp.getPosition() >= type.getRank() || 1713 type.getDimSize(affinExp.getPosition()) % 1714 multiplicity[multiplictyCount++] != 1715 0) 1716 return {}; 1717 } 1718 DistributeOps ops; 1719 ops.extract = 1720 builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map); 1721 ops.insert = 1722 builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids); 1723 return ops; 1724 } 1725 1726 /// Progressive lowering of transfer_read. This pattern supports lowering of 1727 /// `vector.transfer_read` to a combination of `vector.load` and 1728 /// `vector.broadcast` if all of the following hold: 1729 /// - Stride of most minor memref dimension must be 1. 1730 /// - Out-of-bounds masking is not required. 1731 /// - If the memref's element type is a vector type then it coincides with the 1732 /// result type. 1733 /// - The permutation map doesn't perform permutation (broadcasting is allowed). 1734 struct TransferReadToVectorLoadLowering 1735 : public OpRewritePattern<vector::TransferReadOp> { 1736 TransferReadToVectorLoadLowering(MLIRContext *context, 1737 llvm::Optional<unsigned> maxRank) 1738 : OpRewritePattern<vector::TransferReadOp>(context), 1739 maxTransferRank(maxRank) {} 1740 1741 LogicalResult matchAndRewrite(vector::TransferReadOp read, 1742 PatternRewriter &rewriter) const override { 1743 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) 1744 return failure(); 1745 1746 SmallVector<unsigned, 4> broadcastedDims; 1747 // Permutations are handled by VectorToSCF or 1748 // populateVectorTransferPermutationMapLoweringPatterns. 1749 // We let the 0-d corner case pass-through as it is supported. 1750 if (!read.permutation_map().isMinorIdentityWithBroadcasting( 1751 &broadcastedDims)) 1752 return failure(); 1753 1754 auto memRefType = read.getShapedType().dyn_cast<MemRefType>(); 1755 if (!memRefType) 1756 return failure(); 1757 1758 // Non-unit strides are handled by VectorToSCF. 1759 if (!vector::isLastMemrefDimUnitStride(memRefType)) 1760 return failure(); 1761 1762 // If there is broadcasting involved then we first load the unbroadcasted 1763 // vector, and then broadcast it with `vector.broadcast`. 1764 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape(); 1765 SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(), 1766 vectorShape.end()); 1767 for (unsigned i : broadcastedDims) 1768 unbroadcastedVectorShape[i] = 1; 1769 VectorType unbroadcastedVectorType = VectorType::get( 1770 unbroadcastedVectorShape, read.getVectorType().getElementType()); 1771 1772 // `vector.load` supports vector types as memref's elements only when the 1773 // resulting vector type is the same as the element type. 1774 auto memrefElTy = memRefType.getElementType(); 1775 if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType) 1776 return failure(); 1777 1778 // Otherwise, element types of the memref and the vector must match. 1779 if (!memrefElTy.isa<VectorType>() && 1780 memrefElTy != read.getVectorType().getElementType()) 1781 return failure(); 1782 1783 // Out-of-bounds dims are handled by MaterializeTransferMask. 1784 if (read.hasOutOfBoundsDim()) 1785 return failure(); 1786 1787 // Create vector load op. 1788 Operation *loadOp; 1789 if (read.mask()) { 1790 Value fill = rewriter.create<vector::SplatOp>( 1791 read.getLoc(), unbroadcastedVectorType, read.padding()); 1792 loadOp = rewriter.create<vector::MaskedLoadOp>( 1793 read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(), 1794 read.mask(), fill); 1795 } else { 1796 loadOp = rewriter.create<vector::LoadOp>(read.getLoc(), 1797 unbroadcastedVectorType, 1798 read.source(), read.indices()); 1799 } 1800 1801 // Insert a broadcasting op if required. 1802 if (!broadcastedDims.empty()) { 1803 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1804 read, read.getVectorType(), loadOp->getResult(0)); 1805 } else { 1806 rewriter.replaceOp(read, loadOp->getResult(0)); 1807 } 1808 1809 return success(); 1810 } 1811 1812 llvm::Optional<unsigned> maxTransferRank; 1813 }; 1814 1815 /// Replace a 0-d vector.load with a memref.load + vector.broadcast. 1816 // TODO: we shouldn't cross the vector/scalar domains just for this 1817 // but atm we lack the infra to avoid it. Possible solutions include: 1818 // - go directly to LLVM + bitcast 1819 // - introduce a bitcast op and likely a new pointer dialect 1820 // - let memref.load/store additionally support the 0-d vector case 1821 // There are still deeper data layout issues lingering even in this 1822 // trivial case (for architectures for which this matters). 1823 struct VectorLoadToMemrefLoadLowering 1824 : public OpRewritePattern<vector::LoadOp> { 1825 using OpRewritePattern<vector::LoadOp>::OpRewritePattern; 1826 1827 LogicalResult matchAndRewrite(vector::LoadOp loadOp, 1828 PatternRewriter &rewriter) const override { 1829 auto vecType = loadOp.getVectorType(); 1830 if (vecType.getNumElements() != 1) 1831 return failure(); 1832 auto memrefLoad = rewriter.create<memref::LoadOp>( 1833 loadOp.getLoc(), loadOp.base(), loadOp.indices()); 1834 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType, 1835 memrefLoad); 1836 return success(); 1837 } 1838 }; 1839 1840 /// Replace a 0-d vector.store with a vector.extractelement + memref.store. 1841 struct VectorStoreToMemrefStoreLowering 1842 : public OpRewritePattern<vector::StoreOp> { 1843 using OpRewritePattern<vector::StoreOp>::OpRewritePattern; 1844 1845 LogicalResult matchAndRewrite(vector::StoreOp storeOp, 1846 PatternRewriter &rewriter) const override { 1847 auto vecType = storeOp.getVectorType(); 1848 if (vecType.getNumElements() != 1) 1849 return failure(); 1850 Value extracted; 1851 if (vecType.getRank() == 0) { 1852 // TODO: Unifiy once ExtractOp supports 0-d vectors. 1853 extracted = rewriter.create<vector::ExtractElementOp>( 1854 storeOp.getLoc(), storeOp.valueToStore()); 1855 } else { 1856 SmallVector<int64_t> indices(vecType.getRank(), 0); 1857 extracted = rewriter.create<vector::ExtractOp>( 1858 storeOp.getLoc(), storeOp.valueToStore(), indices); 1859 } 1860 1861 rewriter.replaceOpWithNewOp<memref::StoreOp>( 1862 storeOp, extracted, storeOp.base(), storeOp.indices()); 1863 return success(); 1864 } 1865 }; 1866 1867 /// Progressive lowering of transfer_write. This pattern supports lowering of 1868 /// `vector.transfer_write` to `vector.store` if all of the following hold: 1869 /// - Stride of most minor memref dimension must be 1. 1870 /// - Out-of-bounds masking is not required. 1871 /// - If the memref's element type is a vector type then it coincides with the 1872 /// type of the written value. 1873 /// - The permutation map is the minor identity map (neither permutation nor 1874 /// broadcasting is allowed). 1875 struct TransferWriteToVectorStoreLowering 1876 : public OpRewritePattern<vector::TransferWriteOp> { 1877 TransferWriteToVectorStoreLowering(MLIRContext *context, 1878 llvm::Optional<unsigned> maxRank) 1879 : OpRewritePattern<vector::TransferWriteOp>(context), 1880 maxTransferRank(maxRank) {} 1881 1882 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 1883 PatternRewriter &rewriter) const override { 1884 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) 1885 return failure(); 1886 1887 // Permutations are handled by VectorToSCF or 1888 // populateVectorTransferPermutationMapLoweringPatterns. 1889 if ( // pass-through for the 0-d corner case. 1890 !write.permutation_map().isMinorIdentity()) 1891 return failure(); 1892 1893 auto memRefType = write.getShapedType().dyn_cast<MemRefType>(); 1894 if (!memRefType) 1895 return failure(); 1896 1897 // Non-unit strides are handled by VectorToSCF. 1898 if (!vector::isLastMemrefDimUnitStride(memRefType)) 1899 return failure(); 1900 1901 // `vector.store` supports vector types as memref's elements only when the 1902 // type of the vector value being written is the same as the element type. 1903 auto memrefElTy = memRefType.getElementType(); 1904 if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType()) 1905 return failure(); 1906 1907 // Otherwise, element types of the memref and the vector must match. 1908 if (!memrefElTy.isa<VectorType>() && 1909 memrefElTy != write.getVectorType().getElementType()) 1910 return failure(); 1911 1912 // Out-of-bounds dims are handled by MaterializeTransferMask. 1913 if (write.hasOutOfBoundsDim()) 1914 return failure(); 1915 if (write.mask()) { 1916 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 1917 write, write.source(), write.indices(), write.mask(), write.vector()); 1918 } else { 1919 rewriter.replaceOpWithNewOp<vector::StoreOp>( 1920 write, write.vector(), write.source(), write.indices()); 1921 } 1922 return success(); 1923 } 1924 1925 llvm::Optional<unsigned> maxTransferRank; 1926 }; 1927 1928 // Returns the values in `arrayAttr` as an integer vector. 1929 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) { 1930 return llvm::to_vector<4>( 1931 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(), 1932 [](IntegerAttr attr) { return attr.getInt(); })); 1933 } 1934 1935 // Shuffles vector.bitcast op after vector.extract op. 1936 // 1937 // This transforms IR like: 1938 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> 1939 // %1 = vector.extract %0[3] : vector<8xf16> 1940 // Into: 1941 // %0 = vector.extract %src[1] : vector<4xf32> 1942 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16> 1943 // %2 = vector.extract %1[1] : vector<2xf16> 1944 struct BubbleDownVectorBitCastForExtract 1945 : public OpRewritePattern<vector::ExtractOp> { 1946 using OpRewritePattern::OpRewritePattern; 1947 1948 LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 1949 PatternRewriter &rewriter) const override { 1950 // Only support extracting scalars for now. 1951 if (extractOp.getVectorType().getRank() != 1) 1952 return failure(); 1953 1954 auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>(); 1955 if (!castOp) 1956 return failure(); 1957 1958 VectorType castSrcType = castOp.getSourceVectorType(); 1959 VectorType castDstType = castOp.getResultVectorType(); 1960 assert(castSrcType.getRank() == castDstType.getRank()); 1961 1962 // Fail to match if we only have one element in the cast op source. 1963 // This is to avoid infinite loop given that this pattern can generate 1964 // such cases. 1965 if (castSrcType.getNumElements() == 1) 1966 return failure(); 1967 1968 // Only support casting to a larger number of elements or now. 1969 // E.g., vector<4xf32> -> vector<8xf16>. 1970 if (castSrcType.getNumElements() > castDstType.getNumElements()) 1971 return failure(); 1972 1973 unsigned expandRatio = 1974 castDstType.getNumElements() / castSrcType.getNumElements(); 1975 1976 auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t { 1977 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 1978 }; 1979 1980 uint64_t index = getFirstIntValue(extractOp.position()); 1981 1982 // Get the single scalar (as a vector) in the source value that packs the 1983 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> 1984 VectorType oneScalarType = 1985 VectorType::get({1}, castSrcType.getElementType()); 1986 Value packedValue = rewriter.create<vector::ExtractOp>( 1987 extractOp.getLoc(), oneScalarType, castOp.source(), 1988 rewriter.getI64ArrayAttr(index / expandRatio)); 1989 1990 // Cast it to a vector with the desired scalar's type. 1991 // E.g. f32 -> vector<2xf16> 1992 VectorType packedType = 1993 VectorType::get({expandRatio}, castDstType.getElementType()); 1994 Value castedValue = rewriter.create<vector::BitCastOp>( 1995 extractOp.getLoc(), packedType, packedValue); 1996 1997 // Finally extract the desired scalar. 1998 rewriter.replaceOpWithNewOp<vector::ExtractOp>( 1999 extractOp, extractOp.getType(), castedValue, 2000 rewriter.getI64ArrayAttr(index % expandRatio)); 2001 2002 return success(); 2003 } 2004 }; 2005 2006 // Shuffles vector.bitcast op after vector.extract_strided_slice op. 2007 // 2008 // This transforms IR like: 2009 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> 2010 // %0 = vector.extract_strided_slice %cast { 2011 // offsets = [4], sizes = [4], strides = [1] 2012 // } : vector<8xf16> to vector<4xf16> 2013 // Into: 2014 // %0 = vector.extract_strided_slice %src { 2015 // offsets = [2], sizes = [2], strides = [1] 2016 // } : vector<4xf32> to vector<2xf32> 2017 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> 2018 struct BubbleDownBitCastForStridedSliceExtract 2019 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 2020 using OpRewritePattern::OpRewritePattern; 2021 2022 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 2023 PatternRewriter &rewriter) const override { 2024 auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>(); 2025 if (!castOp) 2026 return failure(); 2027 2028 VectorType castSrcType = castOp.getSourceVectorType(); 2029 VectorType castDstType = castOp.getResultVectorType(); 2030 assert(castSrcType.getRank() == castDstType.getRank()); 2031 2032 int64_t castSrcLastDim = castSrcType.getShape().back(); 2033 int64_t castDstLastDim = castDstType.getShape().back(); 2034 // Require casting to more elements for now; other cases to be implemented. 2035 if (castSrcLastDim > castDstLastDim) 2036 return failure(); 2037 2038 // Only accept all one strides for now. 2039 if (llvm::any_of(extractOp.strides().getAsValueRange<IntegerAttr>(), 2040 [](const APInt &val) { return !val.isOneValue(); })) 2041 return failure(); 2042 2043 unsigned rank = extractOp.getVectorType().getRank(); 2044 assert(castDstLastDim % castSrcLastDim == 0); 2045 int64_t expandRatio = castDstLastDim / castSrcLastDim; 2046 2047 // If we have a less number of offsets than the rank, then implicitly we 2048 // are selecting the full range for the last bitcasted dimension; other 2049 // dimensions aren't affected. Otherwise, we need to scale down the last 2050 // dimension's offset given we are extracting from less elements now. 2051 ArrayAttr newOffsets = extractOp.offsets(); 2052 if (newOffsets.size() == rank) { 2053 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2054 if (offsets.back() % expandRatio != 0) 2055 return failure(); 2056 offsets.back() = offsets.back() / expandRatio; 2057 newOffsets = rewriter.getI64ArrayAttr(offsets); 2058 } 2059 2060 // Similarly for sizes. 2061 ArrayAttr newSizes = extractOp.sizes(); 2062 if (newSizes.size() == rank) { 2063 SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes); 2064 if (sizes.back() % expandRatio != 0) 2065 return failure(); 2066 sizes.back() = sizes.back() / expandRatio; 2067 newSizes = rewriter.getI64ArrayAttr(sizes); 2068 } 2069 2070 SmallVector<int64_t, 4> dims = 2071 llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape()); 2072 dims.back() = dims.back() / expandRatio; 2073 VectorType newExtractType = 2074 VectorType::get(dims, castSrcType.getElementType()); 2075 2076 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 2077 extractOp.getLoc(), newExtractType, castOp.source(), newOffsets, 2078 newSizes, extractOp.strides()); 2079 2080 rewriter.replaceOpWithNewOp<vector::BitCastOp>( 2081 extractOp, extractOp.getType(), newExtractOp); 2082 2083 return success(); 2084 } 2085 }; 2086 2087 // Shuffles vector.bitcast op before vector.insert_strided_slice op. 2088 // 2089 // This transforms IR like: 2090 // %0 = vector.insert_strided_slice %src, %dst { 2091 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> 2092 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> 2093 // Into: 2094 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32> 2095 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32> 2096 // %2 = vector.insert_strided_slice %src, %dst { 2097 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 2098 struct BubbleUpBitCastForStridedSliceInsert 2099 : public OpRewritePattern<vector::BitCastOp> { 2100 using OpRewritePattern::OpRewritePattern; 2101 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 2102 PatternRewriter &rewriter) const override { 2103 VectorType castSrcType = bitcastOp.getSourceVectorType(); 2104 VectorType castDstType = bitcastOp.getResultVectorType(); 2105 assert(castSrcType.getRank() == castDstType.getRank()); 2106 2107 int64_t castSrcLastDim = castSrcType.getShape().back(); 2108 int64_t castDstLastDim = castDstType.getShape().back(); 2109 // Require casting to less elements for now; other cases to be implemented. 2110 if (castSrcLastDim < castDstLastDim) 2111 return failure(); 2112 2113 assert(castSrcLastDim % castDstLastDim == 0); 2114 int64_t shrinkRatio = castSrcLastDim / castDstLastDim; 2115 2116 auto insertOp = 2117 bitcastOp.source().getDefiningOp<vector::InsertStridedSliceOp>(); 2118 if (!insertOp) 2119 return failure(); 2120 2121 // Only accept all one strides for now. 2122 if (llvm::any_of(insertOp.strides().getAsValueRange<IntegerAttr>(), 2123 [](const APInt &val) { return !val.isOneValue(); })) 2124 return failure(); 2125 2126 unsigned rank = insertOp.getSourceVectorType().getRank(); 2127 // Require insert op to have the same rank for the source and destination 2128 // vector; other cases to be implemented. 2129 if (rank != insertOp.getDestVectorType().getRank()) 2130 return failure(); 2131 2132 ArrayAttr newOffsets = insertOp.offsets(); 2133 assert(newOffsets.size() == rank); 2134 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2135 if (offsets.back() % shrinkRatio != 0) 2136 return failure(); 2137 offsets.back() = offsets.back() / shrinkRatio; 2138 newOffsets = rewriter.getI64ArrayAttr(offsets); 2139 2140 SmallVector<int64_t, 4> srcDims = 2141 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape()); 2142 srcDims.back() = srcDims.back() / shrinkRatio; 2143 VectorType newCastSrcType = 2144 VectorType::get(srcDims, castDstType.getElementType()); 2145 2146 auto newCastSrcOp = rewriter.create<vector::BitCastOp>( 2147 bitcastOp.getLoc(), newCastSrcType, insertOp.source()); 2148 2149 SmallVector<int64_t, 4> dstDims = 2150 llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); 2151 dstDims.back() = dstDims.back() / shrinkRatio; 2152 VectorType newCastDstType = 2153 VectorType::get(dstDims, castDstType.getElementType()); 2154 2155 auto newCastDstOp = rewriter.create<vector::BitCastOp>( 2156 bitcastOp.getLoc(), newCastDstType, insertOp.dest()); 2157 2158 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>( 2159 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, 2160 insertOp.strides()); 2161 2162 return success(); 2163 } 2164 }; 2165 2166 static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc, 2167 Type targetType, Value value) { 2168 if (targetType == value.getType()) 2169 return value; 2170 2171 bool targetIsIndex = targetType.isIndex(); 2172 bool valueIsIndex = value.getType().isIndex(); 2173 if (targetIsIndex ^ valueIsIndex) 2174 return rewriter.create<arith::IndexCastOp>(loc, targetType, value); 2175 2176 auto targetIntegerType = targetType.dyn_cast<IntegerType>(); 2177 auto valueIntegerType = value.getType().dyn_cast<IntegerType>(); 2178 assert(targetIntegerType && valueIntegerType && 2179 "unexpected cast between types other than integers and index"); 2180 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); 2181 2182 if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) 2183 return rewriter.create<arith::ExtSIOp>(loc, targetIntegerType, value); 2184 return rewriter.create<arith::TruncIOp>(loc, targetIntegerType, value); 2185 } 2186 2187 // Helper that returns a vector comparison that constructs a mask: 2188 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 2189 // 2190 // If `dim == 0` then the result will be a 0-D vector. 2191 // 2192 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 2193 // much more compact, IR for this operation, but LLVM eventually 2194 // generates more elaborate instructions for this intrinsic since it 2195 // is very conservative on the boundary conditions. 2196 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, 2197 bool indexOptimizations, int64_t dim, 2198 Value b, Value *off = nullptr) { 2199 auto loc = op->getLoc(); 2200 // If we can assume all indices fit in 32-bit, we perform the vector 2201 // comparison in 32-bit to get a higher degree of SIMD parallelism. 2202 // Otherwise we perform the vector comparison using 64-bit indices. 2203 Type idxType = 2204 indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type(); 2205 DenseIntElementsAttr indicesAttr; 2206 if (dim == 0 && indexOptimizations) { 2207 indicesAttr = DenseIntElementsAttr::get( 2208 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0}); 2209 } else if (dim == 0) { 2210 indicesAttr = DenseIntElementsAttr::get( 2211 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0}); 2212 } else if (indexOptimizations) { 2213 indicesAttr = rewriter.getI32VectorAttr( 2214 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))); 2215 } else { 2216 indicesAttr = rewriter.getI64VectorAttr( 2217 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))); 2218 } 2219 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 2220 // Add in an offset if requested. 2221 if (off) { 2222 Value o = createCastToIndexLike(rewriter, loc, idxType, *off); 2223 Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o); 2224 indices = rewriter.create<arith::AddIOp>(loc, ov, indices); 2225 } 2226 // Construct the vector comparison. 2227 Value bound = createCastToIndexLike(rewriter, loc, idxType, b); 2228 Value bounds = 2229 rewriter.create<vector::SplatOp>(loc, indices.getType(), bound); 2230 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices, 2231 bounds); 2232 } 2233 2234 template <typename ConcreteOp> 2235 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> { 2236 public: 2237 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt) 2238 : mlir::OpRewritePattern<ConcreteOp>(context), 2239 indexOptimizations(enableIndexOpt) {} 2240 2241 LogicalResult matchAndRewrite(ConcreteOp xferOp, 2242 PatternRewriter &rewriter) const override { 2243 if (!xferOp.hasOutOfBoundsDim()) 2244 return failure(); 2245 2246 if (xferOp.getVectorType().getRank() > 1 || 2247 llvm::size(xferOp.indices()) == 0) 2248 return failure(); 2249 2250 Location loc = xferOp->getLoc(); 2251 VectorType vtp = xferOp.getVectorType(); 2252 2253 // Create the in-bounds mask with all elements between [0 .. dim - offset) 2254 // set and [dim - offset .. vector_length) unset. 2255 // 2256 // TODO: when the leaf transfer rank is k > 1, we need the last `k` 2257 // dimensions here. 2258 unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 2259 Value off = xferOp.indices()[lastIndex]; 2260 Value dim = 2261 vector::createOrFoldDimOp(rewriter, loc, xferOp.source(), lastIndex); 2262 Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off); 2263 Value mask = rewriter.create<vector::CreateMaskOp>( 2264 loc, 2265 VectorType::get(vtp.getShape(), rewriter.getI1Type(), 2266 vtp.getNumScalableDims()), 2267 b); 2268 if (xferOp.mask()) { 2269 // Intersect the in-bounds with the mask specified as an op parameter. 2270 mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.mask()); 2271 } 2272 2273 rewriter.updateRootInPlace(xferOp, [&]() { 2274 xferOp.maskMutable().assign(mask); 2275 xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true})); 2276 }); 2277 2278 return success(); 2279 } 2280 2281 private: 2282 const bool indexOptimizations; 2283 }; 2284 2285 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only). 2286 class VectorCreateMaskOpConversion 2287 : public OpRewritePattern<vector::CreateMaskOp> { 2288 public: 2289 explicit VectorCreateMaskOpConversion(MLIRContext *context, 2290 bool enableIndexOpt) 2291 : mlir::OpRewritePattern<vector::CreateMaskOp>(context), 2292 indexOptimizations(enableIndexOpt) {} 2293 2294 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 2295 PatternRewriter &rewriter) const override { 2296 auto dstType = op.getType(); 2297 int64_t rank = dstType.getRank(); 2298 if (rank > 1) 2299 return failure(); 2300 rewriter.replaceOp( 2301 op, buildVectorComparison(rewriter, op, indexOptimizations, 2302 rank == 0 ? 0 : dstType.getDimSize(0), 2303 op.getOperand(0))); 2304 return success(); 2305 } 2306 2307 private: 2308 const bool indexOptimizations; 2309 }; 2310 2311 // Drop inner most contiguous unit dimensions from transfer_read operand. 2312 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> { 2313 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 2314 2315 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 2316 PatternRewriter &rewriter) const override { 2317 // TODO: support 0-d corner case. 2318 if (readOp.getTransferRank() == 0) 2319 return failure(); 2320 2321 // TODO: support mask. 2322 if (readOp.mask()) 2323 return failure(); 2324 2325 auto srcType = readOp.source().getType().dyn_cast<MemRefType>(); 2326 if (!srcType || !srcType.hasStaticShape()) 2327 return failure(); 2328 2329 if (!readOp.permutation_map().isMinorIdentity()) 2330 return failure(); 2331 2332 auto targetType = readOp.getVectorType(); 2333 if (targetType.getRank() <= 1) 2334 return failure(); 2335 2336 SmallVector<int64_t> srcStrides; 2337 int64_t srcOffset; 2338 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) 2339 return failure(); 2340 2341 size_t dimsToDrop = 0; 2342 for (size_t i = 1; i < srcStrides.size(); ++i) { 2343 int dim = srcType.getRank() - i - 1; 2344 if (srcStrides[dim] == 1) { 2345 dimsToDrop++; 2346 } else { 2347 break; 2348 } 2349 } 2350 if (dimsToDrop == 0) 2351 return failure(); 2352 2353 auto resultTargetVecType = 2354 VectorType::get(targetType.getShape().drop_back(dimsToDrop), 2355 targetType.getElementType()); 2356 2357 MemRefType resultMemrefType; 2358 if (srcType.getLayout().getAffineMap().isIdentity()) { 2359 resultMemrefType = MemRefType::get( 2360 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2361 {}, srcType.getMemorySpaceAsInt()); 2362 } else { 2363 AffineMap map = srcType.getLayout().getAffineMap(); 2364 int numResultDims = map.getNumDims() - dimsToDrop; 2365 int numSymbols = map.getNumSymbols(); 2366 for (size_t i = 0; i < dimsToDrop; ++i) { 2367 int dim = srcType.getRank() - i - 1; 2368 map = map.replace(rewriter.getAffineDimExpr(dim), 2369 rewriter.getAffineConstantExpr(0), numResultDims, 2370 numSymbols); 2371 } 2372 resultMemrefType = MemRefType::get( 2373 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2374 map, srcType.getMemorySpaceAsInt()); 2375 } 2376 2377 auto loc = readOp.getLoc(); 2378 SmallVector<int64_t> offsets(srcType.getRank(), 0); 2379 SmallVector<int64_t> strides(srcType.getRank(), 1); 2380 2381 ArrayAttr inBoundsAttr = 2382 readOp.in_bounds() 2383 ? rewriter.getArrayAttr( 2384 readOp.in_boundsAttr().getValue().drop_back(dimsToDrop)) 2385 : ArrayAttr(); 2386 Value rankedReducedView = rewriter.create<memref::SubViewOp>( 2387 loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(), 2388 strides); 2389 auto permMap = getTransferMinorIdentityMap( 2390 rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType); 2391 Value result = rewriter.create<vector::TransferReadOp>( 2392 loc, resultTargetVecType, rankedReducedView, 2393 readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), 2394 readOp.padding(), 2395 // TODO: support mask. 2396 /*mask=*/Value(), inBoundsAttr); 2397 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType, 2398 result); 2399 return success(); 2400 } 2401 }; 2402 2403 namespace { 2404 2405 /// This function checks to see if the vector combining kind 2406 /// is consistent with the integer or float element type. 2407 static bool isValidKind(bool isInt, vector::CombiningKind kind) { 2408 using vector::CombiningKind; 2409 enum class KindType { FLOAT, INT, INVALID }; 2410 KindType type{KindType::INVALID}; 2411 switch (kind) { 2412 case CombiningKind::MINF: 2413 case CombiningKind::MAXF: 2414 type = KindType::FLOAT; 2415 break; 2416 case CombiningKind::MINUI: 2417 case CombiningKind::MINSI: 2418 case CombiningKind::MAXUI: 2419 case CombiningKind::MAXSI: 2420 case CombiningKind::AND: 2421 case CombiningKind::OR: 2422 case CombiningKind::XOR: 2423 type = KindType::INT; 2424 break; 2425 case CombiningKind::ADD: 2426 case CombiningKind::MUL: 2427 type = isInt ? KindType::INT : KindType::FLOAT; 2428 break; 2429 } 2430 bool isValidIntKind = (type == KindType::INT) && isInt; 2431 bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt); 2432 return (isValidIntKind || isValidFloatKind); 2433 } 2434 2435 /// This function constructs the appropriate integer or float 2436 /// operation given the vector combining kind and operands. The 2437 /// supported int operations are : add, mul, min (signed/unsigned), 2438 /// max(signed/unsigned), and, or, xor. The supported float 2439 /// operations are : add, mul, min and max. 2440 static Value genOperator(Location loc, Value x, Value y, 2441 vector::CombiningKind kind, 2442 PatternRewriter &rewriter) { 2443 using vector::CombiningKind; 2444 2445 auto elType = x.getType().cast<VectorType>().getElementType(); 2446 bool isInt = elType.isIntOrIndex(); 2447 2448 Value combinedResult{nullptr}; 2449 switch (kind) { 2450 case CombiningKind::ADD: 2451 if (isInt) 2452 combinedResult = rewriter.create<arith::AddIOp>(loc, x, y); 2453 else 2454 combinedResult = rewriter.create<arith::AddFOp>(loc, x, y); 2455 break; 2456 case CombiningKind::MUL: 2457 if (isInt) 2458 combinedResult = rewriter.create<arith::MulIOp>(loc, x, y); 2459 else 2460 combinedResult = rewriter.create<arith::MulFOp>(loc, x, y); 2461 break; 2462 case CombiningKind::MINUI: 2463 combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y); 2464 break; 2465 case CombiningKind::MINSI: 2466 combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y); 2467 break; 2468 case CombiningKind::MAXUI: 2469 combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y); 2470 break; 2471 case CombiningKind::MAXSI: 2472 combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y); 2473 break; 2474 case CombiningKind::AND: 2475 combinedResult = rewriter.create<arith::AndIOp>(loc, x, y); 2476 break; 2477 case CombiningKind::OR: 2478 combinedResult = rewriter.create<arith::OrIOp>(loc, x, y); 2479 break; 2480 case CombiningKind::XOR: 2481 combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y); 2482 break; 2483 case CombiningKind::MINF: 2484 combinedResult = rewriter.create<arith::MinFOp>(loc, x, y); 2485 break; 2486 case CombiningKind::MAXF: 2487 combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y); 2488 break; 2489 } 2490 return combinedResult; 2491 } 2492 2493 /// Convert vector.scan op into arith ops and 2494 /// vector.insert_strided_slice/extract_strided_slice 2495 /// 2496 /// Ex: 2497 /// ``` 2498 /// %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 2499 /// 1} : 2500 /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>) 2501 /// ``` 2502 /// Gets converted to: 2503 /// ``` 2504 /// %cst = arith.constant dense<0> : vector<2x3xi32> 2505 /// %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1], 2506 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 = 2507 /// vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]} 2508 /// : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice 2509 /// %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : 2510 /// vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 : 2511 /// vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1], 2512 /// strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 = 2513 /// vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1], 2514 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3, 2515 /// %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets = 2516 /// [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 = 2517 /// vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 : 2518 /// vector<2x3xi32>, vector<2xi32> 2519 /// ``` 2520 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { 2521 using OpRewritePattern<vector::ScanOp>::OpRewritePattern; 2522 2523 LogicalResult matchAndRewrite(vector::ScanOp scanOp, 2524 PatternRewriter &rewriter) const override { 2525 auto loc = scanOp.getLoc(); 2526 VectorType destType = scanOp.getDestType(); 2527 ArrayRef<int64_t> destShape = destType.getShape(); 2528 auto elType = destType.getElementType(); 2529 bool isInt = elType.isIntOrIndex(); 2530 if (!isValidKind(isInt, scanOp.kind())) 2531 return failure(); 2532 2533 VectorType resType = VectorType::get(destShape, elType); 2534 Value result = rewriter.create<arith::ConstantOp>( 2535 loc, resType, rewriter.getZeroAttr(resType)); 2536 int64_t reductionDim = scanOp.reduction_dim(); 2537 bool inclusive = scanOp.inclusive(); 2538 int64_t destRank = destType.getRank(); 2539 VectorType initialValueType = scanOp.getInitialValueType(); 2540 int64_t initialValueRank = initialValueType.getRank(); 2541 2542 SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end()); 2543 reductionShape[reductionDim] = 1; 2544 VectorType reductionType = VectorType::get(reductionShape, elType); 2545 SmallVector<int64_t> offsets(destRank, 0); 2546 SmallVector<int64_t> strides(destRank, 1); 2547 SmallVector<int64_t> sizes(destShape.begin(), destShape.end()); 2548 sizes[reductionDim] = 1; 2549 ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes); 2550 ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides); 2551 2552 Value lastOutput, lastInput; 2553 for (int i = 0; i < destShape[reductionDim]; i++) { 2554 offsets[reductionDim] = i; 2555 ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets); 2556 Value input = rewriter.create<vector::ExtractStridedSliceOp>( 2557 loc, reductionType, scanOp.source(), scanOffsets, scanSizes, 2558 scanStrides); 2559 Value output; 2560 if (i == 0) { 2561 if (inclusive) { 2562 output = input; 2563 } else { 2564 if (initialValueRank == 0) { 2565 // ShapeCastOp cannot handle 0-D vectors 2566 output = rewriter.create<vector::BroadcastOp>( 2567 loc, input.getType(), scanOp.initial_value()); 2568 } else { 2569 output = rewriter.create<vector::ShapeCastOp>( 2570 loc, input.getType(), scanOp.initial_value()); 2571 } 2572 } 2573 } else { 2574 Value y = inclusive ? input : lastInput; 2575 output = genOperator(loc, lastOutput, y, scanOp.kind(), rewriter); 2576 assert(output != nullptr); 2577 } 2578 result = rewriter.create<vector::InsertStridedSliceOp>( 2579 loc, output, result, offsets, strides); 2580 lastOutput = output; 2581 lastInput = input; 2582 } 2583 2584 Value reduction; 2585 if (initialValueRank == 0) { 2586 Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0); 2587 reduction = 2588 rewriter.create<vector::BroadcastOp>(loc, initialValueType, v); 2589 } else { 2590 reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType, 2591 lastOutput); 2592 } 2593 2594 rewriter.replaceOp(scanOp, {result, reduction}); 2595 return success(); 2596 } 2597 }; 2598 2599 } // namespace 2600 2601 void mlir::vector::populateVectorMaskMaterializationPatterns( 2602 RewritePatternSet &patterns, bool indexOptimizations) { 2603 patterns.add<VectorCreateMaskOpConversion, 2604 MaterializeTransferMask<vector::TransferReadOp>, 2605 MaterializeTransferMask<vector::TransferWriteOp>>( 2606 patterns.getContext(), indexOptimizations); 2607 } 2608 2609 void mlir::vector::populateShapeCastFoldingPatterns( 2610 RewritePatternSet &patterns) { 2611 patterns.add<ShapeCastOpFolder>(patterns.getContext()); 2612 } 2613 2614 void mlir::vector::populateBubbleVectorBitCastOpPatterns( 2615 RewritePatternSet &patterns) { 2616 patterns.add<BubbleDownVectorBitCastForExtract, 2617 BubbleDownBitCastForStridedSliceExtract, 2618 BubbleUpBitCastForStridedSliceInsert>(patterns.getContext()); 2619 } 2620 2621 void mlir::vector::populateVectorBroadcastLoweringPatterns( 2622 RewritePatternSet &patterns) { 2623 patterns.add<BroadcastOpLowering>(patterns.getContext()); 2624 } 2625 2626 void mlir::vector::populateVectorMaskOpLoweringPatterns( 2627 RewritePatternSet &patterns) { 2628 patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>( 2629 patterns.getContext()); 2630 } 2631 2632 void mlir::vector::populateVectorShapeCastLoweringPatterns( 2633 RewritePatternSet &patterns) { 2634 patterns.add<ShapeCastOp2DDownCastRewritePattern, 2635 ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>( 2636 patterns.getContext()); 2637 } 2638 2639 void mlir::vector::populateVectorContractLoweringPatterns( 2640 RewritePatternSet &patterns, VectorTransformsOptions options) { 2641 patterns.add<OuterProductOpLowering>(patterns.getContext()); 2642 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering, 2643 ContractionOpToOuterProductOpLowering>(options, 2644 patterns.getContext()); 2645 } 2646 2647 void mlir::vector::populateVectorTransposeLoweringPatterns( 2648 RewritePatternSet &patterns, VectorTransformsOptions options) { 2649 patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>( 2650 options, patterns.getContext()); 2651 } 2652 2653 void mlir::vector::populateVectorReductionToContractPatterns( 2654 RewritePatternSet &patterns) { 2655 patterns.add<MultiReduceToContract, CombineContractBroadcast, 2656 CombineContractTranspose, ReorderCastOpsOnBroadcast, 2657 ReorderCastOpsOnTranspose>(patterns.getContext()); 2658 } 2659 2660 void mlir::vector:: 2661 populateVectorTransferCollapseInnerMostContiguousDimsPatterns( 2662 RewritePatternSet &patterns) { 2663 patterns.add<DropInnerMostUnitDims>(patterns.getContext()); 2664 } 2665 2666 void mlir::vector::populateVectorTransferLoweringPatterns( 2667 RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) { 2668 patterns.add<TransferReadToVectorLoadLowering, 2669 TransferWriteToVectorStoreLowering>(patterns.getContext(), 2670 maxTransferRank); 2671 patterns 2672 .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>( 2673 patterns.getContext()); 2674 } 2675 2676 void mlir::vector::populateVectorScanLoweringPatterns( 2677 RewritePatternSet &patterns) { 2678 patterns.add<ScanToArithOps>(patterns.getContext()); 2679 } 2680