1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites as 1->N patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 14 15 #include <type_traits> 16 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 19 #include "mlir/Dialect/Linalg/IR/Linalg.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/SCF/SCF.h" 22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 23 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 24 #include "mlir/IR/ImplicitLocOpBuilder.h" 25 #include "mlir/IR/Matchers.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "mlir/Interfaces/VectorInterfaces.h" 28 29 #include "llvm/ADT/DenseSet.h" 30 #include "llvm/ADT/MapVector.h" 31 #include "llvm/ADT/STLExtras.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/raw_ostream.h" 35 36 #define DEBUG_TYPE "vector-to-vector" 37 38 using namespace mlir; 39 using namespace mlir::vector; 40 41 // Helper to find an index in an affine map. 42 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) { 43 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 44 int64_t idx = map.getDimPosition(i); 45 if (idx == index) 46 return i; 47 } 48 return None; 49 } 50 51 // Helper to construct iterator types with one index removed. 52 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes, 53 int64_t index) { 54 SmallVector<Attribute, 4> results; 55 for (const auto &it : llvm::enumerate(iteratorTypes)) { 56 int64_t idx = it.index(); 57 if (idx == index) 58 continue; 59 results.push_back(it.value()); 60 } 61 return results; 62 } 63 64 // Helper to construct an affine map with one index removed. 65 static AffineMap adjustMap(AffineMap map, int64_t index, 66 PatternRewriter &rewriter) { 67 auto *ctx = rewriter.getContext(); 68 SmallVector<AffineExpr, 4> results; 69 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 70 int64_t idx = map.getDimPosition(i); 71 if (idx == index) 72 continue; 73 // Re-insert remaining indices, but renamed when occurring 74 // after the removed index. 75 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); 76 results.push_back(targetExpr); 77 } 78 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); 79 } 80 81 // Helper method to possibly drop a dimension in a load. 82 // TODO 83 static Value reshapeLoad(Location loc, Value val, VectorType type, 84 int64_t index, int64_t pos, 85 PatternRewriter &rewriter) { 86 if (index == -1) 87 return val; 88 Type lowType = VectorType::Builder(type).dropDim(0); 89 // At extraction dimension? 90 if (index == 0) { 91 auto posAttr = rewriter.getI64ArrayAttr(pos); 92 return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr); 93 } 94 // Unroll leading dimensions. 95 VectorType vType = lowType.cast<VectorType>(); 96 Type resType = VectorType::Builder(type).dropDim(index); 97 auto resVectorType = resType.cast<VectorType>(); 98 Value result = rewriter.create<arith::ConstantOp>( 99 loc, resVectorType, rewriter.getZeroAttr(resVectorType)); 100 for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { 101 auto posAttr = rewriter.getI64ArrayAttr(d); 102 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr); 103 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); 104 result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, 105 posAttr); 106 } 107 return result; 108 } 109 110 // Helper method to possibly drop a dimension in a store. 111 // TODO 112 static Value reshapeStore(Location loc, Value val, Value result, 113 VectorType type, int64_t index, int64_t pos, 114 PatternRewriter &rewriter) { 115 // Unmodified? 116 if (index == -1) 117 return val; 118 // At insertion dimension? 119 if (index == 0) { 120 auto posAttr = rewriter.getI64ArrayAttr(pos); 121 return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr); 122 } 123 // Unroll leading dimensions. 124 Type lowType = VectorType::Builder(type).dropDim(0); 125 VectorType vType = lowType.cast<VectorType>(); 126 Type insType = VectorType::Builder(vType).dropDim(0); 127 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { 128 auto posAttr = rewriter.getI64ArrayAttr(d); 129 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr); 130 Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr); 131 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); 132 result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr); 133 } 134 return result; 135 } 136 137 template <typename IntType> 138 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) { 139 return llvm::to_vector<4>(llvm::map_range( 140 arrayAttr.getAsRange<IntegerAttr>(), 141 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); 142 } 143 144 namespace { 145 146 /// ShapeCastOpFolder folds cancelling ShapeCastOps away. 147 // 148 // Example: 149 // 150 // The following MLIR with cancelling ShapeCastOps: 151 // 152 // %0 = source : vector<5x4x2xf32> 153 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> 154 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> 155 // %3 = user %2 : vector<5x4x2xf32> 156 // 157 // Should canonicalize to the following: 158 // 159 // %0 = source : vector<5x4x2xf32> 160 // %1 = user %0 : vector<5x4x2xf32> 161 // 162 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> { 163 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 164 165 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, 166 PatternRewriter &rewriter) const override { 167 // Check if 'shapeCastOp' has vector source/result type. 168 auto sourceVectorType = 169 shapeCastOp.source().getType().dyn_cast_or_null<VectorType>(); 170 auto resultVectorType = 171 shapeCastOp.result().getType().dyn_cast_or_null<VectorType>(); 172 if (!sourceVectorType || !resultVectorType) 173 return failure(); 174 175 // Check if shape cast op source operand is also a shape cast op. 176 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>( 177 shapeCastOp.source().getDefiningOp()); 178 if (!sourceShapeCastOp) 179 return failure(); 180 auto operandSourceVectorType = 181 sourceShapeCastOp.source().getType().cast<VectorType>(); 182 auto operandResultVectorType = sourceShapeCastOp.getType(); 183 184 // Check if shape cast operations invert each other. 185 if (operandSourceVectorType != resultVectorType || 186 operandResultVectorType != sourceVectorType) 187 return failure(); 188 189 rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source()); 190 return success(); 191 } 192 }; 193 194 /// Progressive lowering of BroadcastOp. 195 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> { 196 public: 197 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern; 198 199 LogicalResult matchAndRewrite(vector::BroadcastOp op, 200 PatternRewriter &rewriter) const override { 201 auto loc = op.getLoc(); 202 VectorType dstType = op.getVectorType(); 203 VectorType srcType = op.getSourceType().dyn_cast<VectorType>(); 204 Type eltType = dstType.getElementType(); 205 206 // Scalar to any vector can use splat. 207 if (!srcType) { 208 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.source()); 209 return success(); 210 } 211 212 // Determine rank of source and destination. 213 int64_t srcRank = srcType.getRank(); 214 int64_t dstRank = dstType.getRank(); 215 216 // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. 217 if (srcRank <= 1 && dstRank == 1) { 218 Value ext; 219 if (srcRank == 0) 220 ext = rewriter.create<vector::ExtractElementOp>(loc, op.source()); 221 else 222 ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0); 223 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext); 224 return success(); 225 } 226 227 // Duplicate this rank. 228 // For example: 229 // %x = broadcast %y : k-D to n-D, k < n 230 // becomes: 231 // %b = broadcast %y : k-D to (n-1)-D 232 // %x = [%b,%b,%b,%b] : n-D 233 // becomes: 234 // %b = [%y,%y] : (n-1)-D 235 // %x = [%b,%b,%b,%b] : n-D 236 if (srcRank < dstRank) { 237 // Duplication. 238 VectorType resType = 239 VectorType::get(dstType.getShape().drop_front(), eltType); 240 Value bcst = 241 rewriter.create<vector::BroadcastOp>(loc, resType, op.source()); 242 Value result = rewriter.create<arith::ConstantOp>( 243 loc, dstType, rewriter.getZeroAttr(dstType)); 244 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 245 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 246 rewriter.replaceOp(op, result); 247 return success(); 248 } 249 250 // Find non-matching dimension, if any. 251 assert(srcRank == dstRank); 252 int64_t m = -1; 253 for (int64_t r = 0; r < dstRank; r++) 254 if (srcType.getDimSize(r) != dstType.getDimSize(r)) { 255 m = r; 256 break; 257 } 258 259 // All trailing dimensions are the same. Simply pass through. 260 if (m == -1) { 261 rewriter.replaceOp(op, op.source()); 262 return success(); 263 } 264 265 // Any non-matching dimension forces a stretch along this rank. 266 // For example: 267 // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32> 268 // becomes: 269 // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32> 270 // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32> 271 // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32> 272 // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32> 273 // %x = [%a,%b,%c,%d] 274 // becomes: 275 // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32> 276 // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32> 277 // %a = [%u, %v] 278 // .. 279 // %x = [%a,%b,%c,%d] 280 VectorType resType = 281 VectorType::get(dstType.getShape().drop_front(), eltType); 282 Value result = rewriter.create<arith::ConstantOp>( 283 loc, dstType, rewriter.getZeroAttr(dstType)); 284 if (m == 0) { 285 // Stetch at start. 286 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0); 287 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 288 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 289 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 290 } else { 291 // Stetch not at start. 292 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) { 293 Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d); 294 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 295 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 296 } 297 } 298 rewriter.replaceOp(op, result); 299 return success(); 300 } 301 }; 302 303 /// Progressive lowering of TransposeOp. 304 /// One: 305 /// %x = vector.transpose %y, [1, 0] 306 /// is replaced by: 307 /// %z = arith.constant dense<0.000000e+00> 308 /// %0 = vector.extract %y[0, 0] 309 /// %1 = vector.insert %0, %z [0, 0] 310 /// .. 311 /// %x = vector.insert .., .. [.., ..] 312 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { 313 public: 314 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 315 316 TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions, 317 MLIRContext *context) 318 : OpRewritePattern<vector::TransposeOp>(context), 319 vectorTransformOptions(vectorTransformOptions) {} 320 321 LogicalResult matchAndRewrite(vector::TransposeOp op, 322 PatternRewriter &rewriter) const override { 323 auto loc = op.getLoc(); 324 325 VectorType resType = op.getResultType(); 326 327 // Set up convenience transposition table. 328 SmallVector<int64_t, 4> transp; 329 for (auto attr : op.transp()) 330 transp.push_back(attr.cast<IntegerAttr>().getInt()); 331 332 if (vectorTransformOptions.vectorTransposeLowering == 333 vector::VectorTransposeLowering::Shuffle && 334 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) 335 return rewriter.notifyMatchFailure( 336 op, "Options specifies lowering to shuffle"); 337 338 // Handle a true 2-D matrix transpose differently when requested. 339 if (vectorTransformOptions.vectorTransposeLowering == 340 vector::VectorTransposeLowering::Flat && 341 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { 342 Type flattenedType = 343 VectorType::get(resType.getNumElements(), resType.getElementType()); 344 auto matrix = 345 rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector()); 346 auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); 347 auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); 348 Value trans = rewriter.create<vector::FlatTransposeOp>( 349 loc, flattenedType, matrix, rows, columns); 350 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans); 351 return success(); 352 } 353 354 // Generate fully unrolled extract/insert ops. 355 Value result = rewriter.create<arith::ConstantOp>( 356 loc, resType, rewriter.getZeroAttr(resType)); 357 SmallVector<int64_t, 4> lhs(transp.size(), 0); 358 SmallVector<int64_t, 4> rhs(transp.size(), 0); 359 rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs, 360 op.vector(), result, rewriter)); 361 return success(); 362 } 363 364 private: 365 // Builds the indices arrays for the lhs and rhs. Generates the extract/insert 366 // operation when al ranks are exhausted. 367 Value expandIndices(Location loc, VectorType resType, int64_t pos, 368 SmallVector<int64_t, 4> &transp, 369 SmallVector<int64_t, 4> &lhs, 370 SmallVector<int64_t, 4> &rhs, Value input, Value result, 371 PatternRewriter &rewriter) const { 372 if (pos >= resType.getRank()) { 373 auto ridx = rewriter.getI64ArrayAttr(rhs); 374 auto lidx = rewriter.getI64ArrayAttr(lhs); 375 Type eltType = resType.getElementType(); 376 Value e = rewriter.create<vector::ExtractOp>(loc, eltType, input, ridx); 377 return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx); 378 } 379 for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) { 380 lhs[pos] = d; 381 rhs[transp[pos]] = d; 382 result = expandIndices(loc, resType, pos + 1, transp, lhs, rhs, input, 383 result, rewriter); 384 } 385 return result; 386 } 387 388 /// Options to control the vector patterns. 389 vector::VectorTransformsOptions vectorTransformOptions; 390 }; 391 392 /// Rewrite a 2-D vector.transpose as a sequence of: 393 /// vector.shape_cast 2D -> 1D 394 /// vector.shuffle 395 /// vector.shape_cast 1D -> 2D 396 class TransposeOp2DToShuffleLowering 397 : public OpRewritePattern<vector::TransposeOp> { 398 public: 399 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 400 401 TransposeOp2DToShuffleLowering( 402 vector::VectorTransformsOptions vectorTransformOptions, 403 MLIRContext *context) 404 : OpRewritePattern<vector::TransposeOp>(context), 405 vectorTransformOptions(vectorTransformOptions) {} 406 407 LogicalResult matchAndRewrite(vector::TransposeOp op, 408 PatternRewriter &rewriter) const override { 409 auto loc = op.getLoc(); 410 411 VectorType srcType = op.getVectorType(); 412 if (srcType.getRank() != 2) 413 return rewriter.notifyMatchFailure(op, "Not a 2D transpose"); 414 415 SmallVector<int64_t, 4> transp; 416 for (auto attr : op.transp()) 417 transp.push_back(attr.cast<IntegerAttr>().getInt()); 418 if (transp[0] != 1 && transp[1] != 0) 419 return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation"); 420 421 if (vectorTransformOptions.vectorTransposeLowering != 422 VectorTransposeLowering::Shuffle) 423 return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle"); 424 425 int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); 426 Value casted = rewriter.create<vector::ShapeCastOp>( 427 loc, VectorType::get({m * n}, srcType.getElementType()), op.vector()); 428 SmallVector<int64_t> mask; 429 mask.reserve(m * n); 430 for (int64_t j = 0; j < n; ++j) 431 for (int64_t i = 0; i < m; ++i) 432 mask.push_back(i * n + j); 433 434 Value shuffled = 435 rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask); 436 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(), 437 shuffled); 438 439 return success(); 440 } 441 442 private: 443 /// Options to control the vector patterns. 444 vector::VectorTransformsOptions vectorTransformOptions; 445 }; 446 447 /// Progressive lowering of OuterProductOp. 448 /// One: 449 /// %x = vector.outerproduct %lhs, %rhs, %acc 450 /// is replaced by: 451 /// %z = zero-result 452 /// %0 = vector.extract %lhs[0] 453 /// %1 = vector.broadcast %0 454 /// %2 = vector.extract %acc[0] 455 /// %3 = vector.fma %1, %rhs, %2 456 /// %4 = vector.insert %3, %z[0] 457 /// .. 458 /// %x = vector.insert %.., %..[N-1] 459 /// 460 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { 461 public: 462 using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern; 463 464 LogicalResult matchAndRewrite(vector::OuterProductOp op, 465 PatternRewriter &rewriter) const override { 466 auto loc = op.getLoc(); 467 468 VectorType lhsType = op.getOperandVectorTypeLHS(); 469 VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>(); 470 VectorType resType = op.getVectorType(); 471 Type eltType = resType.getElementType(); 472 bool isInt = eltType.isa<IntegerType, IndexType>(); 473 Value acc = (op.acc().empty()) ? nullptr : op.acc()[0]; 474 vector::CombiningKind kind = op.kind(); 475 476 if (!rhsType) { 477 // Special case: AXPY operation. 478 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs()); 479 Optional<Value> mult = 480 isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter) 481 : genMultF(loc, op.lhs(), b, acc, kind, rewriter); 482 if (!mult.hasValue()) 483 return failure(); 484 rewriter.replaceOp(op, mult.getValue()); 485 return success(); 486 } 487 488 Value result = rewriter.create<arith::ConstantOp>( 489 loc, resType, rewriter.getZeroAttr(resType)); 490 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { 491 auto pos = rewriter.getI64ArrayAttr(d); 492 Value x = rewriter.create<vector::ExtractOp>(loc, eltType, op.lhs(), pos); 493 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x); 494 Value r = nullptr; 495 if (acc) 496 r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos); 497 Optional<Value> m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter) 498 : genMultF(loc, a, op.rhs(), r, kind, rewriter); 499 if (!m.hasValue()) 500 return failure(); 501 result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(), 502 result, pos); 503 } 504 rewriter.replaceOp(op, result); 505 return success(); 506 } 507 508 private: 509 static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc, 510 vector::CombiningKind kind, 511 PatternRewriter &rewriter) { 512 using vector::CombiningKind; 513 514 auto mul = rewriter.create<arith::MulIOp>(loc, x, y); 515 if (!acc) 516 return Optional<Value>(mul); 517 518 if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) 519 // Only valid for floating point types. 520 return Optional<Value>(); 521 522 return makeArithReduction(rewriter, loc, kind, mul, acc); 523 } 524 525 static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc, 526 vector::CombiningKind kind, 527 PatternRewriter &rewriter) { 528 using vector::CombiningKind; 529 530 // Special case for fused multiply-add. 531 if (acc && kind == CombiningKind::ADD) { 532 return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc)); 533 } 534 535 auto mul = rewriter.create<arith::MulFOp>(loc, x, y); 536 537 if (!acc) 538 return Optional<Value>(mul); 539 540 if (kind == CombiningKind::ADD || kind == CombiningKind::AND || 541 kind == CombiningKind::MINUI || kind == CombiningKind::MINSI || 542 kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI || 543 kind == CombiningKind::OR || kind == CombiningKind::XOR) 544 // Already handled or only valid for integer types. 545 return Optional<Value>(); 546 547 return makeArithReduction(rewriter, loc, kind, mul, acc); 548 } 549 }; 550 551 /// Progressive lowering of ConstantMaskOp. 552 /// One: 553 /// %x = vector.constant_mask [a,b] 554 /// is replaced by: 555 /// %z = zero-result 556 /// %l = vector.constant_mask [b] 557 /// %4 = vector.insert %l, %z[0] 558 /// .. 559 /// %x = vector.insert %l, %..[a-1] 560 /// until a one-dimensional vector is reached. All these operations 561 /// will be folded at LLVM IR level. 562 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { 563 public: 564 using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern; 565 566 LogicalResult matchAndRewrite(vector::ConstantMaskOp op, 567 PatternRewriter &rewriter) const override { 568 auto loc = op.getLoc(); 569 auto dstType = op.getType(); 570 auto eltType = dstType.getElementType(); 571 auto dimSizes = op.mask_dim_sizes(); 572 int64_t rank = dstType.getRank(); 573 574 if (rank == 0) { 575 assert(dimSizes.size() == 1 && 576 "Expected exactly one dim size for a 0-D vector"); 577 bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1; 578 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 579 op, dstType, 580 DenseIntElementsAttr::get( 581 VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()), 582 ArrayRef<bool>{value})); 583 return success(); 584 } 585 586 int64_t trueDim = std::min(dstType.getDimSize(0), 587 dimSizes[0].cast<IntegerAttr>().getInt()); 588 589 if (rank == 1) { 590 // Express constant 1-D case in explicit vector form: 591 // [T,..,T,F,..,F]. 592 SmallVector<bool, 4> values(dstType.getDimSize(0)); 593 for (int64_t d = 0; d < trueDim; d++) 594 values[d] = true; 595 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 596 op, dstType, rewriter.getBoolVectorAttr(values)); 597 return success(); 598 } 599 600 VectorType lowType = 601 VectorType::get(dstType.getShape().drop_front(), eltType); 602 SmallVector<int64_t, 4> newDimSizes; 603 for (int64_t r = 1; r < rank; r++) 604 newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt()); 605 Value trueVal = rewriter.create<vector::ConstantMaskOp>( 606 loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); 607 Value result = rewriter.create<arith::ConstantOp>( 608 loc, dstType, rewriter.getZeroAttr(dstType)); 609 for (int64_t d = 0; d < trueDim; d++) { 610 auto pos = rewriter.getI64ArrayAttr(d); 611 result = 612 rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos); 613 } 614 rewriter.replaceOp(op, result); 615 return success(); 616 } 617 }; 618 619 /// Progressive lowering of CreateMaskOp. 620 /// One: 621 /// %x = vector.create_mask %a, ... : vector<dx...> 622 /// is replaced by: 623 /// %l = vector.create_mask ... : vector<...> ; one lower rank 624 /// %0 = arith.cmpi "slt", %ci, %a | 625 /// %1 = select %0, %l, %zeroes | 626 /// %r = vector.insert %1, %pr [i] | d-times 627 /// %x = .... 628 /// until a one-dimensional vector is reached. 629 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { 630 public: 631 using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern; 632 633 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 634 PatternRewriter &rewriter) const override { 635 auto dstType = op.getResult().getType().cast<VectorType>(); 636 int64_t rank = dstType.getRank(); 637 if (rank <= 1) 638 return rewriter.notifyMatchFailure( 639 op, "0-D and 1-D vectors are handled separately"); 640 641 auto loc = op.getLoc(); 642 auto eltType = dstType.getElementType(); 643 int64_t dim = dstType.getDimSize(0); 644 Value idx = op.getOperand(0); 645 646 VectorType lowType = 647 VectorType::get(dstType.getShape().drop_front(), eltType); 648 Value trueVal = rewriter.create<vector::CreateMaskOp>( 649 loc, lowType, op.getOperands().drop_front()); 650 Value falseVal = rewriter.create<arith::ConstantOp>( 651 loc, lowType, rewriter.getZeroAttr(lowType)); 652 Value result = rewriter.create<arith::ConstantOp>( 653 loc, dstType, rewriter.getZeroAttr(dstType)); 654 for (int64_t d = 0; d < dim; d++) { 655 Value bnd = 656 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d)); 657 Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 658 bnd, idx); 659 Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal); 660 auto pos = rewriter.getI64ArrayAttr(d); 661 result = 662 rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos); 663 } 664 rewriter.replaceOp(op, result); 665 return success(); 666 } 667 }; 668 669 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D 670 /// vectors progressively on the way to target llvm.matrix intrinsics. 671 /// This iterates over the most major dimension of the 2-D vector and performs 672 /// rewrites into: 673 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D 674 class ShapeCastOp2DDownCastRewritePattern 675 : public OpRewritePattern<vector::ShapeCastOp> { 676 public: 677 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 678 679 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 680 PatternRewriter &rewriter) const override { 681 auto sourceVectorType = op.getSourceVectorType(); 682 auto resultVectorType = op.getResultVectorType(); 683 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) 684 return failure(); 685 686 auto loc = op.getLoc(); 687 Value desc = rewriter.create<arith::ConstantOp>( 688 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 689 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; 690 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { 691 Value vec = rewriter.create<vector::ExtractOp>(loc, op.source(), i); 692 desc = rewriter.create<vector::InsertStridedSliceOp>( 693 loc, vec, desc, 694 /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); 695 } 696 rewriter.replaceOp(op, desc); 697 return success(); 698 } 699 }; 700 701 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D 702 /// vectors progressively. 703 /// This iterates over the most major dimension of the 2-D vector and performs 704 /// rewrites into: 705 /// vector.extract_strided_slice from 1-D + vector.insert into 2-D 706 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. 707 class ShapeCastOp2DUpCastRewritePattern 708 : public OpRewritePattern<vector::ShapeCastOp> { 709 public: 710 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 711 712 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 713 PatternRewriter &rewriter) const override { 714 auto sourceVectorType = op.getSourceVectorType(); 715 auto resultVectorType = op.getResultVectorType(); 716 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) 717 return failure(); 718 719 auto loc = op.getLoc(); 720 Value desc = rewriter.create<arith::ConstantOp>( 721 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 722 unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; 723 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { 724 Value vec = rewriter.create<vector::ExtractStridedSliceOp>( 725 loc, op.source(), /*offsets=*/i * mostMinorVectorSize, 726 /*sizes=*/mostMinorVectorSize, 727 /*strides=*/1); 728 desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i); 729 } 730 rewriter.replaceOp(op, desc); 731 return success(); 732 } 733 }; 734 735 // We typically should not lower general shape cast operations into data 736 // movement instructions, since the assumption is that these casts are 737 // optimized away during progressive lowering. For completeness, however, 738 // we fall back to a reference implementation that moves all elements 739 // into the right place if we get here. 740 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { 741 public: 742 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 743 744 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 745 PatternRewriter &rewriter) const override { 746 Location loc = op.getLoc(); 747 auto sourceVectorType = op.getSourceVectorType(); 748 auto resultVectorType = op.getResultVectorType(); 749 750 // Special case 2D/1D lowerings with better implementations. 751 // TODO: make is ND/1D to allow generic ND->1D->MD. 752 int64_t srcRank = sourceVectorType.getRank(); 753 int64_t resRank = resultVectorType.getRank(); 754 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2)) 755 return failure(); 756 757 // Generic ShapeCast lowering path goes all the way down to unrolled scalar 758 // extract/insert chains. 759 // TODO: consider evolving the semantics to only allow 1D source or dest and 760 // drop this potentially very expensive lowering. 761 // Compute number of elements involved in the reshape. 762 int64_t numElts = 1; 763 for (int64_t r = 0; r < srcRank; r++) 764 numElts *= sourceVectorType.getDimSize(r); 765 // Replace with data movement operations: 766 // x[0,0,0] = y[0,0] 767 // x[0,0,1] = y[0,1] 768 // x[0,1,0] = y[0,2] 769 // etc., incrementing the two index vectors "row-major" 770 // within the source and result shape. 771 SmallVector<int64_t, 4> srcIdx(srcRank); 772 SmallVector<int64_t, 4> resIdx(resRank); 773 Value result = rewriter.create<arith::ConstantOp>( 774 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 775 for (int64_t i = 0; i < numElts; i++) { 776 if (i != 0) { 777 incIdx(srcIdx, sourceVectorType, srcRank - 1); 778 incIdx(resIdx, resultVectorType, resRank - 1); 779 } 780 Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx); 781 result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx); 782 } 783 rewriter.replaceOp(op, result); 784 return success(); 785 } 786 787 private: 788 static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) { 789 assert(0 <= r && r < tp.getRank()); 790 if (++idx[r] == tp.getDimSize(r)) { 791 idx[r] = 0; 792 incIdx(idx, tp, r - 1); 793 } 794 } 795 }; 796 797 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp. 798 /// Ex: 799 /// ``` 800 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> 801 /// %1 = vector.multi_reduction add, %0 [1] 802 /// : vector<8x32x16xf32> to vector<8x16xf32> 803 /// ``` 804 /// Gets converted to: 805 /// ``` 806 /// %1 = vector.contract {indexing_maps = [ 807 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 808 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 809 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 810 /// iterator_types = ["parallel", "parallel", "reduction"], 811 /// kind = add} %0, %arg1, %cst_f0 812 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 813 /// ``` 814 struct MultiReduceToContract 815 : public OpRewritePattern<vector::MultiDimReductionOp> { 816 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 817 818 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, 819 PatternRewriter &rewriter) const override { 820 if (reduceOp.kind() != vector::CombiningKind::ADD) 821 return failure(); 822 Operation *mulOp = reduceOp.source().getDefiningOp(); 823 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp)) 824 return failure(); 825 SmallVector<bool> reductionMask = reduceOp.getReductionMask(); 826 auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size()); 827 SmallVector<AffineExpr> exprs; 828 SmallVector<StringRef> iteratorTypes; 829 for (const auto &isReduceDim : llvm::enumerate(reductionMask)) { 830 if (!isReduceDim.value()) { 831 iteratorTypes.push_back(getParallelIteratorTypeName()); 832 exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index())); 833 } else { 834 iteratorTypes.push_back(getReductionIteratorTypeName()); 835 } 836 } 837 auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(), 838 /*symCount=*/0, exprs, reduceOp.getContext()); 839 Value zero = rewriter.create<arith::ConstantOp>( 840 reduceOp.getLoc(), reduceOp.getDestType(), 841 rewriter.getZeroAttr(reduceOp.getDestType())); 842 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>( 843 reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero, 844 rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}), 845 rewriter.getStrArrayAttr(iteratorTypes)); 846 return success(); 847 } 848 }; 849 850 /// Merge TransposeOp into ContractionOp user. 851 /// Ex: 852 /// ``` 853 /// %0 = vector.transpose %arg0, [2, 0, 1] 854 /// : vector<32x16x8xf32> to vector<8x32x16xf32> 855 /// %1 = vector.contract {indexing_maps = [ 856 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 857 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 858 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 859 /// iterator_types = ["parallel", "parallel", "reduction"], 860 /// kind = add} %0, %arg1, %cst_f0 861 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 862 /// ``` 863 /// Gets converted to: 864 /// ``` 865 /// %1 = vector.contract {indexing_maps = [ 866 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>, 867 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 868 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 869 /// iterator_types = ["parallel", "parallel", "reduction"], 870 /// kind = add} %arg0, %arg1, %cst_f0 871 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 872 /// ``` 873 struct CombineContractTranspose 874 : public OpRewritePattern<vector::ContractionOp> { 875 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 876 877 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 878 PatternRewriter &rewriter) const override { 879 SmallVector<AffineMap, 4> maps = 880 llvm::to_vector<4>(contractOp.getIndexingMaps()); 881 Value lhs = contractOp.lhs(); 882 Value rhs = contractOp.rhs(); 883 size_t index = 0; 884 bool changed = false; 885 for (Value *operand : {&lhs, &rhs}) { 886 AffineMap &map = maps[index++]; 887 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>(); 888 if (!transposeOp) 889 continue; 890 SmallVector<int64_t> perm; 891 transposeOp.getTransp(perm); 892 AffineMap permutationMap = AffineMap::getPermutationMap( 893 extractVector<unsigned>(transposeOp.transp()), 894 contractOp.getContext()); 895 map = inversePermutation(permutationMap).compose(map); 896 *operand = transposeOp.vector(); 897 changed = true; 898 } 899 if (!changed) 900 return failure(); 901 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 902 contractOp, lhs, rhs, contractOp.acc(), 903 rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); 904 return success(); 905 } 906 }; 907 908 /// Merge BroadcastOp into ContractionOp user. 909 /// Ex: 910 /// ``` 911 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> 912 /// %1 = vector.contract {indexing_maps = [ 913 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 914 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 915 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 916 /// iterator_types = ["parallel", "parallel", "reduction"], 917 /// kind = add} %0, %arg1, %cst_f0 918 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 919 /// ``` 920 /// Gets converted to: 921 /// ``` 922 /// %1 = vector.contract {indexing_maps = [ 923 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 924 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 925 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 926 /// iterator_types = ["parallel", "parallel", "reduction"], 927 /// kind = add} %arg0, %arg1, %cst_f0 928 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 929 /// ``` 930 struct CombineContractBroadcast 931 : public OpRewritePattern<vector::ContractionOp> { 932 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 933 934 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 935 PatternRewriter &rewriter) const override { 936 SmallVector<AffineMap, 4> maps = 937 llvm::to_vector<4>(contractOp.getIndexingMaps()); 938 Value lhs = contractOp.lhs(); 939 Value rhs = contractOp.rhs(); 940 size_t index = 0; 941 bool changed = false; 942 for (Value *operand : {&lhs, &rhs}) { 943 AffineMap &map = maps[index++]; 944 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>(); 945 if (!broadcast) 946 continue; 947 // contractionOp can only take vector as operands. 948 auto srcType = broadcast.getSourceType().dyn_cast<VectorType>(); 949 if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank()) 950 continue; 951 int64_t rankDiff = 952 broadcast.getVectorType().getRank() - srcType.getRank(); 953 bool innerDimBroadcast = false; 954 SmallVector<AffineExpr> originalDims; 955 for (const auto &dim : llvm::enumerate(srcType.getShape())) { 956 if (dim.value() != 957 broadcast.getVectorType().getDimSize(rankDiff + dim.index())) { 958 innerDimBroadcast = true; 959 break; 960 } 961 originalDims.push_back( 962 rewriter.getAffineDimExpr(dim.index() + rankDiff)); 963 } 964 // Contract doesn't support inner dimension broadcast. Once this is 965 // relaxed we can remove this case. 966 if (innerDimBroadcast) 967 continue; 968 AffineMap broadcastMap = 969 AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims, 970 contractOp.getContext()); 971 map = broadcastMap.compose(map); 972 *operand = broadcast.source(); 973 changed = true; 974 } 975 if (!changed) 976 return failure(); 977 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 978 contractOp, lhs, rhs, contractOp.acc(), 979 rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); 980 return success(); 981 } 982 }; 983 984 } // namespace 985 986 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using 987 /// operands `x` and `y`. 988 static Value createAdd(Location loc, Value x, Value y, bool isInt, 989 PatternRewriter &rewriter) { 990 if (isInt) 991 return rewriter.create<arith::AddIOp>(loc, x, y); 992 return rewriter.create<arith::AddFOp>(loc, x, y); 993 } 994 995 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using 996 /// operands `x and `y`. 997 static Value createMul(Location loc, Value x, Value y, bool isInt, 998 PatternRewriter &rewriter) { 999 if (isInt) 1000 return rewriter.create<arith::MulIOp>(loc, x, y); 1001 return rewriter.create<arith::MulFOp>(loc, x, y); 1002 } 1003 1004 namespace mlir { 1005 1006 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1007 /// semantics to: 1008 /// ``` 1009 /// %mta = maybe_transpose 1010 /// %mtb = maybe_transpose 1011 /// %flattened_a = vector.shape_cast %mta 1012 /// %flattened_b = vector.shape_cast %mtb 1013 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 1014 /// %mtd = vector.shape_cast %flattened_d 1015 /// %d = maybe_untranspose %mtd 1016 /// %e = add %c, %d 1017 /// ``` 1018 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 1019 // 1020 /// This only kicks in when VectorTransformsOptions is set to `Matmul`. 1021 /// vector.transpose operations are inserted if the vector.contract op is not a 1022 /// row-major matrix multiply. 1023 LogicalResult 1024 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, 1025 PatternRewriter &rew) const { 1026 // TODO: implement masks 1027 if (llvm::size(op.masks()) != 0) 1028 return failure(); 1029 if (vectorTransformOptions.vectorContractLowering != 1030 vector::VectorContractLowering::Matmul) 1031 return failure(); 1032 if (failed(filter(op))) 1033 return failure(); 1034 1035 auto iteratorTypes = op.iterator_types().getValue(); 1036 if (!isParallelIterator(iteratorTypes[0]) || 1037 !isParallelIterator(iteratorTypes[1]) || 1038 !isReductionIterator(iteratorTypes[2])) 1039 return failure(); 1040 1041 Type elementType = op.getLhsType().getElementType(); 1042 if (!elementType.isIntOrFloat()) 1043 return failure(); 1044 1045 // Perform lhs + rhs transpositions to conform to matmul row-major semantics. 1046 // Bail out if the contraction cannot be put in this form. 1047 MLIRContext *ctx = op.getContext(); 1048 Location loc = op.getLoc(); 1049 AffineExpr m, n, k; 1050 bindDims(rew.getContext(), m, n, k); 1051 // LHS must be A(m, k) or A(k, m). 1052 Value lhs = op.lhs(); 1053 auto lhsMap = op.indexing_maps()[0]; 1054 if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) 1055 lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0}); 1056 else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) 1057 return failure(); 1058 1059 // RHS must be B(k, n) or B(n, k). 1060 Value rhs = op.rhs(); 1061 auto rhsMap = op.indexing_maps()[1]; 1062 if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) 1063 rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0}); 1064 else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) 1065 return failure(); 1066 1067 // At this point lhs and rhs are in row-major. 1068 VectorType lhsType = lhs.getType().cast<VectorType>(); 1069 VectorType rhsType = rhs.getType().cast<VectorType>(); 1070 int64_t lhsRows = lhsType.getDimSize(0); 1071 int64_t lhsColumns = lhsType.getDimSize(1); 1072 int64_t rhsColumns = rhsType.getDimSize(1); 1073 1074 Type flattenedLHSType = 1075 VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); 1076 lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs); 1077 1078 Type flattenedRHSType = 1079 VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); 1080 rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs); 1081 1082 Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns, 1083 rhsColumns); 1084 mul = rew.create<vector::ShapeCastOp>( 1085 loc, 1086 VectorType::get({lhsRows, rhsColumns}, 1087 getElementTypeOrSelf(op.acc().getType())), 1088 mul); 1089 1090 // ACC must be C(m, n) or C(n, m). 1091 auto accMap = op.indexing_maps()[2]; 1092 if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) 1093 mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0}); 1094 else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) 1095 llvm_unreachable("invalid contraction semantics"); 1096 1097 Value res = 1098 elementType.isa<IntegerType>() 1099 ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.acc(), mul)) 1100 : static_cast<Value>(rew.create<arith::AddFOp>(loc, op.acc(), mul)); 1101 1102 rew.replaceOp(op, res); 1103 return success(); 1104 } 1105 1106 namespace { 1107 struct IteratorType { 1108 IteratorType(StringRef strRef) : strRef(strRef) {} 1109 bool isOfType(Attribute attr) const { 1110 auto sAttr = attr.dyn_cast<StringAttr>(); 1111 return sAttr && sAttr.getValue() == strRef; 1112 } 1113 StringRef strRef; 1114 }; 1115 struct Par : public IteratorType { 1116 Par() : IteratorType(getParallelIteratorTypeName()) {} 1117 }; 1118 struct Red : public IteratorType { 1119 Red() : IteratorType(getReductionIteratorTypeName()) {} 1120 }; 1121 1122 /// Generate a vector implementation for matmat, matvec and tmatvec. 1123 /// This unrolls outer-products along the reduction dimension. 1124 struct UnrolledOuterProductGenerator 1125 : public StructuredGenerator<vector::ContractionOp> { 1126 1127 UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op) 1128 : StructuredGenerator<vector::ContractionOp>(builder, op), 1129 kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()), 1130 lhsType(op.getLhsType()) {} 1131 1132 Value t(Value v) { 1133 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1134 return builder.create<vector::TransposeOp>(loc, v, perm); 1135 } 1136 1137 Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) { 1138 assert(reductionSize > 0); 1139 for (int64_t k = 0; k < reductionSize; ++k) { 1140 Value a = builder.create<vector::ExtractOp>(loc, lhs, k); 1141 Value b = builder.create<vector::ExtractOp>(loc, rhs, k); 1142 res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b, 1143 res, kind); 1144 } 1145 return res; 1146 } 1147 1148 /// Two outer parallel, one inner reduction (matmat flavor). 1149 FailureOr<Value> matmat() { 1150 if (!iters({Par(), Par(), Red()})) 1151 return failure(); 1152 // Set up the parallel/reduction structure in the right form. 1153 AffineExpr m, n, k; 1154 bindDims(builder.getContext(), m, n, k); 1155 // Classical row-major matmul: Just permute the lhs. 1156 if (layout({{m, k}, {k, n}, {m, n}})) 1157 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1158 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1159 if (layout({{m, k}, {n, k}, {m, n}})) { 1160 Value tlhs = t(lhs); 1161 return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1)); 1162 } 1163 // No need to permute anything. 1164 if (layout({{k, m}, {k, n}, {m, n}})) 1165 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1166 // Just permute the rhs. 1167 if (layout({{k, m}, {n, k}, {m, n}})) 1168 return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0)); 1169 // Transposed output: swap RHS and LHS. 1170 // Classical row-major matmul: permute the lhs. 1171 if (layout({{m, k}, {k, n}, {n, m}})) 1172 return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1)); 1173 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1174 if (layout({{m, k}, {n, k}, {n, m}})) { 1175 Value trhs = t(rhs); 1176 return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1)); 1177 } 1178 if (layout({{k, m}, {k, n}, {n, m}})) 1179 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1180 if (layout({{k, m}, {n, k}, {n, m}})) 1181 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1182 return failure(); 1183 } 1184 1185 /// One outer parallel, one inner reduction (matvec flavor) 1186 FailureOr<Value> matvec() { 1187 if (!iters({Par(), Red()})) 1188 return failure(); 1189 AffineExpr m, k; 1190 bindDims(builder.getContext(), m, k); 1191 1192 // Case mat-vec: transpose. 1193 if (layout({{m, k}, {k}, {m}})) 1194 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1195 // Case mat-trans-vec: ready to go. 1196 if (layout({{k, m}, {k}, {m}})) 1197 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1198 // Case vec-mat: swap and transpose. 1199 if (layout({{k}, {m, k}, {m}})) 1200 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1201 // Case vec-mat-trans: swap and ready to go. 1202 if (layout({{k}, {k, m}, {m}})) 1203 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1204 return failure(); 1205 } 1206 1207 // 1208 // One outer reduction, one inner parallel (tmatvec flavor) 1209 // 1210 FailureOr<Value> tmatvec() { 1211 if (!iters({Red(), Par()})) 1212 return failure(); 1213 AffineExpr k, m; 1214 bindDims(builder.getContext(), k, m); 1215 1216 // Case mat-vec: transpose. 1217 if (layout({{m, k}, {k}, {m}})) 1218 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1219 // Case mat-trans-vec: ready to go. 1220 if (layout({{k, m}, {k}, {m}})) 1221 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1222 // Case vec-mat: swap and transpose. 1223 if (layout({{k}, {m, k}, {m}})) 1224 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1225 // Case vec-mat-trans: swap and ready to go. 1226 if (layout({{k}, {k, m}, {m}})) 1227 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1228 return failure(); 1229 } 1230 1231 private: 1232 vector::CombiningKind kind; 1233 Value lhs, rhs, res; 1234 VectorType lhsType; 1235 }; 1236 } // namespace 1237 1238 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1239 /// semantics to a reduction_size-unrolled sequence: 1240 /// ``` 1241 /// %at = vector.transpose %a, [1, 0] 1242 /// %bRow0 = vector.extract %b[0] 1243 /// %atRow0 = vector.extract %at[0] 1244 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 1245 /// ... 1246 /// %bRowK = vector.extract %b[K] 1247 /// %atRowK = vector.extract %at[K] 1248 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 1249 /// ``` 1250 /// 1251 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but 1252 /// otherwise supports any layout permutation of the matrix-multiply. 1253 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( 1254 vector::ContractionOp op, PatternRewriter &rewriter) const { 1255 // TODO: implement masks 1256 if (llvm::size(op.masks()) != 0) 1257 return failure(); 1258 1259 if (vectorTransformOptions.vectorContractLowering != 1260 vector::VectorContractLowering::OuterProduct) 1261 return failure(); 1262 1263 if (failed(filter(op))) 1264 return failure(); 1265 1266 UnrolledOuterProductGenerator e(rewriter, op); 1267 FailureOr<Value> matmatRes = e.matmat(); 1268 if (succeeded(matmatRes)) { 1269 rewriter.replaceOp(op, *matmatRes); 1270 return success(); 1271 } 1272 FailureOr<Value> matvecRes = e.matvec(); 1273 if (succeeded(matvecRes)) { 1274 rewriter.replaceOp(op, *matvecRes); 1275 return success(); 1276 } 1277 FailureOr<Value> tmatvecRes = e.tmatvec(); 1278 if (succeeded(tmatvecRes)) { 1279 rewriter.replaceOp(op, *tmatvecRes); 1280 return success(); 1281 } 1282 1283 return failure(); 1284 } 1285 1286 LogicalResult 1287 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, 1288 PatternRewriter &rewriter) const { 1289 // TODO: implement masks 1290 if (llvm::size(op.masks()) != 0) 1291 return failure(); 1292 1293 if (failed(filter(op))) 1294 return failure(); 1295 1296 if (vectorTransformOptions.vectorContractLowering != 1297 vector::VectorContractLowering::Dot) 1298 return failure(); 1299 1300 auto iteratorTypes = op.iterator_types().getValue(); 1301 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1302 Location loc = op.getLoc(); 1303 Value lhs = op.lhs(), rhs = op.rhs(); 1304 1305 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1306 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 1307 AffineExpr m, n, k; 1308 bindDims(rewriter.getContext(), m, n, k); 1309 SmallVector<AffineMap, 4> maps = op.getIndexingMaps(); 1310 // 1311 // In the following we wish to make the reduction dimension innermost so we 1312 // can load vectors and just fmul + reduce into a scalar. 1313 // 1314 if (isParallelIterator(iteratorTypes[0]) && 1315 isParallelIterator(iteratorTypes[1]) && 1316 isReductionIterator(iteratorTypes[2])) { 1317 // 1318 // Two outer parallel, one inner reduction (matmat flavor). 1319 // 1320 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 1321 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1322 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { 1323 // No need to permute anything. 1324 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 1325 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1326 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1327 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 1328 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1329 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 1330 // This is the classical row-major matmul. Just permute the lhs. 1331 Value tmp = lhs; 1332 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1333 rhs = tmp; 1334 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 1335 std::swap(lhs, rhs); 1336 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 1337 Value tmp = lhs; 1338 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1339 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm); 1340 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 1341 Value tmp = rhs; 1342 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1343 lhs = tmp; 1344 } else { 1345 return failure(); 1346 } 1347 } else if (isParallelIterator(iteratorTypes[0]) && 1348 isReductionIterator(iteratorTypes[1])) { 1349 // 1350 // One outer parallel, one inner reduction (matvec flavor) 1351 // 1352 if (maps == infer({{m, n}, {n}, {m}})) { 1353 // No need to permute anything. 1354 } else if (maps == infer({{n, m}, {n}, {m}})) { 1355 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1356 } else if (maps == infer({{n}, {m, n}, {m}})) { 1357 std::swap(lhs, rhs); 1358 } else if (maps == infer({{n}, {n, m}, {m}})) { 1359 std::swap(lhs, rhs); 1360 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1361 } else { 1362 return failure(); 1363 } 1364 } else { 1365 return failure(); 1366 } 1367 1368 VectorType dstType = op.getResultType().cast<VectorType>(); 1369 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && 1370 "Expected dst type of rank 1 or 2"); 1371 1372 unsigned rank = dstType.getRank(); 1373 unsigned dstRows = dstType.getShape()[0]; 1374 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; 1375 1376 // ExtractOp does not allow dynamic indexing, we must unroll explicitly. 1377 Value res = rewriter.create<arith::ConstantOp>(loc, dstType, 1378 rewriter.getZeroAttr(dstType)); 1379 bool isInt = dstType.getElementType().isa<IntegerType>(); 1380 for (unsigned r = 0; r < dstRows; ++r) { 1381 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r); 1382 for (unsigned c = 0; c < dstColumns; ++c) { 1383 Value b = rank == 1 1384 ? rhs 1385 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c); 1386 Value m = createMul(op.getLoc(), a, b, isInt, rewriter); 1387 Value reduced = rewriter.create<vector::ReductionOp>( 1388 op.getLoc(), vector::CombiningKind::ADD, m); 1389 1390 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r} 1391 : SmallVector<int64_t, 2>{r, c}; 1392 res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos); 1393 } 1394 } 1395 if (auto acc = op.acc()) 1396 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1397 rewriter.replaceOp(op, res); 1398 return success(); 1399 } 1400 1401 /// Progressive lowering of ContractionOp. 1402 /// One: 1403 /// %x = vector.contract with at least one free/batch dimension 1404 /// is replaced by: 1405 /// %a = vector.contract with one less free/batch dimension 1406 /// %b = vector.contract with one less free/batch dimension 1407 /// .. 1408 /// %x = combine %a %b .. 1409 /// until a pure contraction is reached (no free/batch dimensions), 1410 /// which is replaced by a dot-product. 1411 /// 1412 /// This only kicks in when either VectorTransformsOptions is set 1413 /// to DOT or when other contraction patterns fail. 1414 // 1415 // TODO: break down into transpose/reshape/cast ops 1416 // when they become available to avoid code dup 1417 // TODO: investigate lowering order impact on performance 1418 LogicalResult 1419 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, 1420 PatternRewriter &rewriter) const { 1421 // TODO: implement masks. 1422 if (llvm::size(op.masks()) != 0) 1423 return failure(); 1424 1425 if (failed(filter(op))) 1426 return failure(); 1427 1428 // TODO: support mixed mode contract lowering. 1429 if (op.getLhsType().getElementType() != 1430 getElementTypeOrSelf(op.getAccType()) || 1431 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) 1432 return failure(); 1433 1434 // TODO: implement benefits, cost models. 1435 MLIRContext *ctx = op.getContext(); 1436 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); 1437 if (succeeded(pat1.matchAndRewrite(op, rewriter))) 1438 return success(); 1439 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); 1440 if (succeeded(pat2.matchAndRewrite(op, rewriter))) 1441 return success(); 1442 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); 1443 if (succeeded(pat3.matchAndRewrite(op, rewriter))) 1444 return success(); 1445 1446 // Find first batch dimension in LHS/RHS, and lower when found. 1447 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap(); 1448 if (!batchDimMap.empty()) { 1449 int64_t lhsIndex = batchDimMap[0].first; 1450 int64_t rhsIndex = batchDimMap[0].second; 1451 rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter)); 1452 return success(); 1453 } 1454 1455 // Collect contracting dimensions. 1456 std::vector<std::pair<int64_t, int64_t>> contractingDimMap = 1457 op.getContractingDimMap(); 1458 DenseSet<int64_t> lhsContractingDimSet; 1459 DenseSet<int64_t> rhsContractingDimSet; 1460 for (auto &dimPair : contractingDimMap) { 1461 lhsContractingDimSet.insert(dimPair.first); 1462 rhsContractingDimSet.insert(dimPair.second); 1463 } 1464 1465 // Find first free dimension in LHS, and lower when found. 1466 VectorType lhsType = op.getLhsType(); 1467 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { 1468 if (lhsContractingDimSet.count(lhsIndex) == 0) { 1469 rewriter.replaceOp( 1470 op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter)); 1471 return success(); 1472 } 1473 } 1474 1475 // Find first free dimension in RHS, and lower when found. 1476 VectorType rhsType = op.getRhsType(); 1477 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { 1478 if (rhsContractingDimSet.count(rhsIndex) == 0) { 1479 rewriter.replaceOp( 1480 op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter)); 1481 return success(); 1482 } 1483 } 1484 1485 // Lower the first remaining reduction dimension. 1486 if (!contractingDimMap.empty()) { 1487 rewriter.replaceOp(op, lowerReduction(op, rewriter)); 1488 return success(); 1489 } 1490 1491 return failure(); 1492 } 1493 1494 // Lower one parallel dimension. 1495 // TODO: consider reusing existing contract unrolling 1496 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op, 1497 int64_t lhsIndex, int64_t rhsIndex, 1498 PatternRewriter &rewriter) const { 1499 VectorType lhsType = op.getLhsType(); 1500 VectorType rhsType = op.getRhsType(); 1501 VectorType resType = op.getResultType().cast<VectorType>(); 1502 // Find the iterator type index and result index. 1503 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1504 int64_t iterIndex = -1; 1505 int64_t dimSize = -1; 1506 if (lhsIndex >= 0) { 1507 iterIndex = iMap[0].getDimPosition(lhsIndex); 1508 assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) && 1509 "parallel index should be free in LHS or batch in LHS/RHS"); 1510 dimSize = lhsType.getDimSize(lhsIndex); 1511 } else { 1512 assert(rhsIndex >= 0 && "missing parallel index"); 1513 iterIndex = iMap[1].getDimPosition(rhsIndex); 1514 dimSize = rhsType.getDimSize(rhsIndex); 1515 } 1516 assert(iterIndex >= 0 && "parallel index not listed in operand mapping"); 1517 Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex); 1518 assert(lookup.hasValue() && "parallel index not listed in reduction"); 1519 int64_t resIndex = lookup.getValue(); 1520 // Construct new iterator types and affine map array attribute. 1521 std::array<AffineMap, 3> lowIndexingMaps = { 1522 adjustMap(iMap[0], iterIndex, rewriter), 1523 adjustMap(iMap[1], iterIndex, rewriter), 1524 adjustMap(iMap[2], iterIndex, rewriter)}; 1525 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1526 auto lowIter = 1527 rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); 1528 // Unroll into a series of lower dimensional vector.contract ops. 1529 Location loc = op.getLoc(); 1530 Value result = rewriter.create<arith::ConstantOp>( 1531 loc, resType, rewriter.getZeroAttr(resType)); 1532 for (int64_t d = 0; d < dimSize; ++d) { 1533 auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); 1534 auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); 1535 auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter); 1536 Value lowContract = rewriter.create<vector::ContractionOp>( 1537 loc, lhs, rhs, acc, lowAffine, lowIter); 1538 result = 1539 reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter); 1540 } 1541 return result; 1542 } 1543 1544 // Lower one reduction dimension. 1545 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op, 1546 PatternRewriter &rewriter) const { 1547 auto loc = op.getLoc(); 1548 VectorType lhsType = op.getLhsType(); 1549 VectorType rhsType = op.getRhsType(); 1550 Type resType = op.getResultType(); 1551 assert(!resType.isa<VectorType>()); 1552 bool isInt = resType.isa<IntegerType>(); 1553 // Use iterator index 0. 1554 int64_t iterIndex = 0; 1555 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1556 Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex); 1557 Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex); 1558 assert(lookupLhs.hasValue() && "missing LHS parallel index"); 1559 assert(lookupRhs.hasValue() && "missing RHS parallel index"); 1560 int64_t lhsIndex = lookupLhs.getValue(); 1561 int64_t rhsIndex = lookupRhs.getValue(); 1562 int64_t dimSize = lhsType.getDimSize(lhsIndex); 1563 assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape"); 1564 // Base case. 1565 if (lhsType.getRank() == 1) { 1566 assert(rhsType.getRank() == 1 && "corrupt contraction"); 1567 Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter); 1568 auto kind = vector::CombiningKind::ADD; 1569 Value res = rewriter.create<vector::ReductionOp>(loc, kind, m); 1570 if (auto acc = op.acc()) 1571 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1572 return res; 1573 } 1574 // Construct new iterator types and affine map array attribute. 1575 std::array<AffineMap, 3> lowIndexingMaps = { 1576 adjustMap(iMap[0], iterIndex, rewriter), 1577 adjustMap(iMap[1], iterIndex, rewriter), 1578 adjustMap(iMap[2], iterIndex, rewriter)}; 1579 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1580 auto lowIter = 1581 rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); 1582 // Unroll into a series of lower dimensional vector.contract ops. 1583 // By feeding the initial accumulator into the first contraction, 1584 // and the result of each contraction into the next, eventually 1585 // the sum of all reductions is computed. 1586 Value result = op.acc(); 1587 for (int64_t d = 0; d < dimSize; ++d) { 1588 auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); 1589 auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); 1590 result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result, 1591 lowAffine, lowIter); 1592 } 1593 return result; 1594 } 1595 1596 } // namespace mlir 1597 1598 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp( 1599 OpBuilder &builder, Operation *op, ArrayRef<Value> ids, 1600 ArrayRef<int64_t> multiplicity, const AffineMap &map) { 1601 OpBuilder::InsertionGuard guard(builder); 1602 builder.setInsertionPointAfter(op); 1603 Location loc = op->getLoc(); 1604 if (op->getNumResults() != 1) 1605 return {}; 1606 Value result = op->getResult(0); 1607 VectorType type = op->getResult(0).getType().dyn_cast<VectorType>(); 1608 if (!type || map.getNumResults() != multiplicity.size()) 1609 return {}; 1610 // For each dimension being distributed check that the size is a multiple of 1611 // the multiplicity. To handle more sizes we would need to support masking. 1612 unsigned multiplictyCount = 0; 1613 for (auto exp : map.getResults()) { 1614 auto affinExp = exp.dyn_cast<AffineDimExpr>(); 1615 if (!affinExp || affinExp.getPosition() >= type.getRank() || 1616 type.getDimSize(affinExp.getPosition()) % 1617 multiplicity[multiplictyCount++] != 1618 0) 1619 return {}; 1620 } 1621 DistributeOps ops; 1622 ops.extract = 1623 builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map); 1624 ops.insert = 1625 builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids); 1626 return ops; 1627 } 1628 1629 /// Progressive lowering of transfer_read. This pattern supports lowering of 1630 /// `vector.transfer_read` to a combination of `vector.load` and 1631 /// `vector.broadcast` if all of the following hold: 1632 /// - Stride of most minor memref dimension must be 1. 1633 /// - Out-of-bounds masking is not required. 1634 /// - If the memref's element type is a vector type then it coincides with the 1635 /// result type. 1636 /// - The permutation map doesn't perform permutation (broadcasting is allowed). 1637 struct TransferReadToVectorLoadLowering 1638 : public OpRewritePattern<vector::TransferReadOp> { 1639 TransferReadToVectorLoadLowering(MLIRContext *context, 1640 llvm::Optional<unsigned> maxRank) 1641 : OpRewritePattern<vector::TransferReadOp>(context), 1642 maxTransferRank(maxRank) {} 1643 1644 LogicalResult matchAndRewrite(vector::TransferReadOp read, 1645 PatternRewriter &rewriter) const override { 1646 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) 1647 return failure(); 1648 1649 SmallVector<unsigned, 4> broadcastedDims; 1650 // Permutations are handled by VectorToSCF or 1651 // populateVectorTransferPermutationMapLoweringPatterns. 1652 // We let the 0-d corner case pass-through as it is supported. 1653 if (!read.permutation_map().isMinorIdentityWithBroadcasting( 1654 &broadcastedDims)) 1655 return failure(); 1656 1657 auto memRefType = read.getShapedType().dyn_cast<MemRefType>(); 1658 if (!memRefType) 1659 return failure(); 1660 1661 // Non-unit strides are handled by VectorToSCF. 1662 if (!vector::isLastMemrefDimUnitStride(memRefType)) 1663 return failure(); 1664 1665 // If there is broadcasting involved then we first load the unbroadcasted 1666 // vector, and then broadcast it with `vector.broadcast`. 1667 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape(); 1668 SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(), 1669 vectorShape.end()); 1670 for (unsigned i : broadcastedDims) 1671 unbroadcastedVectorShape[i] = 1; 1672 VectorType unbroadcastedVectorType = VectorType::get( 1673 unbroadcastedVectorShape, read.getVectorType().getElementType()); 1674 1675 // `vector.load` supports vector types as memref's elements only when the 1676 // resulting vector type is the same as the element type. 1677 auto memrefElTy = memRefType.getElementType(); 1678 if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType) 1679 return failure(); 1680 1681 // Otherwise, element types of the memref and the vector must match. 1682 if (!memrefElTy.isa<VectorType>() && 1683 memrefElTy != read.getVectorType().getElementType()) 1684 return failure(); 1685 1686 // Out-of-bounds dims are handled by MaterializeTransferMask. 1687 if (read.hasOutOfBoundsDim()) 1688 return failure(); 1689 1690 // Create vector load op. 1691 Operation *loadOp; 1692 if (read.mask()) { 1693 Value fill = rewriter.create<vector::SplatOp>( 1694 read.getLoc(), unbroadcastedVectorType, read.padding()); 1695 loadOp = rewriter.create<vector::MaskedLoadOp>( 1696 read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(), 1697 read.mask(), fill); 1698 } else { 1699 loadOp = rewriter.create<vector::LoadOp>(read.getLoc(), 1700 unbroadcastedVectorType, 1701 read.source(), read.indices()); 1702 } 1703 1704 // Insert a broadcasting op if required. 1705 if (!broadcastedDims.empty()) { 1706 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1707 read, read.getVectorType(), loadOp->getResult(0)); 1708 } else { 1709 rewriter.replaceOp(read, loadOp->getResult(0)); 1710 } 1711 1712 return success(); 1713 } 1714 1715 llvm::Optional<unsigned> maxTransferRank; 1716 }; 1717 1718 /// Replace a 0-d vector.load with a memref.load + vector.broadcast. 1719 // TODO: we shouldn't cross the vector/scalar domains just for this 1720 // but atm we lack the infra to avoid it. Possible solutions include: 1721 // - go directly to LLVM + bitcast 1722 // - introduce a bitcast op and likely a new pointer dialect 1723 // - let memref.load/store additionally support the 0-d vector case 1724 // There are still deeper data layout issues lingering even in this 1725 // trivial case (for architectures for which this matters). 1726 struct VectorLoadToMemrefLoadLowering 1727 : public OpRewritePattern<vector::LoadOp> { 1728 using OpRewritePattern<vector::LoadOp>::OpRewritePattern; 1729 1730 LogicalResult matchAndRewrite(vector::LoadOp loadOp, 1731 PatternRewriter &rewriter) const override { 1732 auto vecType = loadOp.getVectorType(); 1733 if (vecType.getNumElements() != 1) 1734 return failure(); 1735 auto memrefLoad = rewriter.create<memref::LoadOp>( 1736 loadOp.getLoc(), loadOp.base(), loadOp.indices()); 1737 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType, 1738 memrefLoad); 1739 return success(); 1740 } 1741 }; 1742 1743 /// Replace a 0-d vector.store with a vector.extractelement + memref.store. 1744 struct VectorStoreToMemrefStoreLowering 1745 : public OpRewritePattern<vector::StoreOp> { 1746 using OpRewritePattern<vector::StoreOp>::OpRewritePattern; 1747 1748 LogicalResult matchAndRewrite(vector::StoreOp storeOp, 1749 PatternRewriter &rewriter) const override { 1750 auto vecType = storeOp.getVectorType(); 1751 if (vecType.getNumElements() != 1) 1752 return failure(); 1753 Value extracted; 1754 if (vecType.getRank() == 0) { 1755 // TODO: Unifiy once ExtractOp supports 0-d vectors. 1756 extracted = rewriter.create<vector::ExtractElementOp>( 1757 storeOp.getLoc(), storeOp.valueToStore()); 1758 } else { 1759 SmallVector<int64_t> indices(vecType.getRank(), 0); 1760 extracted = rewriter.create<vector::ExtractOp>( 1761 storeOp.getLoc(), storeOp.valueToStore(), indices); 1762 } 1763 1764 rewriter.replaceOpWithNewOp<memref::StoreOp>( 1765 storeOp, extracted, storeOp.base(), storeOp.indices()); 1766 return success(); 1767 } 1768 }; 1769 1770 /// Progressive lowering of transfer_write. This pattern supports lowering of 1771 /// `vector.transfer_write` to `vector.store` if all of the following hold: 1772 /// - Stride of most minor memref dimension must be 1. 1773 /// - Out-of-bounds masking is not required. 1774 /// - If the memref's element type is a vector type then it coincides with the 1775 /// type of the written value. 1776 /// - The permutation map is the minor identity map (neither permutation nor 1777 /// broadcasting is allowed). 1778 struct TransferWriteToVectorStoreLowering 1779 : public OpRewritePattern<vector::TransferWriteOp> { 1780 TransferWriteToVectorStoreLowering(MLIRContext *context, 1781 llvm::Optional<unsigned> maxRank) 1782 : OpRewritePattern<vector::TransferWriteOp>(context), 1783 maxTransferRank(maxRank) {} 1784 1785 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 1786 PatternRewriter &rewriter) const override { 1787 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) 1788 return failure(); 1789 1790 // Permutations are handled by VectorToSCF or 1791 // populateVectorTransferPermutationMapLoweringPatterns. 1792 if ( // pass-through for the 0-d corner case. 1793 !write.permutation_map().isMinorIdentity()) 1794 return failure(); 1795 1796 auto memRefType = write.getShapedType().dyn_cast<MemRefType>(); 1797 if (!memRefType) 1798 return failure(); 1799 1800 // Non-unit strides are handled by VectorToSCF. 1801 if (!vector::isLastMemrefDimUnitStride(memRefType)) 1802 return failure(); 1803 1804 // `vector.store` supports vector types as memref's elements only when the 1805 // type of the vector value being written is the same as the element type. 1806 auto memrefElTy = memRefType.getElementType(); 1807 if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType()) 1808 return failure(); 1809 1810 // Otherwise, element types of the memref and the vector must match. 1811 if (!memrefElTy.isa<VectorType>() && 1812 memrefElTy != write.getVectorType().getElementType()) 1813 return failure(); 1814 1815 // Out-of-bounds dims are handled by MaterializeTransferMask. 1816 if (write.hasOutOfBoundsDim()) 1817 return failure(); 1818 if (write.mask()) { 1819 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 1820 write, write.source(), write.indices(), write.mask(), write.vector()); 1821 } else { 1822 rewriter.replaceOpWithNewOp<vector::StoreOp>( 1823 write, write.vector(), write.source(), write.indices()); 1824 } 1825 return success(); 1826 } 1827 1828 llvm::Optional<unsigned> maxTransferRank; 1829 }; 1830 1831 // Returns the values in `arrayAttr` as an integer vector. 1832 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) { 1833 return llvm::to_vector<4>( 1834 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(), 1835 [](IntegerAttr attr) { return attr.getInt(); })); 1836 } 1837 1838 // Shuffles vector.bitcast op after vector.extract op. 1839 // 1840 // This transforms IR like: 1841 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> 1842 // %1 = vector.extract %0[3] : vector<8xf16> 1843 // Into: 1844 // %0 = vector.extract %src[1] : vector<4xf32> 1845 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16> 1846 // %2 = vector.extract %1[1] : vector<2xf16> 1847 struct BubbleDownVectorBitCastForExtract 1848 : public OpRewritePattern<vector::ExtractOp> { 1849 using OpRewritePattern::OpRewritePattern; 1850 1851 LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 1852 PatternRewriter &rewriter) const override { 1853 // Only support extracting scalars for now. 1854 if (extractOp.getVectorType().getRank() != 1) 1855 return failure(); 1856 1857 auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>(); 1858 if (!castOp) 1859 return failure(); 1860 1861 VectorType castSrcType = castOp.getSourceVectorType(); 1862 VectorType castDstType = castOp.getResultVectorType(); 1863 assert(castSrcType.getRank() == castDstType.getRank()); 1864 1865 // Fail to match if we only have one element in the cast op source. 1866 // This is to avoid infinite loop given that this pattern can generate 1867 // such cases. 1868 if (castSrcType.getNumElements() == 1) 1869 return failure(); 1870 1871 // Only support casting to a larger number of elements or now. 1872 // E.g., vector<4xf32> -> vector<8xf16>. 1873 if (castSrcType.getNumElements() > castDstType.getNumElements()) 1874 return failure(); 1875 1876 unsigned expandRatio = 1877 castDstType.getNumElements() / castSrcType.getNumElements(); 1878 1879 auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t { 1880 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 1881 }; 1882 1883 uint64_t index = getFirstIntValue(extractOp.position()); 1884 1885 // Get the single scalar (as a vector) in the source value that packs the 1886 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> 1887 VectorType oneScalarType = 1888 VectorType::get({1}, castSrcType.getElementType()); 1889 Value packedValue = rewriter.create<vector::ExtractOp>( 1890 extractOp.getLoc(), oneScalarType, castOp.source(), 1891 rewriter.getI64ArrayAttr(index / expandRatio)); 1892 1893 // Cast it to a vector with the desired scalar's type. 1894 // E.g. f32 -> vector<2xf16> 1895 VectorType packedType = 1896 VectorType::get({expandRatio}, castDstType.getElementType()); 1897 Value castedValue = rewriter.create<vector::BitCastOp>( 1898 extractOp.getLoc(), packedType, packedValue); 1899 1900 // Finally extract the desired scalar. 1901 rewriter.replaceOpWithNewOp<vector::ExtractOp>( 1902 extractOp, extractOp.getType(), castedValue, 1903 rewriter.getI64ArrayAttr(index % expandRatio)); 1904 1905 return success(); 1906 } 1907 }; 1908 1909 // Shuffles vector.bitcast op after vector.extract_strided_slice op. 1910 // 1911 // This transforms IR like: 1912 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> 1913 // %0 = vector.extract_strided_slice %cast { 1914 // offsets = [4], sizes = [4], strides = [1] 1915 // } : vector<8xf16> to vector<4xf16> 1916 // Into: 1917 // %0 = vector.extract_strided_slice %src { 1918 // offsets = [2], sizes = [2], strides = [1] 1919 // } : vector<4xf32> to vector<2xf32> 1920 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> 1921 struct BubbleDownBitCastForStridedSliceExtract 1922 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 1923 using OpRewritePattern::OpRewritePattern; 1924 1925 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 1926 PatternRewriter &rewriter) const override { 1927 auto castOp = extractOp.vector().getDefiningOp<vector::BitCastOp>(); 1928 if (!castOp) 1929 return failure(); 1930 1931 VectorType castSrcType = castOp.getSourceVectorType(); 1932 VectorType castDstType = castOp.getResultVectorType(); 1933 assert(castSrcType.getRank() == castDstType.getRank()); 1934 1935 int64_t castSrcLastDim = castSrcType.getShape().back(); 1936 int64_t castDstLastDim = castDstType.getShape().back(); 1937 // Require casting to more elements for now; other cases to be implemented. 1938 if (castSrcLastDim > castDstLastDim) 1939 return failure(); 1940 1941 // Only accept all one strides for now. 1942 if (llvm::any_of(extractOp.strides().getAsValueRange<IntegerAttr>(), 1943 [](const APInt &val) { return !val.isOneValue(); })) 1944 return failure(); 1945 1946 unsigned rank = extractOp.getVectorType().getRank(); 1947 assert(castDstLastDim % castSrcLastDim == 0); 1948 int64_t expandRatio = castDstLastDim / castSrcLastDim; 1949 1950 // If we have a less number of offsets than the rank, then implicitly we 1951 // are selecting the full range for the last bitcasted dimension; other 1952 // dimensions aren't affected. Otherwise, we need to scale down the last 1953 // dimension's offset given we are extracting from less elements now. 1954 ArrayAttr newOffsets = extractOp.offsets(); 1955 if (newOffsets.size() == rank) { 1956 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 1957 if (offsets.back() % expandRatio != 0) 1958 return failure(); 1959 offsets.back() = offsets.back() / expandRatio; 1960 newOffsets = rewriter.getI64ArrayAttr(offsets); 1961 } 1962 1963 // Similarly for sizes. 1964 ArrayAttr newSizes = extractOp.sizes(); 1965 if (newSizes.size() == rank) { 1966 SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes); 1967 if (sizes.back() % expandRatio != 0) 1968 return failure(); 1969 sizes.back() = sizes.back() / expandRatio; 1970 newSizes = rewriter.getI64ArrayAttr(sizes); 1971 } 1972 1973 SmallVector<int64_t, 4> dims = 1974 llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape()); 1975 dims.back() = dims.back() / expandRatio; 1976 VectorType newExtractType = 1977 VectorType::get(dims, castSrcType.getElementType()); 1978 1979 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 1980 extractOp.getLoc(), newExtractType, castOp.source(), newOffsets, 1981 newSizes, extractOp.strides()); 1982 1983 rewriter.replaceOpWithNewOp<vector::BitCastOp>( 1984 extractOp, extractOp.getType(), newExtractOp); 1985 1986 return success(); 1987 } 1988 }; 1989 1990 // Shuffles vector.bitcast op before vector.insert_strided_slice op. 1991 // 1992 // This transforms IR like: 1993 // %0 = vector.insert_strided_slice %src, %dst { 1994 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> 1995 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> 1996 // Into: 1997 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32> 1998 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32> 1999 // %2 = vector.insert_strided_slice %src, %dst { 2000 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 2001 struct BubbleUpBitCastForStridedSliceInsert 2002 : public OpRewritePattern<vector::BitCastOp> { 2003 using OpRewritePattern::OpRewritePattern; 2004 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 2005 PatternRewriter &rewriter) const override { 2006 VectorType castSrcType = bitcastOp.getSourceVectorType(); 2007 VectorType castDstType = bitcastOp.getResultVectorType(); 2008 assert(castSrcType.getRank() == castDstType.getRank()); 2009 2010 int64_t castSrcLastDim = castSrcType.getShape().back(); 2011 int64_t castDstLastDim = castDstType.getShape().back(); 2012 // Require casting to less elements for now; other cases to be implemented. 2013 if (castSrcLastDim < castDstLastDim) 2014 return failure(); 2015 2016 assert(castSrcLastDim % castDstLastDim == 0); 2017 int64_t shrinkRatio = castSrcLastDim / castDstLastDim; 2018 2019 auto insertOp = 2020 bitcastOp.source().getDefiningOp<vector::InsertStridedSliceOp>(); 2021 if (!insertOp) 2022 return failure(); 2023 2024 // Only accept all one strides for now. 2025 if (llvm::any_of(insertOp.strides().getAsValueRange<IntegerAttr>(), 2026 [](const APInt &val) { return !val.isOneValue(); })) 2027 return failure(); 2028 2029 unsigned rank = insertOp.getSourceVectorType().getRank(); 2030 // Require insert op to have the same rank for the source and destination 2031 // vector; other cases to be implemented. 2032 if (rank != insertOp.getDestVectorType().getRank()) 2033 return failure(); 2034 2035 ArrayAttr newOffsets = insertOp.offsets(); 2036 assert(newOffsets.size() == rank); 2037 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2038 if (offsets.back() % shrinkRatio != 0) 2039 return failure(); 2040 offsets.back() = offsets.back() / shrinkRatio; 2041 newOffsets = rewriter.getI64ArrayAttr(offsets); 2042 2043 SmallVector<int64_t, 4> srcDims = 2044 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape()); 2045 srcDims.back() = srcDims.back() / shrinkRatio; 2046 VectorType newCastSrcType = 2047 VectorType::get(srcDims, castDstType.getElementType()); 2048 2049 auto newCastSrcOp = rewriter.create<vector::BitCastOp>( 2050 bitcastOp.getLoc(), newCastSrcType, insertOp.source()); 2051 2052 SmallVector<int64_t, 4> dstDims = 2053 llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); 2054 dstDims.back() = dstDims.back() / shrinkRatio; 2055 VectorType newCastDstType = 2056 VectorType::get(dstDims, castDstType.getElementType()); 2057 2058 auto newCastDstOp = rewriter.create<vector::BitCastOp>( 2059 bitcastOp.getLoc(), newCastDstType, insertOp.dest()); 2060 2061 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>( 2062 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, 2063 insertOp.strides()); 2064 2065 return success(); 2066 } 2067 }; 2068 2069 static Value createCastToIndexLike(PatternRewriter &rewriter, Location loc, 2070 Type targetType, Value value) { 2071 if (targetType == value.getType()) 2072 return value; 2073 2074 bool targetIsIndex = targetType.isIndex(); 2075 bool valueIsIndex = value.getType().isIndex(); 2076 if (targetIsIndex ^ valueIsIndex) 2077 return rewriter.create<arith::IndexCastOp>(loc, targetType, value); 2078 2079 auto targetIntegerType = targetType.dyn_cast<IntegerType>(); 2080 auto valueIntegerType = value.getType().dyn_cast<IntegerType>(); 2081 assert(targetIntegerType && valueIntegerType && 2082 "unexpected cast between types other than integers and index"); 2083 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); 2084 2085 if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) 2086 return rewriter.create<arith::ExtSIOp>(loc, targetIntegerType, value); 2087 return rewriter.create<arith::TruncIOp>(loc, targetIntegerType, value); 2088 } 2089 2090 // Helper that returns a vector comparison that constructs a mask: 2091 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 2092 // 2093 // If `dim == 0` then the result will be a 0-D vector. 2094 // 2095 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 2096 // much more compact, IR for this operation, but LLVM eventually 2097 // generates more elaborate instructions for this intrinsic since it 2098 // is very conservative on the boundary conditions. 2099 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, 2100 bool indexOptimizations, int64_t dim, 2101 Value b, Value *off = nullptr) { 2102 auto loc = op->getLoc(); 2103 // If we can assume all indices fit in 32-bit, we perform the vector 2104 // comparison in 32-bit to get a higher degree of SIMD parallelism. 2105 // Otherwise we perform the vector comparison using 64-bit indices. 2106 Type idxType = 2107 indexOptimizations ? rewriter.getI32Type() : rewriter.getI64Type(); 2108 DenseIntElementsAttr indicesAttr; 2109 if (dim == 0 && indexOptimizations) { 2110 indicesAttr = DenseIntElementsAttr::get( 2111 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0}); 2112 } else if (dim == 0) { 2113 indicesAttr = DenseIntElementsAttr::get( 2114 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0}); 2115 } else if (indexOptimizations) { 2116 indicesAttr = rewriter.getI32VectorAttr( 2117 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))); 2118 } else { 2119 indicesAttr = rewriter.getI64VectorAttr( 2120 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))); 2121 } 2122 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 2123 // Add in an offset if requested. 2124 if (off) { 2125 Value o = createCastToIndexLike(rewriter, loc, idxType, *off); 2126 Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o); 2127 indices = rewriter.create<arith::AddIOp>(loc, ov, indices); 2128 } 2129 // Construct the vector comparison. 2130 Value bound = createCastToIndexLike(rewriter, loc, idxType, b); 2131 Value bounds = 2132 rewriter.create<vector::SplatOp>(loc, indices.getType(), bound); 2133 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices, 2134 bounds); 2135 } 2136 2137 template <typename ConcreteOp> 2138 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> { 2139 public: 2140 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt) 2141 : mlir::OpRewritePattern<ConcreteOp>(context), 2142 indexOptimizations(enableIndexOpt) {} 2143 2144 LogicalResult matchAndRewrite(ConcreteOp xferOp, 2145 PatternRewriter &rewriter) const override { 2146 if (!xferOp.hasOutOfBoundsDim()) 2147 return failure(); 2148 2149 if (xferOp.getVectorType().getRank() > 1 || 2150 llvm::size(xferOp.indices()) == 0) 2151 return failure(); 2152 2153 Location loc = xferOp->getLoc(); 2154 VectorType vtp = xferOp.getVectorType(); 2155 2156 // * Create a vector with linear indices [ 0 .. vector_length - 1 ]. 2157 // * Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. 2158 // * Let dim the memref dimension, compute the vector comparison mask 2159 // (in-bounds mask): 2160 // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] 2161 // 2162 // TODO: when the leaf transfer rank is k > 1, we need the last `k` 2163 // dimensions here. 2164 unsigned vecWidth = vtp.getNumElements(); 2165 unsigned lastIndex = llvm::size(xferOp.indices()) - 1; 2166 Value off = xferOp.indices()[lastIndex]; 2167 Value dim = 2168 vector::createOrFoldDimOp(rewriter, loc, xferOp.source(), lastIndex); 2169 Value mask = buildVectorComparison(rewriter, xferOp, indexOptimizations, 2170 vecWidth, dim, &off); 2171 2172 if (xferOp.mask()) { 2173 // Intersect the in-bounds with the mask specified as an op parameter. 2174 mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.mask()); 2175 } 2176 2177 rewriter.updateRootInPlace(xferOp, [&]() { 2178 xferOp.maskMutable().assign(mask); 2179 xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true})); 2180 }); 2181 2182 return success(); 2183 } 2184 2185 private: 2186 const bool indexOptimizations; 2187 }; 2188 2189 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only). 2190 class VectorCreateMaskOpConversion 2191 : public OpRewritePattern<vector::CreateMaskOp> { 2192 public: 2193 explicit VectorCreateMaskOpConversion(MLIRContext *context, 2194 bool enableIndexOpt) 2195 : mlir::OpRewritePattern<vector::CreateMaskOp>(context), 2196 indexOptimizations(enableIndexOpt) {} 2197 2198 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 2199 PatternRewriter &rewriter) const override { 2200 auto dstType = op.getType(); 2201 int64_t rank = dstType.getRank(); 2202 if (rank > 1) 2203 return failure(); 2204 rewriter.replaceOp( 2205 op, buildVectorComparison(rewriter, op, indexOptimizations, 2206 rank == 0 ? 0 : dstType.getDimSize(0), 2207 op.getOperand(0))); 2208 return success(); 2209 } 2210 2211 private: 2212 const bool indexOptimizations; 2213 }; 2214 2215 // Drop inner most contiguous unit dimensions from transfer_read operand. 2216 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> { 2217 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 2218 2219 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 2220 PatternRewriter &rewriter) const override { 2221 // TODO: support 0-d corner case. 2222 if (readOp.getTransferRank() == 0) 2223 return failure(); 2224 2225 // TODO: support mask. 2226 if (readOp.mask()) 2227 return failure(); 2228 2229 auto srcType = readOp.source().getType().dyn_cast<MemRefType>(); 2230 if (!srcType || !srcType.hasStaticShape()) 2231 return failure(); 2232 2233 if (!readOp.permutation_map().isMinorIdentity()) 2234 return failure(); 2235 2236 auto targetType = readOp.getVectorType(); 2237 if (targetType.getRank() <= 1) 2238 return failure(); 2239 2240 SmallVector<int64_t> srcStrides; 2241 int64_t srcOffset; 2242 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) 2243 return failure(); 2244 2245 size_t dimsToDrop = 0; 2246 for (size_t i = 1; i < srcStrides.size(); ++i) { 2247 int dim = srcType.getRank() - i - 1; 2248 if (srcStrides[dim] == 1) { 2249 dimsToDrop++; 2250 } else { 2251 break; 2252 } 2253 } 2254 if (dimsToDrop == 0) 2255 return failure(); 2256 2257 auto resultTargetVecType = 2258 VectorType::get(targetType.getShape().drop_back(dimsToDrop), 2259 targetType.getElementType()); 2260 2261 MemRefType resultMemrefType; 2262 if (srcType.getLayout().getAffineMap().isIdentity()) { 2263 resultMemrefType = MemRefType::get( 2264 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2265 {}, srcType.getMemorySpaceAsInt()); 2266 } else { 2267 AffineMap map = srcType.getLayout().getAffineMap(); 2268 int numResultDims = map.getNumDims() - dimsToDrop; 2269 int numSymbols = map.getNumSymbols(); 2270 for (size_t i = 0; i < dimsToDrop; ++i) { 2271 int dim = srcType.getRank() - i - 1; 2272 map = map.replace(rewriter.getAffineDimExpr(dim), 2273 rewriter.getAffineConstantExpr(0), numResultDims, 2274 numSymbols); 2275 } 2276 resultMemrefType = MemRefType::get( 2277 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2278 map, srcType.getMemorySpaceAsInt()); 2279 } 2280 2281 auto loc = readOp.getLoc(); 2282 SmallVector<int64_t> offsets(srcType.getRank(), 0); 2283 SmallVector<int64_t> strides(srcType.getRank(), 1); 2284 2285 ArrayAttr inBoundsAttr = 2286 readOp.in_bounds() 2287 ? rewriter.getArrayAttr( 2288 readOp.in_boundsAttr().getValue().drop_back(dimsToDrop)) 2289 : ArrayAttr(); 2290 Value rankedReducedView = rewriter.create<memref::SubViewOp>( 2291 loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(), 2292 strides); 2293 auto permMap = getTransferMinorIdentityMap( 2294 rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType); 2295 Value result = rewriter.create<vector::TransferReadOp>( 2296 loc, resultTargetVecType, rankedReducedView, 2297 readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), 2298 readOp.padding(), 2299 // TODO: support mask. 2300 /*mask=*/Value(), inBoundsAttr); 2301 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType, 2302 result); 2303 return success(); 2304 } 2305 }; 2306 2307 namespace { 2308 2309 /// This function checks to see if the vector combining kind 2310 /// is consistent with the integer or float element type. 2311 static bool isValidKind(bool isInt, vector::CombiningKind kind) { 2312 using vector::CombiningKind; 2313 enum class KindType { FLOAT, INT, INVALID }; 2314 KindType type{KindType::INVALID}; 2315 switch (kind) { 2316 case CombiningKind::MINF: 2317 case CombiningKind::MAXF: 2318 type = KindType::FLOAT; 2319 break; 2320 case CombiningKind::MINUI: 2321 case CombiningKind::MINSI: 2322 case CombiningKind::MAXUI: 2323 case CombiningKind::MAXSI: 2324 case CombiningKind::AND: 2325 case CombiningKind::OR: 2326 case CombiningKind::XOR: 2327 type = KindType::INT; 2328 break; 2329 case CombiningKind::ADD: 2330 case CombiningKind::MUL: 2331 type = isInt ? KindType::INT : KindType::FLOAT; 2332 break; 2333 } 2334 bool isValidIntKind = (type == KindType::INT) && isInt; 2335 bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt); 2336 return (isValidIntKind || isValidFloatKind); 2337 } 2338 2339 /// This function constructs the appropriate integer or float 2340 /// operation given the vector combining kind and operands. The 2341 /// supported int operations are : add, mul, min (signed/unsigned), 2342 /// max(signed/unsigned), and, or, xor. The supported float 2343 /// operations are : add, mul, min and max. 2344 static Value genOperator(Location loc, Value x, Value y, 2345 vector::CombiningKind kind, 2346 PatternRewriter &rewriter) { 2347 using vector::CombiningKind; 2348 2349 auto elType = x.getType().cast<VectorType>().getElementType(); 2350 bool isInt = elType.isIntOrIndex(); 2351 2352 Value combinedResult{nullptr}; 2353 switch (kind) { 2354 case CombiningKind::ADD: 2355 if (isInt) 2356 combinedResult = rewriter.create<arith::AddIOp>(loc, x, y); 2357 else 2358 combinedResult = rewriter.create<arith::AddFOp>(loc, x, y); 2359 break; 2360 case CombiningKind::MUL: 2361 if (isInt) 2362 combinedResult = rewriter.create<arith::MulIOp>(loc, x, y); 2363 else 2364 combinedResult = rewriter.create<arith::MulFOp>(loc, x, y); 2365 break; 2366 case CombiningKind::MINUI: 2367 combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y); 2368 break; 2369 case CombiningKind::MINSI: 2370 combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y); 2371 break; 2372 case CombiningKind::MAXUI: 2373 combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y); 2374 break; 2375 case CombiningKind::MAXSI: 2376 combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y); 2377 break; 2378 case CombiningKind::AND: 2379 combinedResult = rewriter.create<arith::AndIOp>(loc, x, y); 2380 break; 2381 case CombiningKind::OR: 2382 combinedResult = rewriter.create<arith::OrIOp>(loc, x, y); 2383 break; 2384 case CombiningKind::XOR: 2385 combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y); 2386 break; 2387 case CombiningKind::MINF: 2388 combinedResult = rewriter.create<arith::MinFOp>(loc, x, y); 2389 break; 2390 case CombiningKind::MAXF: 2391 combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y); 2392 break; 2393 } 2394 return combinedResult; 2395 } 2396 2397 /// Convert vector.scan op into arith ops and 2398 /// vector.insert_strided_slice/extract_strided_slice 2399 /// 2400 /// Ex: 2401 /// ``` 2402 /// %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 2403 /// 1} : 2404 /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>) 2405 /// ``` 2406 /// Gets converted to: 2407 /// ``` 2408 /// %cst = arith.constant dense<0> : vector<2x3xi32> 2409 /// %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1], 2410 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 = 2411 /// vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]} 2412 /// : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice 2413 /// %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : 2414 /// vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 : 2415 /// vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1], 2416 /// strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 = 2417 /// vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1], 2418 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3, 2419 /// %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets = 2420 /// [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 = 2421 /// vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 : 2422 /// vector<2x3xi32>, vector<2xi32> 2423 /// ``` 2424 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { 2425 using OpRewritePattern<vector::ScanOp>::OpRewritePattern; 2426 2427 LogicalResult matchAndRewrite(vector::ScanOp scanOp, 2428 PatternRewriter &rewriter) const override { 2429 auto loc = scanOp.getLoc(); 2430 VectorType destType = scanOp.getDestType(); 2431 ArrayRef<int64_t> destShape = destType.getShape(); 2432 auto elType = destType.getElementType(); 2433 bool isInt = elType.isIntOrIndex(); 2434 if (!isValidKind(isInt, scanOp.kind())) 2435 return failure(); 2436 2437 VectorType resType = VectorType::get(destShape, elType); 2438 Value result = rewriter.create<arith::ConstantOp>( 2439 loc, resType, rewriter.getZeroAttr(resType)); 2440 int64_t reductionDim = scanOp.reduction_dim(); 2441 bool inclusive = scanOp.inclusive(); 2442 int64_t destRank = destType.getRank(); 2443 VectorType initialValueType = scanOp.getInitialValueType(); 2444 int64_t initialValueRank = initialValueType.getRank(); 2445 2446 SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end()); 2447 reductionShape[reductionDim] = 1; 2448 VectorType reductionType = VectorType::get(reductionShape, elType); 2449 SmallVector<int64_t> offsets(destRank, 0); 2450 SmallVector<int64_t> strides(destRank, 1); 2451 SmallVector<int64_t> sizes(destShape.begin(), destShape.end()); 2452 sizes[reductionDim] = 1; 2453 ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes); 2454 ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides); 2455 2456 Value lastOutput, lastInput; 2457 for (int i = 0; i < destShape[reductionDim]; i++) { 2458 offsets[reductionDim] = i; 2459 ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets); 2460 Value input = rewriter.create<vector::ExtractStridedSliceOp>( 2461 loc, reductionType, scanOp.source(), scanOffsets, scanSizes, 2462 scanStrides); 2463 Value output; 2464 if (i == 0) { 2465 if (inclusive) { 2466 output = input; 2467 } else { 2468 if (initialValueRank == 0) { 2469 // ShapeCastOp cannot handle 0-D vectors 2470 output = rewriter.create<vector::BroadcastOp>( 2471 loc, input.getType(), scanOp.initial_value()); 2472 } else { 2473 output = rewriter.create<vector::ShapeCastOp>( 2474 loc, input.getType(), scanOp.initial_value()); 2475 } 2476 } 2477 } else { 2478 Value y = inclusive ? input : lastInput; 2479 output = genOperator(loc, lastOutput, y, scanOp.kind(), rewriter); 2480 assert(output != nullptr); 2481 } 2482 result = rewriter.create<vector::InsertStridedSliceOp>( 2483 loc, output, result, offsets, strides); 2484 lastOutput = output; 2485 lastInput = input; 2486 } 2487 2488 Value reduction; 2489 if (initialValueRank == 0) { 2490 Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0); 2491 reduction = 2492 rewriter.create<vector::BroadcastOp>(loc, initialValueType, v); 2493 } else { 2494 reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType, 2495 lastOutput); 2496 } 2497 2498 rewriter.replaceOp(scanOp, {result, reduction}); 2499 return success(); 2500 } 2501 }; 2502 2503 } // namespace 2504 2505 void mlir::vector::populateVectorMaskMaterializationPatterns( 2506 RewritePatternSet &patterns, bool indexOptimizations) { 2507 patterns.add<VectorCreateMaskOpConversion, 2508 MaterializeTransferMask<vector::TransferReadOp>, 2509 MaterializeTransferMask<vector::TransferWriteOp>>( 2510 patterns.getContext(), indexOptimizations); 2511 } 2512 2513 void mlir::vector::populateShapeCastFoldingPatterns( 2514 RewritePatternSet &patterns) { 2515 patterns.add<ShapeCastOpFolder>(patterns.getContext()); 2516 } 2517 2518 void mlir::vector::populateBubbleVectorBitCastOpPatterns( 2519 RewritePatternSet &patterns) { 2520 patterns.add<BubbleDownVectorBitCastForExtract, 2521 BubbleDownBitCastForStridedSliceExtract, 2522 BubbleUpBitCastForStridedSliceInsert>(patterns.getContext()); 2523 } 2524 2525 void mlir::vector::populateVectorBroadcastLoweringPatterns( 2526 RewritePatternSet &patterns) { 2527 patterns.add<BroadcastOpLowering>(patterns.getContext()); 2528 } 2529 2530 void mlir::vector::populateVectorMaskOpLoweringPatterns( 2531 RewritePatternSet &patterns) { 2532 patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>( 2533 patterns.getContext()); 2534 } 2535 2536 void mlir::vector::populateVectorShapeCastLoweringPatterns( 2537 RewritePatternSet &patterns) { 2538 patterns.add<ShapeCastOp2DDownCastRewritePattern, 2539 ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>( 2540 patterns.getContext()); 2541 } 2542 2543 void mlir::vector::populateVectorContractLoweringPatterns( 2544 RewritePatternSet &patterns, VectorTransformsOptions options) { 2545 patterns.add<OuterProductOpLowering>(patterns.getContext()); 2546 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering, 2547 ContractionOpToOuterProductOpLowering>(options, 2548 patterns.getContext()); 2549 } 2550 2551 void mlir::vector::populateVectorTransposeLoweringPatterns( 2552 RewritePatternSet &patterns, VectorTransformsOptions options) { 2553 patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>( 2554 options, patterns.getContext()); 2555 } 2556 2557 void mlir::vector::populateVectorReductionToContractPatterns( 2558 RewritePatternSet &patterns) { 2559 patterns.add<MultiReduceToContract, CombineContractBroadcast, 2560 CombineContractTranspose>(patterns.getContext()); 2561 } 2562 2563 void mlir::vector:: 2564 populateVectorTransferCollapseInnerMostContiguousDimsPatterns( 2565 RewritePatternSet &patterns) { 2566 patterns.add<DropInnerMostUnitDims>(patterns.getContext()); 2567 } 2568 2569 void mlir::vector::populateVectorTransferLoweringPatterns( 2570 RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) { 2571 patterns.add<TransferReadToVectorLoadLowering, 2572 TransferWriteToVectorStoreLowering>(patterns.getContext(), 2573 maxTransferRank); 2574 patterns 2575 .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>( 2576 patterns.getContext()); 2577 } 2578 2579 void mlir::vector::populateVectorScanLoweringPatterns( 2580 RewritePatternSet &patterns) { 2581 patterns.add<ScanToArithOps>(patterns.getContext()); 2582 } 2583