1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites as 1->N patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 14 15 #include <type_traits> 16 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 19 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 20 #include "mlir/Dialect/Linalg/IR/Linalg.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/SCF.h" 23 #include "mlir/Dialect/Utils/IndexingUtils.h" 24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 26 #include "mlir/IR/BuiltinTypes.h" 27 #include "mlir/IR/ImplicitLocOpBuilder.h" 28 #include "mlir/IR/Matchers.h" 29 #include "mlir/IR/PatternMatch.h" 30 #include "mlir/Interfaces/VectorInterfaces.h" 31 32 #include "llvm/ADT/DenseSet.h" 33 #include "llvm/ADT/MapVector.h" 34 #include "llvm/ADT/STLExtras.h" 35 #include "llvm/Support/CommandLine.h" 36 #include "llvm/Support/Debug.h" 37 #include "llvm/Support/raw_ostream.h" 38 39 #define DEBUG_TYPE "vector-to-vector" 40 41 using namespace mlir; 42 using namespace mlir::vector; 43 44 // Helper to find an index in an affine map. 45 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) { 46 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 47 int64_t idx = map.getDimPosition(i); 48 if (idx == index) 49 return i; 50 } 51 return None; 52 } 53 54 // Helper to construct iterator types with one index removed. 55 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes, 56 int64_t index) { 57 SmallVector<Attribute, 4> results; 58 for (const auto &it : llvm::enumerate(iteratorTypes)) { 59 int64_t idx = it.index(); 60 if (idx == index) 61 continue; 62 results.push_back(it.value()); 63 } 64 return results; 65 } 66 67 // Helper to construct an affine map with one index removed. 68 static AffineMap adjustMap(AffineMap map, int64_t index, 69 PatternRewriter &rewriter) { 70 auto *ctx = rewriter.getContext(); 71 SmallVector<AffineExpr, 4> results; 72 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 73 int64_t idx = map.getDimPosition(i); 74 if (idx == index) 75 continue; 76 // Re-insert remaining indices, but renamed when occurring 77 // after the removed index. 78 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); 79 results.push_back(targetExpr); 80 } 81 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); 82 } 83 84 // Helper method to possibly drop a dimension in a load. 85 // TODO 86 static Value reshapeLoad(Location loc, Value val, VectorType type, 87 int64_t index, int64_t pos, 88 PatternRewriter &rewriter) { 89 if (index == -1) 90 return val; 91 Type lowType = VectorType::Builder(type).dropDim(0); 92 // At extraction dimension? 93 if (index == 0) { 94 auto posAttr = rewriter.getI64ArrayAttr(pos); 95 return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr); 96 } 97 // Unroll leading dimensions. 98 VectorType vType = lowType.cast<VectorType>(); 99 Type resType = VectorType::Builder(type).dropDim(index); 100 auto resVectorType = resType.cast<VectorType>(); 101 Value result = rewriter.create<arith::ConstantOp>( 102 loc, resVectorType, rewriter.getZeroAttr(resVectorType)); 103 for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { 104 auto posAttr = rewriter.getI64ArrayAttr(d); 105 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr); 106 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); 107 result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, 108 posAttr); 109 } 110 return result; 111 } 112 113 // Helper method to possibly drop a dimension in a store. 114 // TODO 115 static Value reshapeStore(Location loc, Value val, Value result, 116 VectorType type, int64_t index, int64_t pos, 117 PatternRewriter &rewriter) { 118 // Unmodified? 119 if (index == -1) 120 return val; 121 // At insertion dimension? 122 if (index == 0) { 123 auto posAttr = rewriter.getI64ArrayAttr(pos); 124 return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr); 125 } 126 // Unroll leading dimensions. 127 Type lowType = VectorType::Builder(type).dropDim(0); 128 VectorType vType = lowType.cast<VectorType>(); 129 Type insType = VectorType::Builder(vType).dropDim(0); 130 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { 131 auto posAttr = rewriter.getI64ArrayAttr(d); 132 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr); 133 Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr); 134 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); 135 result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr); 136 } 137 return result; 138 } 139 140 template <typename IntType> 141 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) { 142 return llvm::to_vector<4>(llvm::map_range( 143 arrayAttr.getAsRange<IntegerAttr>(), 144 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); 145 } 146 147 namespace { 148 149 /// ShapeCastOpFolder folds cancelling ShapeCastOps away. 150 // 151 // Example: 152 // 153 // The following MLIR with cancelling ShapeCastOps: 154 // 155 // %0 = source : vector<5x4x2xf32> 156 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> 157 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> 158 // %3 = user %2 : vector<5x4x2xf32> 159 // 160 // Should canonicalize to the following: 161 // 162 // %0 = source : vector<5x4x2xf32> 163 // %1 = user %0 : vector<5x4x2xf32> 164 // 165 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> { 166 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 167 168 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, 169 PatternRewriter &rewriter) const override { 170 // Check if 'shapeCastOp' has vector source/result type. 171 auto sourceVectorType = 172 shapeCastOp.getSource().getType().dyn_cast_or_null<VectorType>(); 173 auto resultVectorType = 174 shapeCastOp.getResult().getType().dyn_cast_or_null<VectorType>(); 175 if (!sourceVectorType || !resultVectorType) 176 return failure(); 177 178 // Check if shape cast op source operand is also a shape cast op. 179 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>( 180 shapeCastOp.getSource().getDefiningOp()); 181 if (!sourceShapeCastOp) 182 return failure(); 183 auto operandSourceVectorType = 184 sourceShapeCastOp.getSource().getType().cast<VectorType>(); 185 auto operandResultVectorType = sourceShapeCastOp.getType(); 186 187 // Check if shape cast operations invert each other. 188 if (operandSourceVectorType != resultVectorType || 189 operandResultVectorType != sourceVectorType) 190 return failure(); 191 192 rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource()); 193 return success(); 194 } 195 }; 196 197 /// Progressive lowering of BroadcastOp. 198 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> { 199 public: 200 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern; 201 202 LogicalResult matchAndRewrite(vector::BroadcastOp op, 203 PatternRewriter &rewriter) const override { 204 auto loc = op.getLoc(); 205 VectorType dstType = op.getVectorType(); 206 VectorType srcType = op.getSourceType().dyn_cast<VectorType>(); 207 Type eltType = dstType.getElementType(); 208 209 // Scalar to any vector can use splat. 210 if (!srcType) { 211 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource()); 212 return success(); 213 } 214 215 // Determine rank of source and destination. 216 int64_t srcRank = srcType.getRank(); 217 int64_t dstRank = dstType.getRank(); 218 219 // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. 220 if (srcRank <= 1 && dstRank == 1) { 221 Value ext; 222 if (srcRank == 0) 223 ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource()); 224 else 225 ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0); 226 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext); 227 return success(); 228 } 229 230 // Duplicate this rank. 231 // For example: 232 // %x = broadcast %y : k-D to n-D, k < n 233 // becomes: 234 // %b = broadcast %y : k-D to (n-1)-D 235 // %x = [%b,%b,%b,%b] : n-D 236 // becomes: 237 // %b = [%y,%y] : (n-1)-D 238 // %x = [%b,%b,%b,%b] : n-D 239 if (srcRank < dstRank) { 240 // Duplication. 241 VectorType resType = 242 VectorType::get(dstType.getShape().drop_front(), eltType); 243 Value bcst = 244 rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource()); 245 Value result = rewriter.create<arith::ConstantOp>( 246 loc, dstType, rewriter.getZeroAttr(dstType)); 247 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 248 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 249 rewriter.replaceOp(op, result); 250 return success(); 251 } 252 253 // Find non-matching dimension, if any. 254 assert(srcRank == dstRank); 255 int64_t m = -1; 256 for (int64_t r = 0; r < dstRank; r++) 257 if (srcType.getDimSize(r) != dstType.getDimSize(r)) { 258 m = r; 259 break; 260 } 261 262 // All trailing dimensions are the same. Simply pass through. 263 if (m == -1) { 264 rewriter.replaceOp(op, op.getSource()); 265 return success(); 266 } 267 268 // Any non-matching dimension forces a stretch along this rank. 269 // For example: 270 // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32> 271 // becomes: 272 // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32> 273 // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32> 274 // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32> 275 // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32> 276 // %x = [%a,%b,%c,%d] 277 // becomes: 278 // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32> 279 // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32> 280 // %a = [%u, %v] 281 // .. 282 // %x = [%a,%b,%c,%d] 283 VectorType resType = 284 VectorType::get(dstType.getShape().drop_front(), eltType); 285 Value result = rewriter.create<arith::ConstantOp>( 286 loc, dstType, rewriter.getZeroAttr(dstType)); 287 if (m == 0) { 288 // Stetch at start. 289 Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0); 290 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 291 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 292 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 293 } else { 294 // Stetch not at start. 295 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) { 296 Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d); 297 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 298 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 299 } 300 } 301 rewriter.replaceOp(op, result); 302 return success(); 303 } 304 }; 305 306 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not 307 /// transposed. 308 void pruneNonTransposedDims(ArrayRef<int64_t> transpose, 309 SmallVectorImpl<int64_t> &result) { 310 size_t numTransposedDims = transpose.size(); 311 for (size_t transpDim : llvm::reverse(transpose)) { 312 if (transpDim != numTransposedDims - 1) 313 break; 314 numTransposedDims--; 315 } 316 317 result.append(transpose.begin(), transpose.begin() + numTransposedDims); 318 } 319 320 /// Progressive lowering of TransposeOp. 321 /// One: 322 /// %x = vector.transpose %y, [1, 0] 323 /// is replaced by: 324 /// %z = arith.constant dense<0.000000e+00> 325 /// %0 = vector.extract %y[0, 0] 326 /// %1 = vector.insert %0, %z [0, 0] 327 /// .. 328 /// %x = vector.insert .., .. [.., ..] 329 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { 330 public: 331 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 332 333 TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions, 334 MLIRContext *context) 335 : OpRewritePattern<vector::TransposeOp>(context), 336 vectorTransformOptions(vectorTransformOptions) {} 337 338 LogicalResult matchAndRewrite(vector::TransposeOp op, 339 PatternRewriter &rewriter) const override { 340 auto loc = op.getLoc(); 341 342 Value input = op.getVector(); 343 VectorType inputType = op.getVectorType(); 344 VectorType resType = op.getResultType(); 345 346 // Set up convenience transposition table. 347 SmallVector<int64_t, 4> transp; 348 for (auto attr : op.getTransp()) 349 transp.push_back(attr.cast<IntegerAttr>().getInt()); 350 351 if (vectorTransformOptions.vectorTransposeLowering == 352 vector::VectorTransposeLowering::Shuffle && 353 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) 354 return rewriter.notifyMatchFailure( 355 op, "Options specifies lowering to shuffle"); 356 357 // Handle a true 2-D matrix transpose differently when requested. 358 if (vectorTransformOptions.vectorTransposeLowering == 359 vector::VectorTransposeLowering::Flat && 360 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { 361 Type flattenedType = 362 VectorType::get(resType.getNumElements(), resType.getElementType()); 363 auto matrix = 364 rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input); 365 auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); 366 auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); 367 Value trans = rewriter.create<vector::FlatTransposeOp>( 368 loc, flattenedType, matrix, rows, columns); 369 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans); 370 return success(); 371 } 372 373 // Generate unrolled extract/insert ops. We do not unroll the rightmost 374 // (i.e., highest-order) dimensions that are not transposed and leave them 375 // in vector form to improve performance. Therefore, we prune those 376 // dimensions from the shape/transpose data structures used to generate the 377 // extract/insert ops. 378 SmallVector<int64_t, 4> prunedTransp; 379 pruneNonTransposedDims(transp, prunedTransp); 380 size_t numPrunedDims = transp.size() - prunedTransp.size(); 381 auto prunedInShape = inputType.getShape().drop_back(numPrunedDims); 382 SmallVector<int64_t, 4> ones(prunedInShape.size(), 1); 383 auto prunedInStrides = computeStrides(prunedInShape, ones); 384 385 // Generates the extract/insert operations for every scalar/vector element 386 // of the leftmost transposed dimensions. We traverse every transpose 387 // element using a linearized index that we delinearize to generate the 388 // appropriate indices for the extract/insert operations. 389 Value result = rewriter.create<arith::ConstantOp>( 390 loc, resType, rewriter.getZeroAttr(resType)); 391 int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape); 392 393 for (int64_t linearIdx = 0; linearIdx < numTransposedElements; 394 ++linearIdx) { 395 auto extractIdxs = delinearize(prunedInStrides, linearIdx); 396 SmallVector<int64_t, 4> insertIdxs(extractIdxs); 397 applyPermutationToVector(insertIdxs, prunedTransp); 398 Value extractOp = 399 rewriter.create<vector::ExtractOp>(loc, input, extractIdxs); 400 result = 401 rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs); 402 } 403 404 rewriter.replaceOp(op, result); 405 return success(); 406 } 407 408 private: 409 /// Options to control the vector patterns. 410 vector::VectorTransformsOptions vectorTransformOptions; 411 }; 412 413 /// Rewrite a 2-D vector.transpose as a sequence of: 414 /// vector.shape_cast 2D -> 1D 415 /// vector.shuffle 416 /// vector.shape_cast 1D -> 2D 417 class TransposeOp2DToShuffleLowering 418 : public OpRewritePattern<vector::TransposeOp> { 419 public: 420 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 421 422 TransposeOp2DToShuffleLowering( 423 vector::VectorTransformsOptions vectorTransformOptions, 424 MLIRContext *context) 425 : OpRewritePattern<vector::TransposeOp>(context), 426 vectorTransformOptions(vectorTransformOptions) {} 427 428 LogicalResult matchAndRewrite(vector::TransposeOp op, 429 PatternRewriter &rewriter) const override { 430 auto loc = op.getLoc(); 431 432 VectorType srcType = op.getVectorType(); 433 if (srcType.getRank() != 2) 434 return rewriter.notifyMatchFailure(op, "Not a 2D transpose"); 435 436 SmallVector<int64_t, 4> transp; 437 for (auto attr : op.getTransp()) 438 transp.push_back(attr.cast<IntegerAttr>().getInt()); 439 if (transp[0] != 1 && transp[1] != 0) 440 return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation"); 441 442 if (vectorTransformOptions.vectorTransposeLowering != 443 VectorTransposeLowering::Shuffle) 444 return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle"); 445 446 int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); 447 Value casted = rewriter.create<vector::ShapeCastOp>( 448 loc, VectorType::get({m * n}, srcType.getElementType()), 449 op.getVector()); 450 SmallVector<int64_t> mask; 451 mask.reserve(m * n); 452 for (int64_t j = 0; j < n; ++j) 453 for (int64_t i = 0; i < m; ++i) 454 mask.push_back(i * n + j); 455 456 Value shuffled = 457 rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask); 458 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(), 459 shuffled); 460 461 return success(); 462 } 463 464 private: 465 /// Options to control the vector patterns. 466 vector::VectorTransformsOptions vectorTransformOptions; 467 }; 468 469 /// Progressive lowering of OuterProductOp. 470 /// One: 471 /// %x = vector.outerproduct %lhs, %rhs, %acc 472 /// is replaced by: 473 /// %z = zero-result 474 /// %0 = vector.extract %lhs[0] 475 /// %1 = vector.broadcast %0 476 /// %2 = vector.extract %acc[0] 477 /// %3 = vector.fma %1, %rhs, %2 478 /// %4 = vector.insert %3, %z[0] 479 /// .. 480 /// %x = vector.insert %.., %..[N-1] 481 /// 482 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { 483 public: 484 using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern; 485 486 LogicalResult matchAndRewrite(vector::OuterProductOp op, 487 PatternRewriter &rewriter) const override { 488 auto loc = op.getLoc(); 489 490 VectorType lhsType = op.getOperandVectorTypeLHS(); 491 VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>(); 492 VectorType resType = op.getVectorType(); 493 Type eltType = resType.getElementType(); 494 bool isInt = eltType.isa<IntegerType, IndexType>(); 495 Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0]; 496 vector::CombiningKind kind = op.getKind(); 497 498 if (!rhsType) { 499 // Special case: AXPY operation. 500 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs()); 501 Optional<Value> mult = 502 isInt ? genMultI(loc, op.getLhs(), b, acc, kind, rewriter) 503 : genMultF(loc, op.getLhs(), b, acc, kind, rewriter); 504 if (!mult.hasValue()) 505 return failure(); 506 rewriter.replaceOp(op, mult.getValue()); 507 return success(); 508 } 509 510 Value result = rewriter.create<arith::ConstantOp>( 511 loc, resType, rewriter.getZeroAttr(resType)); 512 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { 513 auto pos = rewriter.getI64ArrayAttr(d); 514 Value x = 515 rewriter.create<vector::ExtractOp>(loc, eltType, op.getLhs(), pos); 516 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x); 517 Value r = nullptr; 518 if (acc) 519 r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos); 520 Optional<Value> m = 521 isInt ? genMultI(loc, a, op.getRhs(), r, kind, rewriter) 522 : genMultF(loc, a, op.getRhs(), r, kind, rewriter); 523 if (!m.hasValue()) 524 return failure(); 525 result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(), 526 result, pos); 527 } 528 rewriter.replaceOp(op, result); 529 return success(); 530 } 531 532 private: 533 static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc, 534 vector::CombiningKind kind, 535 PatternRewriter &rewriter) { 536 using vector::CombiningKind; 537 538 auto mul = rewriter.create<arith::MulIOp>(loc, x, y); 539 if (!acc) 540 return Optional<Value>(mul); 541 542 if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) 543 // Only valid for floating point types. 544 return Optional<Value>(); 545 546 return makeArithReduction(rewriter, loc, kind, mul, acc); 547 } 548 549 static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc, 550 vector::CombiningKind kind, 551 PatternRewriter &rewriter) { 552 using vector::CombiningKind; 553 554 // Special case for fused multiply-add. 555 if (acc && kind == CombiningKind::ADD) { 556 return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc)); 557 } 558 559 auto mul = rewriter.create<arith::MulFOp>(loc, x, y); 560 561 if (!acc) 562 return Optional<Value>(mul); 563 564 if (kind == CombiningKind::ADD || kind == CombiningKind::AND || 565 kind == CombiningKind::MINUI || kind == CombiningKind::MINSI || 566 kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI || 567 kind == CombiningKind::OR || kind == CombiningKind::XOR) 568 // Already handled or only valid for integer types. 569 return Optional<Value>(); 570 571 return makeArithReduction(rewriter, loc, kind, mul, acc); 572 } 573 }; 574 575 /// Progressive lowering of ConstantMaskOp. 576 /// One: 577 /// %x = vector.constant_mask [a,b] 578 /// is replaced by: 579 /// %z = zero-result 580 /// %l = vector.constant_mask [b] 581 /// %4 = vector.insert %l, %z[0] 582 /// .. 583 /// %x = vector.insert %l, %..[a-1] 584 /// until a one-dimensional vector is reached. All these operations 585 /// will be folded at LLVM IR level. 586 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { 587 public: 588 using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern; 589 590 LogicalResult matchAndRewrite(vector::ConstantMaskOp op, 591 PatternRewriter &rewriter) const override { 592 auto loc = op.getLoc(); 593 auto dstType = op.getType(); 594 auto eltType = dstType.getElementType(); 595 auto dimSizes = op.getMaskDimSizes(); 596 int64_t rank = dstType.getRank(); 597 598 if (rank == 0) { 599 assert(dimSizes.size() == 1 && 600 "Expected exactly one dim size for a 0-D vector"); 601 bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1; 602 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 603 op, dstType, 604 DenseIntElementsAttr::get( 605 VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()), 606 ArrayRef<bool>{value})); 607 return success(); 608 } 609 610 // Scalable constant masks can only be lowered for the "none set" case. 611 if (dstType.cast<VectorType>().isScalable()) { 612 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 613 op, DenseElementsAttr::get(dstType, false)); 614 return success(); 615 } 616 617 int64_t trueDim = std::min(dstType.getDimSize(0), 618 dimSizes[0].cast<IntegerAttr>().getInt()); 619 620 if (rank == 1) { 621 // Express constant 1-D case in explicit vector form: 622 // [T,..,T,F,..,F]. 623 SmallVector<bool, 4> values(dstType.getDimSize(0)); 624 for (int64_t d = 0; d < trueDim; d++) 625 values[d] = true; 626 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 627 op, dstType, rewriter.getBoolVectorAttr(values)); 628 return success(); 629 } 630 631 VectorType lowType = 632 VectorType::get(dstType.getShape().drop_front(), eltType); 633 SmallVector<int64_t, 4> newDimSizes; 634 for (int64_t r = 1; r < rank; r++) 635 newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt()); 636 Value trueVal = rewriter.create<vector::ConstantMaskOp>( 637 loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); 638 Value result = rewriter.create<arith::ConstantOp>( 639 loc, dstType, rewriter.getZeroAttr(dstType)); 640 for (int64_t d = 0; d < trueDim; d++) { 641 auto pos = rewriter.getI64ArrayAttr(d); 642 result = 643 rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos); 644 } 645 rewriter.replaceOp(op, result); 646 return success(); 647 } 648 }; 649 650 /// Progressive lowering of CreateMaskOp. 651 /// One: 652 /// %x = vector.create_mask %a, ... : vector<dx...> 653 /// is replaced by: 654 /// %l = vector.create_mask ... : vector<...> ; one lower rank 655 /// %0 = arith.cmpi "slt", %ci, %a | 656 /// %1 = select %0, %l, %zeroes | 657 /// %r = vector.insert %1, %pr [i] | d-times 658 /// %x = .... 659 /// until a one-dimensional vector is reached. 660 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { 661 public: 662 using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern; 663 664 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 665 PatternRewriter &rewriter) const override { 666 auto dstType = op.getResult().getType().cast<VectorType>(); 667 int64_t rank = dstType.getRank(); 668 if (rank <= 1) 669 return rewriter.notifyMatchFailure( 670 op, "0-D and 1-D vectors are handled separately"); 671 672 auto loc = op.getLoc(); 673 auto eltType = dstType.getElementType(); 674 int64_t dim = dstType.getDimSize(0); 675 Value idx = op.getOperand(0); 676 677 VectorType lowType = 678 VectorType::get(dstType.getShape().drop_front(), eltType); 679 Value trueVal = rewriter.create<vector::CreateMaskOp>( 680 loc, lowType, op.getOperands().drop_front()); 681 Value falseVal = rewriter.create<arith::ConstantOp>( 682 loc, lowType, rewriter.getZeroAttr(lowType)); 683 Value result = rewriter.create<arith::ConstantOp>( 684 loc, dstType, rewriter.getZeroAttr(dstType)); 685 for (int64_t d = 0; d < dim; d++) { 686 Value bnd = 687 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d)); 688 Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 689 bnd, idx); 690 Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal); 691 auto pos = rewriter.getI64ArrayAttr(d); 692 result = 693 rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos); 694 } 695 rewriter.replaceOp(op, result); 696 return success(); 697 } 698 }; 699 700 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D 701 /// vectors progressively on the way to target llvm.matrix intrinsics. 702 /// This iterates over the most major dimension of the 2-D vector and performs 703 /// rewrites into: 704 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D 705 class ShapeCastOp2DDownCastRewritePattern 706 : public OpRewritePattern<vector::ShapeCastOp> { 707 public: 708 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 709 710 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 711 PatternRewriter &rewriter) const override { 712 auto sourceVectorType = op.getSourceVectorType(); 713 auto resultVectorType = op.getResultVectorType(); 714 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) 715 return failure(); 716 717 auto loc = op.getLoc(); 718 Value desc = rewriter.create<arith::ConstantOp>( 719 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 720 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; 721 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { 722 Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i); 723 desc = rewriter.create<vector::InsertStridedSliceOp>( 724 loc, vec, desc, 725 /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); 726 } 727 rewriter.replaceOp(op, desc); 728 return success(); 729 } 730 }; 731 732 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D 733 /// vectors progressively. 734 /// This iterates over the most major dimension of the 2-D vector and performs 735 /// rewrites into: 736 /// vector.extract_strided_slice from 1-D + vector.insert into 2-D 737 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. 738 class ShapeCastOp2DUpCastRewritePattern 739 : public OpRewritePattern<vector::ShapeCastOp> { 740 public: 741 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 742 743 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 744 PatternRewriter &rewriter) const override { 745 auto sourceVectorType = op.getSourceVectorType(); 746 auto resultVectorType = op.getResultVectorType(); 747 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) 748 return failure(); 749 750 auto loc = op.getLoc(); 751 Value desc = rewriter.create<arith::ConstantOp>( 752 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 753 unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; 754 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { 755 Value vec = rewriter.create<vector::ExtractStridedSliceOp>( 756 loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize, 757 /*sizes=*/mostMinorVectorSize, 758 /*strides=*/1); 759 desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i); 760 } 761 rewriter.replaceOp(op, desc); 762 return success(); 763 } 764 }; 765 766 // We typically should not lower general shape cast operations into data 767 // movement instructions, since the assumption is that these casts are 768 // optimized away during progressive lowering. For completeness, however, 769 // we fall back to a reference implementation that moves all elements 770 // into the right place if we get here. 771 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { 772 public: 773 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 774 775 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 776 PatternRewriter &rewriter) const override { 777 Location loc = op.getLoc(); 778 auto sourceVectorType = op.getSourceVectorType(); 779 auto resultVectorType = op.getResultVectorType(); 780 781 // Special case 2D/1D lowerings with better implementations. 782 // TODO: make is ND/1D to allow generic ND->1D->MD. 783 int64_t srcRank = sourceVectorType.getRank(); 784 int64_t resRank = resultVectorType.getRank(); 785 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2)) 786 return failure(); 787 788 // Generic ShapeCast lowering path goes all the way down to unrolled scalar 789 // extract/insert chains. 790 // TODO: consider evolving the semantics to only allow 1D source or dest and 791 // drop this potentially very expensive lowering. 792 // Compute number of elements involved in the reshape. 793 int64_t numElts = 1; 794 for (int64_t r = 0; r < srcRank; r++) 795 numElts *= sourceVectorType.getDimSize(r); 796 // Replace with data movement operations: 797 // x[0,0,0] = y[0,0] 798 // x[0,0,1] = y[0,1] 799 // x[0,1,0] = y[0,2] 800 // etc., incrementing the two index vectors "row-major" 801 // within the source and result shape. 802 SmallVector<int64_t, 4> srcIdx(srcRank); 803 SmallVector<int64_t, 4> resIdx(resRank); 804 Value result = rewriter.create<arith::ConstantOp>( 805 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 806 for (int64_t i = 0; i < numElts; i++) { 807 if (i != 0) { 808 incIdx(srcIdx, sourceVectorType, srcRank - 1); 809 incIdx(resIdx, resultVectorType, resRank - 1); 810 } 811 Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx); 812 result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx); 813 } 814 rewriter.replaceOp(op, result); 815 return success(); 816 } 817 818 private: 819 static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) { 820 assert(0 <= r && r < tp.getRank()); 821 if (++idx[r] == tp.getDimSize(r)) { 822 idx[r] = 0; 823 incIdx(idx, tp, r - 1); 824 } 825 } 826 }; 827 828 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp. 829 /// Ex: 830 /// ``` 831 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> 832 /// %1 = vector.multi_reduction add, %0 [1] 833 /// : vector<8x32x16xf32> to vector<8x16xf32> 834 /// ``` 835 /// Gets converted to: 836 /// ``` 837 /// %1 = vector.contract {indexing_maps = [ 838 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 839 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 840 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 841 /// iterator_types = ["parallel", "parallel", "reduction"], 842 /// kind = add} %0, %arg1, %cst_f0 843 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 844 /// ``` 845 struct MultiReduceToContract 846 : public OpRewritePattern<vector::MultiDimReductionOp> { 847 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 848 849 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, 850 PatternRewriter &rewriter) const override { 851 if (reduceOp.getKind() != vector::CombiningKind::ADD) 852 return failure(); 853 Operation *mulOp = reduceOp.getSource().getDefiningOp(); 854 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp)) 855 return failure(); 856 SmallVector<bool> reductionMask = reduceOp.getReductionMask(); 857 auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size()); 858 SmallVector<AffineExpr> exprs; 859 SmallVector<StringRef> iteratorTypes; 860 for (const auto &isReduceDim : llvm::enumerate(reductionMask)) { 861 if (!isReduceDim.value()) { 862 iteratorTypes.push_back(getParallelIteratorTypeName()); 863 exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index())); 864 } else { 865 iteratorTypes.push_back(getReductionIteratorTypeName()); 866 } 867 } 868 auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(), 869 /*symCount=*/0, exprs, reduceOp.getContext()); 870 Value zero = rewriter.create<arith::ConstantOp>( 871 reduceOp.getLoc(), reduceOp.getDestType(), 872 rewriter.getZeroAttr(reduceOp.getDestType())); 873 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>( 874 reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero, 875 rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}), 876 rewriter.getStrArrayAttr(iteratorTypes)); 877 return success(); 878 } 879 }; 880 881 /// Merge TransposeOp into ContractionOp user. 882 /// Ex: 883 /// ``` 884 /// %0 = vector.transpose %arg0, [2, 0, 1] 885 /// : vector<32x16x8xf32> to vector<8x32x16xf32> 886 /// %1 = vector.contract {indexing_maps = [ 887 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 888 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 889 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 890 /// iterator_types = ["parallel", "parallel", "reduction"], 891 /// kind = add} %0, %arg1, %cst_f0 892 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 893 /// ``` 894 /// Gets converted to: 895 /// ``` 896 /// %1 = vector.contract {indexing_maps = [ 897 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>, 898 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 899 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 900 /// iterator_types = ["parallel", "parallel", "reduction"], 901 /// kind = add} %arg0, %arg1, %cst_f0 902 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 903 /// ``` 904 struct CombineContractTranspose 905 : public OpRewritePattern<vector::ContractionOp> { 906 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 907 908 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 909 PatternRewriter &rewriter) const override { 910 SmallVector<AffineMap, 4> maps = 911 llvm::to_vector<4>(contractOp.getIndexingMaps()); 912 Value lhs = contractOp.getLhs(); 913 Value rhs = contractOp.getRhs(); 914 size_t index = 0; 915 bool changed = false; 916 for (Value *operand : {&lhs, &rhs}) { 917 AffineMap &map = maps[index++]; 918 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>(); 919 if (!transposeOp) 920 continue; 921 SmallVector<int64_t> perm; 922 transposeOp.getTransp(perm); 923 AffineMap permutationMap = AffineMap::getPermutationMap( 924 extractVector<unsigned>(transposeOp.getTransp()), 925 contractOp.getContext()); 926 map = inversePermutation(permutationMap).compose(map); 927 *operand = transposeOp.getVector(); 928 changed = true; 929 } 930 if (!changed) 931 return failure(); 932 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 933 contractOp, lhs, rhs, contractOp.getAcc(), 934 rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); 935 return success(); 936 } 937 }; 938 939 /// Merge BroadcastOp into ContractionOp user. 940 /// Ex: 941 /// ``` 942 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> 943 /// %1 = vector.contract {indexing_maps = [ 944 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 945 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 946 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 947 /// iterator_types = ["parallel", "parallel", "reduction"], 948 /// kind = add} %0, %arg1, %cst_f0 949 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 950 /// ``` 951 /// Gets converted to: 952 /// ``` 953 /// %1 = vector.contract {indexing_maps = [ 954 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 955 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 956 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 957 /// iterator_types = ["parallel", "parallel", "reduction"], 958 /// kind = add} %arg0, %arg1, %cst_f0 959 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 960 /// ``` 961 struct CombineContractBroadcast 962 : public OpRewritePattern<vector::ContractionOp> { 963 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 964 965 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 966 PatternRewriter &rewriter) const override { 967 SmallVector<AffineMap, 4> maps = 968 llvm::to_vector<4>(contractOp.getIndexingMaps()); 969 Value lhs = contractOp.getLhs(); 970 Value rhs = contractOp.getRhs(); 971 size_t index = 0; 972 bool changed = false; 973 for (Value *operand : {&lhs, &rhs}) { 974 AffineMap &map = maps[index++]; 975 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>(); 976 if (!broadcast) 977 continue; 978 // contractionOp can only take vector as operands. 979 auto srcType = broadcast.getSourceType().dyn_cast<VectorType>(); 980 if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank()) 981 continue; 982 int64_t rankDiff = 983 broadcast.getVectorType().getRank() - srcType.getRank(); 984 bool innerDimBroadcast = false; 985 SmallVector<AffineExpr> originalDims; 986 for (const auto &dim : llvm::enumerate(srcType.getShape())) { 987 if (dim.value() != 988 broadcast.getVectorType().getDimSize(rankDiff + dim.index())) { 989 innerDimBroadcast = true; 990 break; 991 } 992 originalDims.push_back( 993 rewriter.getAffineDimExpr(dim.index() + rankDiff)); 994 } 995 // Contract doesn't support inner dimension broadcast. Once this is 996 // relaxed we can remove this case. 997 if (innerDimBroadcast) 998 continue; 999 AffineMap broadcastMap = 1000 AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims, 1001 contractOp.getContext()); 1002 map = broadcastMap.compose(map); 1003 *operand = broadcast.getSource(); 1004 changed = true; 1005 } 1006 if (!changed) 1007 return failure(); 1008 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 1009 contractOp, lhs, rhs, contractOp.getAcc(), 1010 rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); 1011 return success(); 1012 } 1013 }; 1014 1015 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and 1016 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when 1017 /// casting ops are around these operations. 1018 /// Ex: 1019 /// ``` 1020 /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8> 1021 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> 1022 /// ``` 1023 /// Gets converted to: 1024 /// ``` 1025 /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32> 1026 /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32> 1027 /// ``` 1028 struct ReorderCastOpsOnBroadcast 1029 : public OpInterfaceRewritePattern<CastOpInterface> { 1030 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern; 1031 1032 LogicalResult matchAndRewrite(CastOpInterface op, 1033 PatternRewriter &rewriter) const override { 1034 if (op->getNumOperands() != 1) 1035 return failure(); 1036 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>(); 1037 if (!bcastOp) 1038 return failure(); 1039 1040 Type castResTy = getElementTypeOrSelf(op->getResult(0)); 1041 if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>()) 1042 castResTy = VectorType::get(vecTy.getShape(), castResTy); 1043 auto *castOp = 1044 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 1045 bcastOp.getSource(), castResTy, op->getAttrs()); 1046 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1047 op, op->getResult(0).getType(), castOp->getResult(0)); 1048 return success(); 1049 } 1050 }; 1051 1052 /// Reorders elementwise(transpose) to transpose(elementwise). This makes 1053 /// transpose ops and contraction ops closer, which kicks in 1054 /// CombineContractTranspose pattern when elementwise ops are between these 1055 /// operations. Ex: 1056 /// ``` 1057 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> 1058 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> 1059 /// %r = arith.addf %at, %bt : vector<2x4xf32> 1060 /// ``` 1061 /// Gets converted to: 1062 /// ``` 1063 /// %0 = arith.addf %a, %b : vector<4x2xf32> 1064 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32> 1065 /// ``` 1066 struct ReorderElementwiseOpsOnTranspose final 1067 : public OpTraitRewritePattern<OpTrait::Elementwise> { 1068 using OpTraitRewritePattern::OpTraitRewritePattern; 1069 LogicalResult matchAndRewrite(Operation *op, 1070 PatternRewriter &rewriter) const override { 1071 if (op->getNumResults() != 1 || op->getNumRegions() != 0) 1072 return failure(); 1073 1074 // Make sure all operands are transpose/constant ops and collect their 1075 // transposition maps. 1076 SmallVector<ArrayAttr, 4> transposeMaps; 1077 transposeMaps.reserve(op->getNumOperands()); 1078 // Record the initial type before transposition. We'll use its shape later. 1079 // Any type will do here as we will check all transpose maps are the same. 1080 VectorType srcType; 1081 for (Value operand : op->getOperands()) { 1082 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>(); 1083 if (transposeOp) { 1084 transposeMaps.push_back(transposeOp.getTransp()); 1085 srcType = transposeOp.getVectorType(); 1086 } else if (!matchPattern(operand, m_Constant())) { 1087 return failure(); 1088 } 1089 } 1090 if (transposeMaps.empty()) 1091 return failure(); 1092 // This is an elementwise op, so all transposed operands should have the 1093 // same type. We need to additionally check that all transposes uses the 1094 // same map. 1095 if (!llvm::is_splat(transposeMaps)) 1096 return rewriter.notifyMatchFailure(op, "different transpose map"); 1097 1098 SmallVector<Value, 4> srcValues; 1099 srcValues.reserve(op->getNumOperands()); 1100 1101 // If there are constant operands, we need to insert inverse transposes for 1102 // them. Calculate the inverse order first. 1103 auto order = extractVector<unsigned>(transposeMaps.front()); 1104 SmallVector<int64_t> invOrder(order.size()); 1105 for (int i = 0, e = order.size(); i < e; ++i) 1106 invOrder[order[i]] = i; 1107 1108 for (Value operand : op->getOperands()) { 1109 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>(); 1110 if (transposeOp) { 1111 srcValues.push_back(transposeOp.getVector()); 1112 } else { 1113 // This is a constant. Create a reverse transpose op for it. 1114 auto vectorType = VectorType::get( 1115 srcType.getShape(), 1116 operand.getType().cast<VectorType>().getElementType()); 1117 srcValues.push_back(rewriter.create<vector::TransposeOp>( 1118 operand.getLoc(), vectorType, operand, 1119 rewriter.getI64ArrayAttr(invOrder))); 1120 } 1121 } 1122 1123 auto vectorType = VectorType::get( 1124 srcType.getShape(), 1125 op->getResultTypes()[0].cast<VectorType>().getElementType()); 1126 Operation *elementwiseOp = 1127 rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, 1128 vectorType, op->getAttrs()); 1129 rewriter.replaceOpWithNewOp<vector::TransposeOp>( 1130 op, op->getResultTypes()[0], elementwiseOp->getResult(0), 1131 transposeMaps.front()); 1132 return success(); 1133 } 1134 }; 1135 1136 } // namespace 1137 1138 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using 1139 /// operands `x` and `y`. 1140 static Value createAdd(Location loc, Value x, Value y, bool isInt, 1141 PatternRewriter &rewriter) { 1142 if (isInt) 1143 return rewriter.create<arith::AddIOp>(loc, x, y); 1144 return rewriter.create<arith::AddFOp>(loc, x, y); 1145 } 1146 1147 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using 1148 /// operands `x and `y`. 1149 static Value createMul(Location loc, Value x, Value y, bool isInt, 1150 PatternRewriter &rewriter) { 1151 if (isInt) 1152 return rewriter.create<arith::MulIOp>(loc, x, y); 1153 return rewriter.create<arith::MulFOp>(loc, x, y); 1154 } 1155 1156 namespace mlir { 1157 1158 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1159 /// semantics to: 1160 /// ``` 1161 /// %mta = maybe_transpose 1162 /// %mtb = maybe_transpose 1163 /// %flattened_a = vector.shape_cast %mta 1164 /// %flattened_b = vector.shape_cast %mtb 1165 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 1166 /// %mtd = vector.shape_cast %flattened_d 1167 /// %d = maybe_untranspose %mtd 1168 /// %e = add %c, %d 1169 /// ``` 1170 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 1171 // 1172 /// This only kicks in when VectorTransformsOptions is set to `Matmul`. 1173 /// vector.transpose operations are inserted if the vector.contract op is not a 1174 /// row-major matrix multiply. 1175 LogicalResult 1176 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, 1177 PatternRewriter &rew) const { 1178 // TODO: implement masks 1179 if (llvm::size(op.getMasks()) != 0) 1180 return failure(); 1181 if (vectorTransformOptions.vectorContractLowering != 1182 vector::VectorContractLowering::Matmul) 1183 return failure(); 1184 if (failed(filter(op))) 1185 return failure(); 1186 1187 auto iteratorTypes = op.getIteratorTypes().getValue(); 1188 if (!isParallelIterator(iteratorTypes[0]) || 1189 !isParallelIterator(iteratorTypes[1]) || 1190 !isReductionIterator(iteratorTypes[2])) 1191 return failure(); 1192 1193 Type elementType = op.getLhsType().getElementType(); 1194 if (!elementType.isIntOrFloat()) 1195 return failure(); 1196 1197 // Perform lhs + rhs transpositions to conform to matmul row-major semantics. 1198 // Bail out if the contraction cannot be put in this form. 1199 MLIRContext *ctx = op.getContext(); 1200 Location loc = op.getLoc(); 1201 AffineExpr m, n, k; 1202 bindDims(rew.getContext(), m, n, k); 1203 // LHS must be A(m, k) or A(k, m). 1204 Value lhs = op.getLhs(); 1205 auto lhsMap = op.getIndexingMaps()[0]; 1206 if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) 1207 lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0}); 1208 else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) 1209 return failure(); 1210 1211 // RHS must be B(k, n) or B(n, k). 1212 Value rhs = op.getRhs(); 1213 auto rhsMap = op.getIndexingMaps()[1]; 1214 if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) 1215 rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0}); 1216 else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) 1217 return failure(); 1218 1219 // At this point lhs and rhs are in row-major. 1220 VectorType lhsType = lhs.getType().cast<VectorType>(); 1221 VectorType rhsType = rhs.getType().cast<VectorType>(); 1222 int64_t lhsRows = lhsType.getDimSize(0); 1223 int64_t lhsColumns = lhsType.getDimSize(1); 1224 int64_t rhsColumns = rhsType.getDimSize(1); 1225 1226 Type flattenedLHSType = 1227 VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); 1228 lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs); 1229 1230 Type flattenedRHSType = 1231 VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); 1232 rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs); 1233 1234 Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns, 1235 rhsColumns); 1236 mul = rew.create<vector::ShapeCastOp>( 1237 loc, 1238 VectorType::get({lhsRows, rhsColumns}, 1239 getElementTypeOrSelf(op.getAcc().getType())), 1240 mul); 1241 1242 // ACC must be C(m, n) or C(n, m). 1243 auto accMap = op.getIndexingMaps()[2]; 1244 if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) 1245 mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0}); 1246 else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) 1247 llvm_unreachable("invalid contraction semantics"); 1248 1249 Value res = 1250 elementType.isa<IntegerType>() 1251 ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul)) 1252 : static_cast<Value>( 1253 rew.create<arith::AddFOp>(loc, op.getAcc(), mul)); 1254 1255 rew.replaceOp(op, res); 1256 return success(); 1257 } 1258 1259 namespace { 1260 struct IteratorType { 1261 IteratorType(StringRef strRef) : strRef(strRef) {} 1262 bool isOfType(Attribute attr) const { 1263 auto sAttr = attr.dyn_cast<StringAttr>(); 1264 return sAttr && sAttr.getValue() == strRef; 1265 } 1266 StringRef strRef; 1267 }; 1268 struct Par : public IteratorType { 1269 Par() : IteratorType(getParallelIteratorTypeName()) {} 1270 }; 1271 struct Red : public IteratorType { 1272 Red() : IteratorType(getReductionIteratorTypeName()) {} 1273 }; 1274 1275 /// Generate a vector implementation for matmat, matvec and tmatvec. 1276 /// This unrolls outer-products along the reduction dimension. 1277 struct UnrolledOuterProductGenerator 1278 : public StructuredGenerator<vector::ContractionOp> { 1279 UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op) 1280 : StructuredGenerator<vector::ContractionOp>(builder, op), 1281 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), 1282 res(op.getAcc()), lhsType(op.getLhsType()) {} 1283 1284 Value t(Value v) { 1285 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1286 return builder.create<vector::TransposeOp>(loc, v, perm); 1287 } 1288 1289 Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) { 1290 assert(reductionSize > 0); 1291 for (int64_t k = 0; k < reductionSize; ++k) { 1292 Value a = builder.create<vector::ExtractOp>(loc, lhs, k); 1293 Value b = builder.create<vector::ExtractOp>(loc, rhs, k); 1294 res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b, 1295 res, kind); 1296 } 1297 return res; 1298 } 1299 1300 /// Two outer parallel, one inner reduction (matmat flavor). 1301 FailureOr<Value> matmat() { 1302 if (!iters({Par(), Par(), Red()})) 1303 return failure(); 1304 // Set up the parallel/reduction structure in the right form. 1305 AffineExpr m, n, k; 1306 bindDims(builder.getContext(), m, n, k); 1307 // Classical row-major matmul: Just permute the lhs. 1308 if (layout({{m, k}, {k, n}, {m, n}})) 1309 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1310 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1311 if (layout({{m, k}, {n, k}, {m, n}})) { 1312 Value tlhs = t(lhs); 1313 return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1)); 1314 } 1315 // No need to permute anything. 1316 if (layout({{k, m}, {k, n}, {m, n}})) 1317 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1318 // Just permute the rhs. 1319 if (layout({{k, m}, {n, k}, {m, n}})) 1320 return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0)); 1321 // Transposed output: swap RHS and LHS. 1322 // Classical row-major matmul: permute the lhs. 1323 if (layout({{m, k}, {k, n}, {n, m}})) 1324 return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1)); 1325 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1326 if (layout({{m, k}, {n, k}, {n, m}})) { 1327 Value trhs = t(rhs); 1328 return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1)); 1329 } 1330 if (layout({{k, m}, {k, n}, {n, m}})) 1331 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1332 if (layout({{k, m}, {n, k}, {n, m}})) 1333 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1334 return failure(); 1335 } 1336 1337 /// One outer parallel, one inner reduction (matvec flavor) 1338 FailureOr<Value> matvec() { 1339 if (!iters({Par(), Red()})) 1340 return failure(); 1341 AffineExpr m, k; 1342 bindDims(builder.getContext(), m, k); 1343 1344 // Case mat-vec: transpose. 1345 if (layout({{m, k}, {k}, {m}})) 1346 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1347 // Case mat-trans-vec: ready to go. 1348 if (layout({{k, m}, {k}, {m}})) 1349 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1350 // Case vec-mat: swap and transpose. 1351 if (layout({{k}, {m, k}, {m}})) 1352 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1353 // Case vec-mat-trans: swap and ready to go. 1354 if (layout({{k}, {k, m}, {m}})) 1355 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1356 return failure(); 1357 } 1358 1359 // 1360 // One outer reduction, one inner parallel (tmatvec flavor) 1361 // 1362 FailureOr<Value> tmatvec() { 1363 if (!iters({Red(), Par()})) 1364 return failure(); 1365 AffineExpr k, m; 1366 bindDims(builder.getContext(), k, m); 1367 1368 // Case mat-vec: transpose. 1369 if (layout({{m, k}, {k}, {m}})) 1370 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1371 // Case mat-trans-vec: ready to go. 1372 if (layout({{k, m}, {k}, {m}})) 1373 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1374 // Case vec-mat: swap and transpose. 1375 if (layout({{k}, {m, k}, {m}})) 1376 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1377 // Case vec-mat-trans: swap and ready to go. 1378 if (layout({{k}, {k, m}, {m}})) 1379 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1380 return failure(); 1381 } 1382 1383 private: 1384 vector::CombiningKind kind; 1385 Value lhs, rhs, res; 1386 VectorType lhsType; 1387 }; 1388 } // namespace 1389 1390 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1391 /// semantics to a reduction_size-unrolled sequence: 1392 /// ``` 1393 /// %at = vector.transpose %a, [1, 0] 1394 /// %bRow0 = vector.extract %b[0] 1395 /// %atRow0 = vector.extract %at[0] 1396 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 1397 /// ... 1398 /// %bRowK = vector.extract %b[K] 1399 /// %atRowK = vector.extract %at[K] 1400 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 1401 /// ``` 1402 /// 1403 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but 1404 /// otherwise supports any layout permutation of the matrix-multiply. 1405 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( 1406 vector::ContractionOp op, PatternRewriter &rewriter) const { 1407 // TODO: implement masks 1408 if (llvm::size(op.getMasks()) != 0) 1409 return failure(); 1410 1411 if (vectorTransformOptions.vectorContractLowering != 1412 vector::VectorContractLowering::OuterProduct) 1413 return failure(); 1414 1415 if (failed(filter(op))) 1416 return failure(); 1417 1418 UnrolledOuterProductGenerator e(rewriter, op); 1419 FailureOr<Value> matmatRes = e.matmat(); 1420 if (succeeded(matmatRes)) { 1421 rewriter.replaceOp(op, *matmatRes); 1422 return success(); 1423 } 1424 FailureOr<Value> matvecRes = e.matvec(); 1425 if (succeeded(matvecRes)) { 1426 rewriter.replaceOp(op, *matvecRes); 1427 return success(); 1428 } 1429 FailureOr<Value> tmatvecRes = e.tmatvec(); 1430 if (succeeded(tmatvecRes)) { 1431 rewriter.replaceOp(op, *tmatvecRes); 1432 return success(); 1433 } 1434 1435 return failure(); 1436 } 1437 1438 LogicalResult 1439 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, 1440 PatternRewriter &rewriter) const { 1441 // TODO: implement masks 1442 if (llvm::size(op.getMasks()) != 0) 1443 return failure(); 1444 1445 if (failed(filter(op))) 1446 return failure(); 1447 1448 if (vectorTransformOptions.vectorContractLowering != 1449 vector::VectorContractLowering::Dot) 1450 return failure(); 1451 1452 auto iteratorTypes = op.getIteratorTypes().getValue(); 1453 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1454 Location loc = op.getLoc(); 1455 Value lhs = op.getLhs(), rhs = op.getRhs(); 1456 1457 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1458 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 1459 AffineExpr m, n, k; 1460 bindDims(rewriter.getContext(), m, n, k); 1461 SmallVector<AffineMap, 4> maps = op.getIndexingMaps(); 1462 // 1463 // In the following we wish to make the reduction dimension innermost so we 1464 // can load vectors and just fmul + reduce into a scalar. 1465 // 1466 if (isParallelIterator(iteratorTypes[0]) && 1467 isParallelIterator(iteratorTypes[1]) && 1468 isReductionIterator(iteratorTypes[2])) { 1469 // 1470 // Two outer parallel, one inner reduction (matmat flavor). 1471 // 1472 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 1473 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1474 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { 1475 // No need to permute anything. 1476 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 1477 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1478 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1479 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 1480 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1481 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 1482 // This is the classical row-major matmul. Just permute the lhs. 1483 Value tmp = lhs; 1484 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1485 rhs = tmp; 1486 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 1487 std::swap(lhs, rhs); 1488 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 1489 Value tmp = lhs; 1490 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1491 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm); 1492 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 1493 Value tmp = rhs; 1494 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1495 lhs = tmp; 1496 } else { 1497 return failure(); 1498 } 1499 } else if (isParallelIterator(iteratorTypes[0]) && 1500 isReductionIterator(iteratorTypes[1])) { 1501 // 1502 // One outer parallel, one inner reduction (matvec flavor) 1503 // 1504 if (maps == infer({{m, n}, {n}, {m}})) { 1505 // No need to permute anything. 1506 } else if (maps == infer({{n, m}, {n}, {m}})) { 1507 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1508 } else if (maps == infer({{n}, {m, n}, {m}})) { 1509 std::swap(lhs, rhs); 1510 } else if (maps == infer({{n}, {n, m}, {m}})) { 1511 std::swap(lhs, rhs); 1512 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1513 } else { 1514 return failure(); 1515 } 1516 } else { 1517 return failure(); 1518 } 1519 1520 VectorType dstType = op.getResultType().cast<VectorType>(); 1521 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && 1522 "Expected dst type of rank 1 or 2"); 1523 1524 unsigned rank = dstType.getRank(); 1525 unsigned dstRows = dstType.getShape()[0]; 1526 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; 1527 1528 // ExtractOp does not allow dynamic indexing, we must unroll explicitly. 1529 Value res = rewriter.create<arith::ConstantOp>(loc, dstType, 1530 rewriter.getZeroAttr(dstType)); 1531 bool isInt = dstType.getElementType().isa<IntegerType>(); 1532 for (unsigned r = 0; r < dstRows; ++r) { 1533 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r); 1534 for (unsigned c = 0; c < dstColumns; ++c) { 1535 Value b = rank == 1 1536 ? rhs 1537 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c); 1538 Value m = createMul(op.getLoc(), a, b, isInt, rewriter); 1539 Value reduced = rewriter.create<vector::ReductionOp>( 1540 op.getLoc(), vector::CombiningKind::ADD, m); 1541 1542 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r} 1543 : SmallVector<int64_t, 2>{r, c}; 1544 res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos); 1545 } 1546 } 1547 if (auto acc = op.getAcc()) 1548 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1549 rewriter.replaceOp(op, res); 1550 return success(); 1551 } 1552 1553 /// Progressive lowering of ContractionOp. 1554 /// One: 1555 /// %x = vector.contract with at least one free/batch dimension 1556 /// is replaced by: 1557 /// %a = vector.contract with one less free/batch dimension 1558 /// %b = vector.contract with one less free/batch dimension 1559 /// .. 1560 /// %x = combine %a %b .. 1561 /// until a pure contraction is reached (no free/batch dimensions), 1562 /// which is replaced by a dot-product. 1563 /// 1564 /// This only kicks in when either VectorTransformsOptions is set 1565 /// to DOT or when other contraction patterns fail. 1566 // 1567 // TODO: break down into transpose/reshape/cast ops 1568 // when they become available to avoid code dup 1569 // TODO: investigate lowering order impact on performance 1570 LogicalResult 1571 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, 1572 PatternRewriter &rewriter) const { 1573 // TODO: implement masks. 1574 if (llvm::size(op.getMasks()) != 0) 1575 return failure(); 1576 1577 if (failed(filter(op))) 1578 return failure(); 1579 1580 // TODO: support mixed mode contract lowering. 1581 if (op.getLhsType().getElementType() != 1582 getElementTypeOrSelf(op.getAccType()) || 1583 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) 1584 return failure(); 1585 1586 // TODO: implement benefits, cost models. 1587 MLIRContext *ctx = op.getContext(); 1588 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); 1589 if (succeeded(pat1.matchAndRewrite(op, rewriter))) 1590 return success(); 1591 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); 1592 if (succeeded(pat2.matchAndRewrite(op, rewriter))) 1593 return success(); 1594 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); 1595 if (succeeded(pat3.matchAndRewrite(op, rewriter))) 1596 return success(); 1597 1598 // Find first batch dimension in LHS/RHS, and lower when found. 1599 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap(); 1600 if (!batchDimMap.empty()) { 1601 int64_t lhsIndex = batchDimMap[0].first; 1602 int64_t rhsIndex = batchDimMap[0].second; 1603 rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter)); 1604 return success(); 1605 } 1606 1607 // Collect contracting dimensions. 1608 std::vector<std::pair<int64_t, int64_t>> contractingDimMap = 1609 op.getContractingDimMap(); 1610 DenseSet<int64_t> lhsContractingDimSet; 1611 DenseSet<int64_t> rhsContractingDimSet; 1612 for (auto &dimPair : contractingDimMap) { 1613 lhsContractingDimSet.insert(dimPair.first); 1614 rhsContractingDimSet.insert(dimPair.second); 1615 } 1616 1617 // Find first free dimension in LHS, and lower when found. 1618 VectorType lhsType = op.getLhsType(); 1619 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { 1620 if (lhsContractingDimSet.count(lhsIndex) == 0) { 1621 rewriter.replaceOp( 1622 op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter)); 1623 return success(); 1624 } 1625 } 1626 1627 // Find first free dimension in RHS, and lower when found. 1628 VectorType rhsType = op.getRhsType(); 1629 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { 1630 if (rhsContractingDimSet.count(rhsIndex) == 0) { 1631 rewriter.replaceOp( 1632 op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter)); 1633 return success(); 1634 } 1635 } 1636 1637 // Lower the first remaining reduction dimension. 1638 if (!contractingDimMap.empty()) { 1639 rewriter.replaceOp(op, lowerReduction(op, rewriter)); 1640 return success(); 1641 } 1642 1643 return failure(); 1644 } 1645 1646 // Lower one parallel dimension. 1647 // TODO: consider reusing existing contract unrolling 1648 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op, 1649 int64_t lhsIndex, int64_t rhsIndex, 1650 PatternRewriter &rewriter) const { 1651 VectorType lhsType = op.getLhsType(); 1652 VectorType rhsType = op.getRhsType(); 1653 VectorType resType = op.getResultType().cast<VectorType>(); 1654 // Find the iterator type index and result index. 1655 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1656 int64_t iterIndex = -1; 1657 int64_t dimSize = -1; 1658 if (lhsIndex >= 0) { 1659 iterIndex = iMap[0].getDimPosition(lhsIndex); 1660 assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) && 1661 "parallel index should be free in LHS or batch in LHS/RHS"); 1662 dimSize = lhsType.getDimSize(lhsIndex); 1663 } else { 1664 assert(rhsIndex >= 0 && "missing parallel index"); 1665 iterIndex = iMap[1].getDimPosition(rhsIndex); 1666 dimSize = rhsType.getDimSize(rhsIndex); 1667 } 1668 assert(iterIndex >= 0 && "parallel index not listed in operand mapping"); 1669 Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex); 1670 assert(lookup.hasValue() && "parallel index not listed in reduction"); 1671 int64_t resIndex = lookup.getValue(); 1672 // Construct new iterator types and affine map array attribute. 1673 std::array<AffineMap, 3> lowIndexingMaps = { 1674 adjustMap(iMap[0], iterIndex, rewriter), 1675 adjustMap(iMap[1], iterIndex, rewriter), 1676 adjustMap(iMap[2], iterIndex, rewriter)}; 1677 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1678 auto lowIter = 1679 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); 1680 // Unroll into a series of lower dimensional vector.contract ops. 1681 Location loc = op.getLoc(); 1682 Value result = rewriter.create<arith::ConstantOp>( 1683 loc, resType, rewriter.getZeroAttr(resType)); 1684 for (int64_t d = 0; d < dimSize; ++d) { 1685 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); 1686 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); 1687 auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter); 1688 Value lowContract = rewriter.create<vector::ContractionOp>( 1689 loc, lhs, rhs, acc, lowAffine, lowIter); 1690 result = 1691 reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter); 1692 } 1693 return result; 1694 } 1695 1696 // Lower one reduction dimension. 1697 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op, 1698 PatternRewriter &rewriter) const { 1699 auto loc = op.getLoc(); 1700 VectorType lhsType = op.getLhsType(); 1701 VectorType rhsType = op.getRhsType(); 1702 Type resType = op.getResultType(); 1703 assert(!resType.isa<VectorType>()); 1704 bool isInt = resType.isa<IntegerType>(); 1705 // Use iterator index 0. 1706 int64_t iterIndex = 0; 1707 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1708 Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex); 1709 Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex); 1710 assert(lookupLhs.hasValue() && "missing LHS parallel index"); 1711 assert(lookupRhs.hasValue() && "missing RHS parallel index"); 1712 int64_t lhsIndex = lookupLhs.getValue(); 1713 int64_t rhsIndex = lookupRhs.getValue(); 1714 int64_t dimSize = lhsType.getDimSize(lhsIndex); 1715 assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape"); 1716 // Base case. 1717 if (lhsType.getRank() == 1) { 1718 assert(rhsType.getRank() == 1 && "corrupt contraction"); 1719 Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter); 1720 auto kind = vector::CombiningKind::ADD; 1721 Value res = rewriter.create<vector::ReductionOp>(loc, kind, m); 1722 if (auto acc = op.getAcc()) 1723 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1724 return res; 1725 } 1726 // Construct new iterator types and affine map array attribute. 1727 std::array<AffineMap, 3> lowIndexingMaps = { 1728 adjustMap(iMap[0], iterIndex, rewriter), 1729 adjustMap(iMap[1], iterIndex, rewriter), 1730 adjustMap(iMap[2], iterIndex, rewriter)}; 1731 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1732 auto lowIter = 1733 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); 1734 // Unroll into a series of lower dimensional vector.contract ops. 1735 // By feeding the initial accumulator into the first contraction, 1736 // and the result of each contraction into the next, eventually 1737 // the sum of all reductions is computed. 1738 Value result = op.getAcc(); 1739 for (int64_t d = 0; d < dimSize; ++d) { 1740 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); 1741 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); 1742 result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result, 1743 lowAffine, lowIter); 1744 } 1745 return result; 1746 } 1747 1748 } // namespace mlir 1749 1750 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp( 1751 OpBuilder &builder, Operation *op, ArrayRef<Value> ids, 1752 ArrayRef<int64_t> multiplicity, const AffineMap &map) { 1753 OpBuilder::InsertionGuard guard(builder); 1754 builder.setInsertionPointAfter(op); 1755 Location loc = op->getLoc(); 1756 if (op->getNumResults() != 1) 1757 return {}; 1758 Value result = op->getResult(0); 1759 VectorType type = op->getResult(0).getType().dyn_cast<VectorType>(); 1760 if (!type || map.getNumResults() != multiplicity.size()) 1761 return {}; 1762 // For each dimension being distributed check that the size is a multiple of 1763 // the multiplicity. To handle more sizes we would need to support masking. 1764 unsigned multiplictyCount = 0; 1765 for (auto exp : map.getResults()) { 1766 auto affinExp = exp.dyn_cast<AffineDimExpr>(); 1767 if (!affinExp || affinExp.getPosition() >= type.getRank() || 1768 type.getDimSize(affinExp.getPosition()) % 1769 multiplicity[multiplictyCount++] != 1770 0) 1771 return {}; 1772 } 1773 DistributeOps ops; 1774 ops.extract = 1775 builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map); 1776 ops.insert = 1777 builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids); 1778 return ops; 1779 } 1780 1781 /// Progressive lowering of transfer_read. This pattern supports lowering of 1782 /// `vector.transfer_read` to a combination of `vector.load` and 1783 /// `vector.broadcast` if all of the following hold: 1784 /// - Stride of most minor memref dimension must be 1. 1785 /// - Out-of-bounds masking is not required. 1786 /// - If the memref's element type is a vector type then it coincides with the 1787 /// result type. 1788 /// - The permutation map doesn't perform permutation (broadcasting is allowed). 1789 struct TransferReadToVectorLoadLowering 1790 : public OpRewritePattern<vector::TransferReadOp> { 1791 TransferReadToVectorLoadLowering(MLIRContext *context, 1792 llvm::Optional<unsigned> maxRank) 1793 : OpRewritePattern<vector::TransferReadOp>(context), 1794 maxTransferRank(maxRank) {} 1795 1796 LogicalResult matchAndRewrite(vector::TransferReadOp read, 1797 PatternRewriter &rewriter) const override { 1798 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) 1799 return failure(); 1800 1801 SmallVector<unsigned, 4> broadcastedDims; 1802 // Permutations are handled by VectorToSCF or 1803 // populateVectorTransferPermutationMapLoweringPatterns. 1804 // We let the 0-d corner case pass-through as it is supported. 1805 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting( 1806 &broadcastedDims)) 1807 return failure(); 1808 1809 auto memRefType = read.getShapedType().dyn_cast<MemRefType>(); 1810 if (!memRefType) 1811 return failure(); 1812 1813 // Non-unit strides are handled by VectorToSCF. 1814 if (!vector::isLastMemrefDimUnitStride(memRefType)) 1815 return failure(); 1816 1817 // If there is broadcasting involved then we first load the unbroadcasted 1818 // vector, and then broadcast it with `vector.broadcast`. 1819 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape(); 1820 SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(), 1821 vectorShape.end()); 1822 for (unsigned i : broadcastedDims) 1823 unbroadcastedVectorShape[i] = 1; 1824 VectorType unbroadcastedVectorType = VectorType::get( 1825 unbroadcastedVectorShape, read.getVectorType().getElementType()); 1826 1827 // `vector.load` supports vector types as memref's elements only when the 1828 // resulting vector type is the same as the element type. 1829 auto memrefElTy = memRefType.getElementType(); 1830 if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType) 1831 return failure(); 1832 1833 // Otherwise, element types of the memref and the vector must match. 1834 if (!memrefElTy.isa<VectorType>() && 1835 memrefElTy != read.getVectorType().getElementType()) 1836 return failure(); 1837 1838 // Out-of-bounds dims are handled by MaterializeTransferMask. 1839 if (read.hasOutOfBoundsDim()) 1840 return failure(); 1841 1842 // Create vector load op. 1843 Operation *loadOp; 1844 if (read.getMask()) { 1845 Value fill = rewriter.create<vector::SplatOp>( 1846 read.getLoc(), unbroadcastedVectorType, read.getPadding()); 1847 loadOp = rewriter.create<vector::MaskedLoadOp>( 1848 read.getLoc(), unbroadcastedVectorType, read.getSource(), 1849 read.getIndices(), read.getMask(), fill); 1850 } else { 1851 loadOp = rewriter.create<vector::LoadOp>( 1852 read.getLoc(), unbroadcastedVectorType, read.getSource(), 1853 read.getIndices()); 1854 } 1855 1856 // Insert a broadcasting op if required. 1857 if (!broadcastedDims.empty()) { 1858 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1859 read, read.getVectorType(), loadOp->getResult(0)); 1860 } else { 1861 rewriter.replaceOp(read, loadOp->getResult(0)); 1862 } 1863 1864 return success(); 1865 } 1866 1867 llvm::Optional<unsigned> maxTransferRank; 1868 }; 1869 1870 /// Replace a 0-d vector.load with a memref.load + vector.broadcast. 1871 // TODO: we shouldn't cross the vector/scalar domains just for this 1872 // but atm we lack the infra to avoid it. Possible solutions include: 1873 // - go directly to LLVM + bitcast 1874 // - introduce a bitcast op and likely a new pointer dialect 1875 // - let memref.load/store additionally support the 0-d vector case 1876 // There are still deeper data layout issues lingering even in this 1877 // trivial case (for architectures for which this matters). 1878 struct VectorLoadToMemrefLoadLowering 1879 : public OpRewritePattern<vector::LoadOp> { 1880 using OpRewritePattern<vector::LoadOp>::OpRewritePattern; 1881 1882 LogicalResult matchAndRewrite(vector::LoadOp loadOp, 1883 PatternRewriter &rewriter) const override { 1884 auto vecType = loadOp.getVectorType(); 1885 if (vecType.getNumElements() != 1) 1886 return failure(); 1887 auto memrefLoad = rewriter.create<memref::LoadOp>( 1888 loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices()); 1889 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType, 1890 memrefLoad); 1891 return success(); 1892 } 1893 }; 1894 1895 /// Replace a 0-d vector.store with a vector.extractelement + memref.store. 1896 struct VectorStoreToMemrefStoreLowering 1897 : public OpRewritePattern<vector::StoreOp> { 1898 using OpRewritePattern<vector::StoreOp>::OpRewritePattern; 1899 1900 LogicalResult matchAndRewrite(vector::StoreOp storeOp, 1901 PatternRewriter &rewriter) const override { 1902 auto vecType = storeOp.getVectorType(); 1903 if (vecType.getNumElements() != 1) 1904 return failure(); 1905 Value extracted; 1906 if (vecType.getRank() == 0) { 1907 // TODO: Unifiy once ExtractOp supports 0-d vectors. 1908 extracted = rewriter.create<vector::ExtractElementOp>( 1909 storeOp.getLoc(), storeOp.getValueToStore()); 1910 } else { 1911 SmallVector<int64_t> indices(vecType.getRank(), 0); 1912 extracted = rewriter.create<vector::ExtractOp>( 1913 storeOp.getLoc(), storeOp.getValueToStore(), indices); 1914 } 1915 1916 rewriter.replaceOpWithNewOp<memref::StoreOp>( 1917 storeOp, extracted, storeOp.getBase(), storeOp.getIndices()); 1918 return success(); 1919 } 1920 }; 1921 1922 /// Progressive lowering of transfer_write. This pattern supports lowering of 1923 /// `vector.transfer_write` to `vector.store` if all of the following hold: 1924 /// - Stride of most minor memref dimension must be 1. 1925 /// - Out-of-bounds masking is not required. 1926 /// - If the memref's element type is a vector type then it coincides with the 1927 /// type of the written value. 1928 /// - The permutation map is the minor identity map (neither permutation nor 1929 /// broadcasting is allowed). 1930 struct TransferWriteToVectorStoreLowering 1931 : public OpRewritePattern<vector::TransferWriteOp> { 1932 TransferWriteToVectorStoreLowering(MLIRContext *context, 1933 llvm::Optional<unsigned> maxRank) 1934 : OpRewritePattern<vector::TransferWriteOp>(context), 1935 maxTransferRank(maxRank) {} 1936 1937 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 1938 PatternRewriter &rewriter) const override { 1939 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) 1940 return failure(); 1941 1942 // Permutations are handled by VectorToSCF or 1943 // populateVectorTransferPermutationMapLoweringPatterns. 1944 if ( // pass-through for the 0-d corner case. 1945 !write.getPermutationMap().isMinorIdentity()) 1946 return failure(); 1947 1948 auto memRefType = write.getShapedType().dyn_cast<MemRefType>(); 1949 if (!memRefType) 1950 return failure(); 1951 1952 // Non-unit strides are handled by VectorToSCF. 1953 if (!vector::isLastMemrefDimUnitStride(memRefType)) 1954 return failure(); 1955 1956 // `vector.store` supports vector types as memref's elements only when the 1957 // type of the vector value being written is the same as the element type. 1958 auto memrefElTy = memRefType.getElementType(); 1959 if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType()) 1960 return failure(); 1961 1962 // Otherwise, element types of the memref and the vector must match. 1963 if (!memrefElTy.isa<VectorType>() && 1964 memrefElTy != write.getVectorType().getElementType()) 1965 return failure(); 1966 1967 // Out-of-bounds dims are handled by MaterializeTransferMask. 1968 if (write.hasOutOfBoundsDim()) 1969 return failure(); 1970 if (write.getMask()) { 1971 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 1972 write, write.getSource(), write.getIndices(), write.getMask(), 1973 write.getVector()); 1974 } else { 1975 rewriter.replaceOpWithNewOp<vector::StoreOp>( 1976 write, write.getVector(), write.getSource(), write.getIndices()); 1977 } 1978 return success(); 1979 } 1980 1981 llvm::Optional<unsigned> maxTransferRank; 1982 }; 1983 1984 // Returns the values in `arrayAttr` as an integer vector. 1985 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) { 1986 return llvm::to_vector<4>( 1987 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(), 1988 [](IntegerAttr attr) { return attr.getInt(); })); 1989 } 1990 1991 // Shuffles vector.bitcast op after vector.extract op. 1992 // 1993 // This transforms IR like: 1994 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> 1995 // %1 = vector.extract %0[3] : vector<8xf16> 1996 // Into: 1997 // %0 = vector.extract %src[1] : vector<4xf32> 1998 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16> 1999 // %2 = vector.extract %1[1] : vector<2xf16> 2000 struct BubbleDownVectorBitCastForExtract 2001 : public OpRewritePattern<vector::ExtractOp> { 2002 using OpRewritePattern::OpRewritePattern; 2003 2004 LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 2005 PatternRewriter &rewriter) const override { 2006 // Only support extracting scalars for now. 2007 if (extractOp.getVectorType().getRank() != 1) 2008 return failure(); 2009 2010 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>(); 2011 if (!castOp) 2012 return failure(); 2013 2014 VectorType castSrcType = castOp.getSourceVectorType(); 2015 VectorType castDstType = castOp.getResultVectorType(); 2016 assert(castSrcType.getRank() == castDstType.getRank()); 2017 2018 // Fail to match if we only have one element in the cast op source. 2019 // This is to avoid infinite loop given that this pattern can generate 2020 // such cases. 2021 if (castSrcType.getNumElements() == 1) 2022 return failure(); 2023 2024 // Only support casting to a larger number of elements or now. 2025 // E.g., vector<4xf32> -> vector<8xf16>. 2026 if (castSrcType.getNumElements() > castDstType.getNumElements()) 2027 return failure(); 2028 2029 unsigned expandRatio = 2030 castDstType.getNumElements() / castSrcType.getNumElements(); 2031 2032 auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t { 2033 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 2034 }; 2035 2036 uint64_t index = getFirstIntValue(extractOp.getPosition()); 2037 2038 // Get the single scalar (as a vector) in the source value that packs the 2039 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> 2040 VectorType oneScalarType = 2041 VectorType::get({1}, castSrcType.getElementType()); 2042 Value packedValue = rewriter.create<vector::ExtractOp>( 2043 extractOp.getLoc(), oneScalarType, castOp.getSource(), 2044 rewriter.getI64ArrayAttr(index / expandRatio)); 2045 2046 // Cast it to a vector with the desired scalar's type. 2047 // E.g. f32 -> vector<2xf16> 2048 VectorType packedType = 2049 VectorType::get({expandRatio}, castDstType.getElementType()); 2050 Value castedValue = rewriter.create<vector::BitCastOp>( 2051 extractOp.getLoc(), packedType, packedValue); 2052 2053 // Finally extract the desired scalar. 2054 rewriter.replaceOpWithNewOp<vector::ExtractOp>( 2055 extractOp, extractOp.getType(), castedValue, 2056 rewriter.getI64ArrayAttr(index % expandRatio)); 2057 2058 return success(); 2059 } 2060 }; 2061 2062 // Shuffles vector.bitcast op after vector.extract_strided_slice op. 2063 // 2064 // This transforms IR like: 2065 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> 2066 // %0 = vector.extract_strided_slice %cast { 2067 // offsets = [4], sizes = [4], strides = [1] 2068 // } : vector<8xf16> to vector<4xf16> 2069 // Into: 2070 // %0 = vector.extract_strided_slice %src { 2071 // offsets = [2], sizes = [2], strides = [1] 2072 // } : vector<4xf32> to vector<2xf32> 2073 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> 2074 struct BubbleDownBitCastForStridedSliceExtract 2075 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 2076 using OpRewritePattern::OpRewritePattern; 2077 2078 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 2079 PatternRewriter &rewriter) const override { 2080 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>(); 2081 if (!castOp) 2082 return failure(); 2083 2084 VectorType castSrcType = castOp.getSourceVectorType(); 2085 VectorType castDstType = castOp.getResultVectorType(); 2086 assert(castSrcType.getRank() == castDstType.getRank()); 2087 2088 int64_t castSrcLastDim = castSrcType.getShape().back(); 2089 int64_t castDstLastDim = castDstType.getShape().back(); 2090 // Require casting to more elements for now; other cases to be implemented. 2091 if (castSrcLastDim > castDstLastDim) 2092 return failure(); 2093 2094 // Only accept all one strides for now. 2095 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(), 2096 [](const APInt &val) { return !val.isOneValue(); })) 2097 return failure(); 2098 2099 unsigned rank = extractOp.getVectorType().getRank(); 2100 assert(castDstLastDim % castSrcLastDim == 0); 2101 int64_t expandRatio = castDstLastDim / castSrcLastDim; 2102 2103 // If we have a less number of offsets than the rank, then implicitly we 2104 // are selecting the full range for the last bitcasted dimension; other 2105 // dimensions aren't affected. Otherwise, we need to scale down the last 2106 // dimension's offset given we are extracting from less elements now. 2107 ArrayAttr newOffsets = extractOp.getOffsets(); 2108 if (newOffsets.size() == rank) { 2109 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2110 if (offsets.back() % expandRatio != 0) 2111 return failure(); 2112 offsets.back() = offsets.back() / expandRatio; 2113 newOffsets = rewriter.getI64ArrayAttr(offsets); 2114 } 2115 2116 // Similarly for sizes. 2117 ArrayAttr newSizes = extractOp.getSizes(); 2118 if (newSizes.size() == rank) { 2119 SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes); 2120 if (sizes.back() % expandRatio != 0) 2121 return failure(); 2122 sizes.back() = sizes.back() / expandRatio; 2123 newSizes = rewriter.getI64ArrayAttr(sizes); 2124 } 2125 2126 SmallVector<int64_t, 4> dims = 2127 llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape()); 2128 dims.back() = dims.back() / expandRatio; 2129 VectorType newExtractType = 2130 VectorType::get(dims, castSrcType.getElementType()); 2131 2132 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 2133 extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets, 2134 newSizes, extractOp.getStrides()); 2135 2136 rewriter.replaceOpWithNewOp<vector::BitCastOp>( 2137 extractOp, extractOp.getType(), newExtractOp); 2138 2139 return success(); 2140 } 2141 }; 2142 2143 // Shuffles vector.bitcast op before vector.insert_strided_slice op. 2144 // 2145 // This transforms IR like: 2146 // %0 = vector.insert_strided_slice %src, %dst { 2147 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> 2148 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> 2149 // Into: 2150 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32> 2151 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32> 2152 // %2 = vector.insert_strided_slice %src, %dst { 2153 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 2154 struct BubbleUpBitCastForStridedSliceInsert 2155 : public OpRewritePattern<vector::BitCastOp> { 2156 using OpRewritePattern::OpRewritePattern; 2157 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 2158 PatternRewriter &rewriter) const override { 2159 VectorType castSrcType = bitcastOp.getSourceVectorType(); 2160 VectorType castDstType = bitcastOp.getResultVectorType(); 2161 assert(castSrcType.getRank() == castDstType.getRank()); 2162 2163 int64_t castSrcLastDim = castSrcType.getShape().back(); 2164 int64_t castDstLastDim = castDstType.getShape().back(); 2165 // Require casting to less elements for now; other cases to be implemented. 2166 if (castSrcLastDim < castDstLastDim) 2167 return failure(); 2168 2169 assert(castSrcLastDim % castDstLastDim == 0); 2170 int64_t shrinkRatio = castSrcLastDim / castDstLastDim; 2171 2172 auto insertOp = 2173 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>(); 2174 if (!insertOp) 2175 return failure(); 2176 2177 // Only accept all one strides for now. 2178 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(), 2179 [](const APInt &val) { return !val.isOneValue(); })) 2180 return failure(); 2181 2182 unsigned rank = insertOp.getSourceVectorType().getRank(); 2183 // Require insert op to have the same rank for the source and destination 2184 // vector; other cases to be implemented. 2185 if (rank != insertOp.getDestVectorType().getRank()) 2186 return failure(); 2187 2188 ArrayAttr newOffsets = insertOp.getOffsets(); 2189 assert(newOffsets.size() == rank); 2190 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2191 if (offsets.back() % shrinkRatio != 0) 2192 return failure(); 2193 offsets.back() = offsets.back() / shrinkRatio; 2194 newOffsets = rewriter.getI64ArrayAttr(offsets); 2195 2196 SmallVector<int64_t, 4> srcDims = 2197 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape()); 2198 srcDims.back() = srcDims.back() / shrinkRatio; 2199 VectorType newCastSrcType = 2200 VectorType::get(srcDims, castDstType.getElementType()); 2201 2202 auto newCastSrcOp = rewriter.create<vector::BitCastOp>( 2203 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource()); 2204 2205 SmallVector<int64_t, 4> dstDims = 2206 llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); 2207 dstDims.back() = dstDims.back() / shrinkRatio; 2208 VectorType newCastDstType = 2209 VectorType::get(dstDims, castDstType.getElementType()); 2210 2211 auto newCastDstOp = rewriter.create<vector::BitCastOp>( 2212 bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); 2213 2214 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>( 2215 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, 2216 insertOp.getStrides()); 2217 2218 return success(); 2219 } 2220 }; 2221 2222 // Helper that returns a vector comparison that constructs a mask: 2223 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 2224 // 2225 // If `dim == 0` then the result will be a 0-D vector. 2226 // 2227 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 2228 // much more compact, IR for this operation, but LLVM eventually 2229 // generates more elaborate instructions for this intrinsic since it 2230 // is very conservative on the boundary conditions. 2231 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, 2232 bool force32BitVectorIndices, int64_t dim, 2233 Value b, Value *off = nullptr) { 2234 auto loc = op->getLoc(); 2235 // If we can assume all indices fit in 32-bit, we perform the vector 2236 // comparison in 32-bit to get a higher degree of SIMD parallelism. 2237 // Otherwise we perform the vector comparison using 64-bit indices. 2238 Type idxType = 2239 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 2240 DenseIntElementsAttr indicesAttr; 2241 if (dim == 0 && force32BitVectorIndices) { 2242 indicesAttr = DenseIntElementsAttr::get( 2243 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0}); 2244 } else if (dim == 0) { 2245 indicesAttr = DenseIntElementsAttr::get( 2246 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0}); 2247 } else if (force32BitVectorIndices) { 2248 indicesAttr = rewriter.getI32VectorAttr( 2249 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))); 2250 } else { 2251 indicesAttr = rewriter.getI64VectorAttr( 2252 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))); 2253 } 2254 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 2255 // Add in an offset if requested. 2256 if (off) { 2257 Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off); 2258 Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o); 2259 indices = rewriter.create<arith::AddIOp>(loc, ov, indices); 2260 } 2261 // Construct the vector comparison. 2262 Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b); 2263 Value bounds = 2264 rewriter.create<vector::SplatOp>(loc, indices.getType(), bound); 2265 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices, 2266 bounds); 2267 } 2268 2269 template <typename ConcreteOp> 2270 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> { 2271 public: 2272 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt) 2273 : mlir::OpRewritePattern<ConcreteOp>(context), 2274 force32BitVectorIndices(enableIndexOpt) {} 2275 2276 LogicalResult matchAndRewrite(ConcreteOp xferOp, 2277 PatternRewriter &rewriter) const override { 2278 if (!xferOp.hasOutOfBoundsDim()) 2279 return failure(); 2280 2281 if (xferOp.getVectorType().getRank() > 1 || 2282 llvm::size(xferOp.getIndices()) == 0) 2283 return failure(); 2284 2285 Location loc = xferOp->getLoc(); 2286 VectorType vtp = xferOp.getVectorType(); 2287 2288 // Create the in-bounds mask with all elements between [0 .. dim - offset) 2289 // set and [dim - offset .. vector_length) unset. 2290 // 2291 // TODO: when the leaf transfer rank is k > 1, we need the last `k` 2292 // dimensions here. 2293 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1; 2294 Value off = xferOp.getIndices()[lastIndex]; 2295 Value dim = 2296 vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex); 2297 Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off); 2298 Value mask = rewriter.create<vector::CreateMaskOp>( 2299 loc, 2300 VectorType::get(vtp.getShape(), rewriter.getI1Type(), 2301 vtp.getNumScalableDims()), 2302 b); 2303 if (xferOp.getMask()) { 2304 // Intersect the in-bounds with the mask specified as an op parameter. 2305 mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask()); 2306 } 2307 2308 rewriter.updateRootInPlace(xferOp, [&]() { 2309 xferOp.getMaskMutable().assign(mask); 2310 xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 2311 }); 2312 2313 return success(); 2314 } 2315 2316 private: 2317 const bool force32BitVectorIndices; 2318 }; 2319 2320 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only). 2321 class VectorCreateMaskOpConversion 2322 : public OpRewritePattern<vector::CreateMaskOp> { 2323 public: 2324 explicit VectorCreateMaskOpConversion(MLIRContext *context, 2325 bool enableIndexOpt) 2326 : mlir::OpRewritePattern<vector::CreateMaskOp>(context), 2327 force32BitVectorIndices(enableIndexOpt) {} 2328 2329 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 2330 PatternRewriter &rewriter) const override { 2331 auto dstType = op.getType(); 2332 if (dstType.cast<VectorType>().isScalable()) 2333 return failure(); 2334 int64_t rank = dstType.getRank(); 2335 if (rank > 1) 2336 return failure(); 2337 rewriter.replaceOp( 2338 op, buildVectorComparison(rewriter, op, force32BitVectorIndices, 2339 rank == 0 ? 0 : dstType.getDimSize(0), 2340 op.getOperand(0))); 2341 return success(); 2342 } 2343 2344 private: 2345 const bool force32BitVectorIndices; 2346 }; 2347 2348 // Drop inner most contiguous unit dimensions from transfer_read operand. 2349 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> { 2350 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 2351 2352 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 2353 PatternRewriter &rewriter) const override { 2354 // TODO: support 0-d corner case. 2355 if (readOp.getTransferRank() == 0) 2356 return failure(); 2357 2358 // TODO: support mask. 2359 if (readOp.getMask()) 2360 return failure(); 2361 2362 auto srcType = readOp.getSource().getType().dyn_cast<MemRefType>(); 2363 if (!srcType || !srcType.hasStaticShape()) 2364 return failure(); 2365 2366 if (!readOp.getPermutationMap().isMinorIdentity()) 2367 return failure(); 2368 2369 auto targetType = readOp.getVectorType(); 2370 if (targetType.getRank() <= 1) 2371 return failure(); 2372 2373 SmallVector<int64_t> srcStrides; 2374 int64_t srcOffset; 2375 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) 2376 return failure(); 2377 2378 size_t dimsToDrop = 0; 2379 for (size_t i = 1; i < srcStrides.size(); ++i) { 2380 int dim = srcType.getRank() - i - 1; 2381 if (srcStrides[dim] == 1) { 2382 dimsToDrop++; 2383 } else { 2384 break; 2385 } 2386 } 2387 if (dimsToDrop == 0) 2388 return failure(); 2389 2390 auto resultTargetVecType = 2391 VectorType::get(targetType.getShape().drop_back(dimsToDrop), 2392 targetType.getElementType()); 2393 2394 MemRefType resultMemrefType; 2395 if (srcType.getLayout().getAffineMap().isIdentity()) { 2396 resultMemrefType = MemRefType::get( 2397 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2398 {}, srcType.getMemorySpaceAsInt()); 2399 } else { 2400 AffineMap map = srcType.getLayout().getAffineMap(); 2401 int numSymbols = map.getNumSymbols(); 2402 for (size_t i = 0; i < dimsToDrop; ++i) { 2403 int dim = srcType.getRank() - i - 1; 2404 map = map.replace(rewriter.getAffineDimExpr(dim), 2405 rewriter.getAffineConstantExpr(0), 2406 map.getNumDims() - 1, numSymbols); 2407 } 2408 resultMemrefType = MemRefType::get( 2409 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2410 map, srcType.getMemorySpaceAsInt()); 2411 } 2412 2413 auto loc = readOp.getLoc(); 2414 SmallVector<int64_t> offsets(srcType.getRank(), 0); 2415 SmallVector<int64_t> strides(srcType.getRank(), 1); 2416 2417 ArrayAttr inBoundsAttr = 2418 readOp.getInBounds() 2419 ? rewriter.getArrayAttr( 2420 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)) 2421 : ArrayAttr(); 2422 Value rankedReducedView = rewriter.create<memref::SubViewOp>( 2423 loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(), 2424 strides); 2425 auto permMap = getTransferMinorIdentityMap( 2426 rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType); 2427 Value result = rewriter.create<vector::TransferReadOp>( 2428 loc, resultTargetVecType, rankedReducedView, 2429 readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), 2430 readOp.getPadding(), 2431 // TODO: support mask. 2432 /*mask=*/Value(), inBoundsAttr); 2433 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType, 2434 result); 2435 return success(); 2436 } 2437 }; 2438 2439 namespace { 2440 2441 /// This function checks to see if the vector combining kind 2442 /// is consistent with the integer or float element type. 2443 static bool isValidKind(bool isInt, vector::CombiningKind kind) { 2444 using vector::CombiningKind; 2445 enum class KindType { FLOAT, INT, INVALID }; 2446 KindType type{KindType::INVALID}; 2447 switch (kind) { 2448 case CombiningKind::MINF: 2449 case CombiningKind::MAXF: 2450 type = KindType::FLOAT; 2451 break; 2452 case CombiningKind::MINUI: 2453 case CombiningKind::MINSI: 2454 case CombiningKind::MAXUI: 2455 case CombiningKind::MAXSI: 2456 case CombiningKind::AND: 2457 case CombiningKind::OR: 2458 case CombiningKind::XOR: 2459 type = KindType::INT; 2460 break; 2461 case CombiningKind::ADD: 2462 case CombiningKind::MUL: 2463 type = isInt ? KindType::INT : KindType::FLOAT; 2464 break; 2465 } 2466 bool isValidIntKind = (type == KindType::INT) && isInt; 2467 bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt); 2468 return (isValidIntKind || isValidFloatKind); 2469 } 2470 2471 /// This function constructs the appropriate integer or float 2472 /// operation given the vector combining kind and operands. The 2473 /// supported int operations are : add, mul, min (signed/unsigned), 2474 /// max(signed/unsigned), and, or, xor. The supported float 2475 /// operations are : add, mul, min and max. 2476 static Value genOperator(Location loc, Value x, Value y, 2477 vector::CombiningKind kind, 2478 PatternRewriter &rewriter) { 2479 using vector::CombiningKind; 2480 2481 auto elType = x.getType().cast<VectorType>().getElementType(); 2482 bool isInt = elType.isIntOrIndex(); 2483 2484 Value combinedResult{nullptr}; 2485 switch (kind) { 2486 case CombiningKind::ADD: 2487 if (isInt) 2488 combinedResult = rewriter.create<arith::AddIOp>(loc, x, y); 2489 else 2490 combinedResult = rewriter.create<arith::AddFOp>(loc, x, y); 2491 break; 2492 case CombiningKind::MUL: 2493 if (isInt) 2494 combinedResult = rewriter.create<arith::MulIOp>(loc, x, y); 2495 else 2496 combinedResult = rewriter.create<arith::MulFOp>(loc, x, y); 2497 break; 2498 case CombiningKind::MINUI: 2499 combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y); 2500 break; 2501 case CombiningKind::MINSI: 2502 combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y); 2503 break; 2504 case CombiningKind::MAXUI: 2505 combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y); 2506 break; 2507 case CombiningKind::MAXSI: 2508 combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y); 2509 break; 2510 case CombiningKind::AND: 2511 combinedResult = rewriter.create<arith::AndIOp>(loc, x, y); 2512 break; 2513 case CombiningKind::OR: 2514 combinedResult = rewriter.create<arith::OrIOp>(loc, x, y); 2515 break; 2516 case CombiningKind::XOR: 2517 combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y); 2518 break; 2519 case CombiningKind::MINF: 2520 combinedResult = rewriter.create<arith::MinFOp>(loc, x, y); 2521 break; 2522 case CombiningKind::MAXF: 2523 combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y); 2524 break; 2525 } 2526 return combinedResult; 2527 } 2528 2529 /// Convert vector.scan op into arith ops and 2530 /// vector.insert_strided_slice/extract_strided_slice 2531 /// 2532 /// Ex: 2533 /// ``` 2534 /// %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 2535 /// 1} : 2536 /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>) 2537 /// ``` 2538 /// Gets converted to: 2539 /// ``` 2540 /// %cst = arith.constant dense<0> : vector<2x3xi32> 2541 /// %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1], 2542 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 = 2543 /// vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]} 2544 /// : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice 2545 /// %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : 2546 /// vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 : 2547 /// vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1], 2548 /// strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 = 2549 /// vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1], 2550 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3, 2551 /// %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets = 2552 /// [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 = 2553 /// vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 : 2554 /// vector<2x3xi32>, vector<2xi32> 2555 /// ``` 2556 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { 2557 using OpRewritePattern<vector::ScanOp>::OpRewritePattern; 2558 2559 LogicalResult matchAndRewrite(vector::ScanOp scanOp, 2560 PatternRewriter &rewriter) const override { 2561 auto loc = scanOp.getLoc(); 2562 VectorType destType = scanOp.getDestType(); 2563 ArrayRef<int64_t> destShape = destType.getShape(); 2564 auto elType = destType.getElementType(); 2565 bool isInt = elType.isIntOrIndex(); 2566 if (!isValidKind(isInt, scanOp.getKind())) 2567 return failure(); 2568 2569 VectorType resType = VectorType::get(destShape, elType); 2570 Value result = rewriter.create<arith::ConstantOp>( 2571 loc, resType, rewriter.getZeroAttr(resType)); 2572 int64_t reductionDim = scanOp.getReductionDim(); 2573 bool inclusive = scanOp.getInclusive(); 2574 int64_t destRank = destType.getRank(); 2575 VectorType initialValueType = scanOp.getInitialValueType(); 2576 int64_t initialValueRank = initialValueType.getRank(); 2577 2578 SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end()); 2579 reductionShape[reductionDim] = 1; 2580 VectorType reductionType = VectorType::get(reductionShape, elType); 2581 SmallVector<int64_t> offsets(destRank, 0); 2582 SmallVector<int64_t> strides(destRank, 1); 2583 SmallVector<int64_t> sizes(destShape.begin(), destShape.end()); 2584 sizes[reductionDim] = 1; 2585 ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes); 2586 ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides); 2587 2588 Value lastOutput, lastInput; 2589 for (int i = 0; i < destShape[reductionDim]; i++) { 2590 offsets[reductionDim] = i; 2591 ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets); 2592 Value input = rewriter.create<vector::ExtractStridedSliceOp>( 2593 loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes, 2594 scanStrides); 2595 Value output; 2596 if (i == 0) { 2597 if (inclusive) { 2598 output = input; 2599 } else { 2600 if (initialValueRank == 0) { 2601 // ShapeCastOp cannot handle 0-D vectors 2602 output = rewriter.create<vector::BroadcastOp>( 2603 loc, input.getType(), scanOp.getInitialValue()); 2604 } else { 2605 output = rewriter.create<vector::ShapeCastOp>( 2606 loc, input.getType(), scanOp.getInitialValue()); 2607 } 2608 } 2609 } else { 2610 Value y = inclusive ? input : lastInput; 2611 output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter); 2612 assert(output != nullptr); 2613 } 2614 result = rewriter.create<vector::InsertStridedSliceOp>( 2615 loc, output, result, offsets, strides); 2616 lastOutput = output; 2617 lastInput = input; 2618 } 2619 2620 Value reduction; 2621 if (initialValueRank == 0) { 2622 Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0); 2623 reduction = 2624 rewriter.create<vector::BroadcastOp>(loc, initialValueType, v); 2625 } else { 2626 reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType, 2627 lastOutput); 2628 } 2629 2630 rewriter.replaceOp(scanOp, {result, reduction}); 2631 return success(); 2632 } 2633 }; 2634 2635 } // namespace 2636 2637 void mlir::vector::populateVectorMaskMaterializationPatterns( 2638 RewritePatternSet &patterns, bool force32BitVectorIndices) { 2639 patterns.add<VectorCreateMaskOpConversion, 2640 MaterializeTransferMask<vector::TransferReadOp>, 2641 MaterializeTransferMask<vector::TransferWriteOp>>( 2642 patterns.getContext(), force32BitVectorIndices); 2643 } 2644 2645 void mlir::vector::populateShapeCastFoldingPatterns( 2646 RewritePatternSet &patterns) { 2647 patterns.add<ShapeCastOpFolder>(patterns.getContext()); 2648 } 2649 2650 void mlir::vector::populateBubbleVectorBitCastOpPatterns( 2651 RewritePatternSet &patterns) { 2652 patterns.add<BubbleDownVectorBitCastForExtract, 2653 BubbleDownBitCastForStridedSliceExtract, 2654 BubbleUpBitCastForStridedSliceInsert>(patterns.getContext()); 2655 } 2656 2657 void mlir::vector::populateVectorBroadcastLoweringPatterns( 2658 RewritePatternSet &patterns) { 2659 patterns.add<BroadcastOpLowering>(patterns.getContext()); 2660 } 2661 2662 void mlir::vector::populateVectorMaskOpLoweringPatterns( 2663 RewritePatternSet &patterns) { 2664 patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>( 2665 patterns.getContext()); 2666 } 2667 2668 void mlir::vector::populateVectorShapeCastLoweringPatterns( 2669 RewritePatternSet &patterns) { 2670 patterns.add<ShapeCastOp2DDownCastRewritePattern, 2671 ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>( 2672 patterns.getContext()); 2673 } 2674 2675 void mlir::vector::populateVectorContractLoweringPatterns( 2676 RewritePatternSet &patterns, VectorTransformsOptions options) { 2677 patterns.add<OuterProductOpLowering>(patterns.getContext()); 2678 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering, 2679 ContractionOpToOuterProductOpLowering>(options, 2680 patterns.getContext()); 2681 } 2682 2683 void mlir::vector::populateVectorTransposeLoweringPatterns( 2684 RewritePatternSet &patterns, VectorTransformsOptions options) { 2685 patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>( 2686 options, patterns.getContext()); 2687 } 2688 2689 void mlir::vector::populateVectorReductionToContractPatterns( 2690 RewritePatternSet &patterns) { 2691 patterns.add<MultiReduceToContract, CombineContractBroadcast, 2692 CombineContractTranspose, ReorderCastOpsOnBroadcast, 2693 ReorderElementwiseOpsOnTranspose>(patterns.getContext()); 2694 } 2695 2696 void mlir::vector:: 2697 populateVectorTransferCollapseInnerMostContiguousDimsPatterns( 2698 RewritePatternSet &patterns) { 2699 patterns.add<DropInnerMostUnitDims>(patterns.getContext()); 2700 } 2701 2702 void mlir::vector::populateVectorTransferLoweringPatterns( 2703 RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) { 2704 patterns.add<TransferReadToVectorLoadLowering, 2705 TransferWriteToVectorStoreLowering>(patterns.getContext(), 2706 maxTransferRank); 2707 patterns 2708 .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>( 2709 patterns.getContext()); 2710 } 2711 2712 void mlir::vector::populateVectorScanLoweringPatterns( 2713 RewritePatternSet &patterns) { 2714 patterns.add<ScanToArithOps>(patterns.getContext()); 2715 } 2716