1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites as 1->N patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 14 15 #include <type_traits> 16 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 19 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 20 #include "mlir/Dialect/Linalg/IR/Linalg.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/IR/SCF.h" 23 #include "mlir/Dialect/Utils/IndexingUtils.h" 24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 26 #include "mlir/IR/BuiltinTypes.h" 27 #include "mlir/IR/ImplicitLocOpBuilder.h" 28 #include "mlir/IR/Matchers.h" 29 #include "mlir/IR/PatternMatch.h" 30 #include "mlir/Interfaces/VectorInterfaces.h" 31 32 #include "llvm/ADT/DenseSet.h" 33 #include "llvm/ADT/MapVector.h" 34 #include "llvm/ADT/STLExtras.h" 35 #include "llvm/ADT/SmallBitVector.h" 36 #include "llvm/Support/CommandLine.h" 37 #include "llvm/Support/Debug.h" 38 #include "llvm/Support/raw_ostream.h" 39 40 #define DEBUG_TYPE "vector-to-vector" 41 42 using namespace mlir; 43 using namespace mlir::vector; 44 45 // Helper to find an index in an affine map. 46 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) { 47 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 48 int64_t idx = map.getDimPosition(i); 49 if (idx == index) 50 return i; 51 } 52 return None; 53 } 54 55 // Helper to construct iterator types with one index removed. 56 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes, 57 int64_t index) { 58 SmallVector<Attribute, 4> results; 59 for (const auto &it : llvm::enumerate(iteratorTypes)) { 60 int64_t idx = it.index(); 61 if (idx == index) 62 continue; 63 results.push_back(it.value()); 64 } 65 return results; 66 } 67 68 // Helper to construct an affine map with one index removed. 69 static AffineMap adjustMap(AffineMap map, int64_t index, 70 PatternRewriter &rewriter) { 71 auto *ctx = rewriter.getContext(); 72 SmallVector<AffineExpr, 4> results; 73 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 74 int64_t idx = map.getDimPosition(i); 75 if (idx == index) 76 continue; 77 // Re-insert remaining indices, but renamed when occurring 78 // after the removed index. 79 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); 80 results.push_back(targetExpr); 81 } 82 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); 83 } 84 85 // Helper method to possibly drop a dimension in a load. 86 // TODO 87 static Value reshapeLoad(Location loc, Value val, VectorType type, 88 int64_t index, int64_t pos, 89 PatternRewriter &rewriter) { 90 if (index == -1) 91 return val; 92 Type lowType = VectorType::Builder(type).dropDim(0); 93 // At extraction dimension? 94 if (index == 0) { 95 auto posAttr = rewriter.getI64ArrayAttr(pos); 96 return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr); 97 } 98 // Unroll leading dimensions. 99 VectorType vType = lowType.cast<VectorType>(); 100 Type resType = VectorType::Builder(type).dropDim(index); 101 auto resVectorType = resType.cast<VectorType>(); 102 Value result = rewriter.create<arith::ConstantOp>( 103 loc, resVectorType, rewriter.getZeroAttr(resVectorType)); 104 for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { 105 auto posAttr = rewriter.getI64ArrayAttr(d); 106 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr); 107 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); 108 result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, 109 posAttr); 110 } 111 return result; 112 } 113 114 // Helper method to possibly drop a dimension in a store. 115 // TODO 116 static Value reshapeStore(Location loc, Value val, Value result, 117 VectorType type, int64_t index, int64_t pos, 118 PatternRewriter &rewriter) { 119 // Unmodified? 120 if (index == -1) 121 return val; 122 // At insertion dimension? 123 if (index == 0) { 124 auto posAttr = rewriter.getI64ArrayAttr(pos); 125 return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr); 126 } 127 // Unroll leading dimensions. 128 Type lowType = VectorType::Builder(type).dropDim(0); 129 VectorType vType = lowType.cast<VectorType>(); 130 Type insType = VectorType::Builder(vType).dropDim(0); 131 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { 132 auto posAttr = rewriter.getI64ArrayAttr(d); 133 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr); 134 Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr); 135 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); 136 result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr); 137 } 138 return result; 139 } 140 141 template <typename IntType> 142 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) { 143 return llvm::to_vector<4>(llvm::map_range( 144 arrayAttr.getAsRange<IntegerAttr>(), 145 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); 146 } 147 148 /// Helper to create arithmetic operation associated with a kind of contraction. 149 static Optional<Value> createContractArithOp(Location loc, Value x, Value y, 150 Value acc, 151 vector::CombiningKind kind, 152 PatternRewriter &rewriter, 153 bool isInt) { 154 using vector::CombiningKind; 155 Value mul; 156 if (isInt) { 157 if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) 158 // Only valid for floating point types. 159 return Optional<Value>(); 160 mul = rewriter.create<arith::MulIOp>(loc, x, y); 161 } else { 162 // Float case. 163 if (kind == CombiningKind::AND || kind == CombiningKind::MINUI || 164 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI || 165 kind == CombiningKind::MAXSI || kind == CombiningKind::OR || 166 kind == CombiningKind::XOR) 167 // Only valid for integer types. 168 return Optional<Value>(); 169 // Special case for fused multiply-add. 170 if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) { 171 return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc)); 172 } 173 mul = rewriter.create<arith::MulFOp>(loc, x, y); 174 } 175 if (!acc) 176 return Optional<Value>(mul); 177 return makeArithReduction(rewriter, loc, kind, mul, acc); 178 } 179 180 /// Return the positions of the reductions in the given map. 181 static SmallVector<int64_t> getReductionIndex(AffineMap map, 182 ArrayAttr iteratorTypes) { 183 SmallVector<int64_t> dimsIdx; 184 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 185 if (isReductionIterator(iteratorTypes[map.getDimPosition(i)])) 186 dimsIdx.push_back(i); 187 } 188 return dimsIdx; 189 } 190 191 /// Look for a given dimension in an affine map and return its position. Return 192 /// llvm::None if the dimension is not in the map results. 193 static llvm::Optional<unsigned> getDimPosition(AffineMap map, unsigned dim) { 194 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 195 if (map.getDimPosition(i) == dim) 196 return i; 197 } 198 return llvm::None; 199 } 200 201 namespace { 202 203 /// ShapeCastOpFolder folds cancelling ShapeCastOps away. 204 // 205 // Example: 206 // 207 // The following MLIR with cancelling ShapeCastOps: 208 // 209 // %0 = source : vector<5x4x2xf32> 210 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> 211 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> 212 // %3 = user %2 : vector<5x4x2xf32> 213 // 214 // Should canonicalize to the following: 215 // 216 // %0 = source : vector<5x4x2xf32> 217 // %1 = user %0 : vector<5x4x2xf32> 218 // 219 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> { 220 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 221 222 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, 223 PatternRewriter &rewriter) const override { 224 // Check if 'shapeCastOp' has vector source/result type. 225 auto sourceVectorType = 226 shapeCastOp.getSource().getType().dyn_cast_or_null<VectorType>(); 227 auto resultVectorType = 228 shapeCastOp.getResult().getType().dyn_cast_or_null<VectorType>(); 229 if (!sourceVectorType || !resultVectorType) 230 return failure(); 231 232 // Check if shape cast op source operand is also a shape cast op. 233 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>( 234 shapeCastOp.getSource().getDefiningOp()); 235 if (!sourceShapeCastOp) 236 return failure(); 237 auto operandSourceVectorType = 238 sourceShapeCastOp.getSource().getType().cast<VectorType>(); 239 auto operandResultVectorType = sourceShapeCastOp.getType(); 240 241 // Check if shape cast operations invert each other. 242 if (operandSourceVectorType != resultVectorType || 243 operandResultVectorType != sourceVectorType) 244 return failure(); 245 246 rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource()); 247 return success(); 248 } 249 }; 250 251 /// Progressive lowering of BroadcastOp. 252 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> { 253 public: 254 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern; 255 256 LogicalResult matchAndRewrite(vector::BroadcastOp op, 257 PatternRewriter &rewriter) const override { 258 auto loc = op.getLoc(); 259 VectorType dstType = op.getVectorType(); 260 VectorType srcType = op.getSourceType().dyn_cast<VectorType>(); 261 Type eltType = dstType.getElementType(); 262 263 // Scalar to any vector can use splat. 264 if (!srcType) { 265 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource()); 266 return success(); 267 } 268 269 // Determine rank of source and destination. 270 int64_t srcRank = srcType.getRank(); 271 int64_t dstRank = dstType.getRank(); 272 273 // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. 274 if (srcRank <= 1 && dstRank == 1) { 275 Value ext; 276 if (srcRank == 0) 277 ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource()); 278 else 279 ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0); 280 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext); 281 return success(); 282 } 283 284 // Duplicate this rank. 285 // For example: 286 // %x = broadcast %y : k-D to n-D, k < n 287 // becomes: 288 // %b = broadcast %y : k-D to (n-1)-D 289 // %x = [%b,%b,%b,%b] : n-D 290 // becomes: 291 // %b = [%y,%y] : (n-1)-D 292 // %x = [%b,%b,%b,%b] : n-D 293 if (srcRank < dstRank) { 294 // Duplication. 295 VectorType resType = 296 VectorType::get(dstType.getShape().drop_front(), eltType); 297 Value bcst = 298 rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource()); 299 Value result = rewriter.create<arith::ConstantOp>( 300 loc, dstType, rewriter.getZeroAttr(dstType)); 301 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 302 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 303 rewriter.replaceOp(op, result); 304 return success(); 305 } 306 307 // Find non-matching dimension, if any. 308 assert(srcRank == dstRank); 309 int64_t m = -1; 310 for (int64_t r = 0; r < dstRank; r++) 311 if (srcType.getDimSize(r) != dstType.getDimSize(r)) { 312 m = r; 313 break; 314 } 315 316 // All trailing dimensions are the same. Simply pass through. 317 if (m == -1) { 318 rewriter.replaceOp(op, op.getSource()); 319 return success(); 320 } 321 322 // Any non-matching dimension forces a stretch along this rank. 323 // For example: 324 // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32> 325 // becomes: 326 // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32> 327 // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32> 328 // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32> 329 // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32> 330 // %x = [%a,%b,%c,%d] 331 // becomes: 332 // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32> 333 // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32> 334 // %a = [%u, %v] 335 // .. 336 // %x = [%a,%b,%c,%d] 337 VectorType resType = 338 VectorType::get(dstType.getShape().drop_front(), eltType); 339 Value result = rewriter.create<arith::ConstantOp>( 340 loc, dstType, rewriter.getZeroAttr(dstType)); 341 if (m == 0) { 342 // Stetch at start. 343 Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0); 344 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 345 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 346 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 347 } else { 348 // Stetch not at start. 349 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) { 350 Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d); 351 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 352 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 353 } 354 } 355 rewriter.replaceOp(op, result); 356 return success(); 357 } 358 }; 359 360 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not 361 /// transposed. 362 void pruneNonTransposedDims(ArrayRef<int64_t> transpose, 363 SmallVectorImpl<int64_t> &result) { 364 size_t numTransposedDims = transpose.size(); 365 for (size_t transpDim : llvm::reverse(transpose)) { 366 if (transpDim != numTransposedDims - 1) 367 break; 368 numTransposedDims--; 369 } 370 371 result.append(transpose.begin(), transpose.begin() + numTransposedDims); 372 } 373 374 /// Progressive lowering of TransposeOp. 375 /// One: 376 /// %x = vector.transpose %y, [1, 0] 377 /// is replaced by: 378 /// %z = arith.constant dense<0.000000e+00> 379 /// %0 = vector.extract %y[0, 0] 380 /// %1 = vector.insert %0, %z [0, 0] 381 /// .. 382 /// %x = vector.insert .., .. [.., ..] 383 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { 384 public: 385 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 386 387 TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions, 388 MLIRContext *context) 389 : OpRewritePattern<vector::TransposeOp>(context), 390 vectorTransformOptions(vectorTransformOptions) {} 391 392 LogicalResult matchAndRewrite(vector::TransposeOp op, 393 PatternRewriter &rewriter) const override { 394 auto loc = op.getLoc(); 395 396 Value input = op.getVector(); 397 VectorType inputType = op.getVectorType(); 398 VectorType resType = op.getResultType(); 399 400 // Set up convenience transposition table. 401 SmallVector<int64_t, 4> transp; 402 for (auto attr : op.getTransp()) 403 transp.push_back(attr.cast<IntegerAttr>().getInt()); 404 405 if (vectorTransformOptions.vectorTransposeLowering == 406 vector::VectorTransposeLowering::Shuffle && 407 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) 408 return rewriter.notifyMatchFailure( 409 op, "Options specifies lowering to shuffle"); 410 411 // Handle a true 2-D matrix transpose differently when requested. 412 if (vectorTransformOptions.vectorTransposeLowering == 413 vector::VectorTransposeLowering::Flat && 414 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { 415 Type flattenedType = 416 VectorType::get(resType.getNumElements(), resType.getElementType()); 417 auto matrix = 418 rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input); 419 auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); 420 auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); 421 Value trans = rewriter.create<vector::FlatTransposeOp>( 422 loc, flattenedType, matrix, rows, columns); 423 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans); 424 return success(); 425 } 426 427 // Generate unrolled extract/insert ops. We do not unroll the rightmost 428 // (i.e., highest-order) dimensions that are not transposed and leave them 429 // in vector form to improve performance. Therefore, we prune those 430 // dimensions from the shape/transpose data structures used to generate the 431 // extract/insert ops. 432 SmallVector<int64_t, 4> prunedTransp; 433 pruneNonTransposedDims(transp, prunedTransp); 434 size_t numPrunedDims = transp.size() - prunedTransp.size(); 435 auto prunedInShape = inputType.getShape().drop_back(numPrunedDims); 436 SmallVector<int64_t, 4> ones(prunedInShape.size(), 1); 437 auto prunedInStrides = computeStrides(prunedInShape, ones); 438 439 // Generates the extract/insert operations for every scalar/vector element 440 // of the leftmost transposed dimensions. We traverse every transpose 441 // element using a linearized index that we delinearize to generate the 442 // appropriate indices for the extract/insert operations. 443 Value result = rewriter.create<arith::ConstantOp>( 444 loc, resType, rewriter.getZeroAttr(resType)); 445 int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape); 446 447 for (int64_t linearIdx = 0; linearIdx < numTransposedElements; 448 ++linearIdx) { 449 auto extractIdxs = delinearize(prunedInStrides, linearIdx); 450 SmallVector<int64_t, 4> insertIdxs(extractIdxs); 451 applyPermutationToVector(insertIdxs, prunedTransp); 452 Value extractOp = 453 rewriter.create<vector::ExtractOp>(loc, input, extractIdxs); 454 result = 455 rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs); 456 } 457 458 rewriter.replaceOp(op, result); 459 return success(); 460 } 461 462 private: 463 /// Options to control the vector patterns. 464 vector::VectorTransformsOptions vectorTransformOptions; 465 }; 466 467 /// Rewrite a 2-D vector.transpose as a sequence of: 468 /// vector.shape_cast 2D -> 1D 469 /// vector.shuffle 470 /// vector.shape_cast 1D -> 2D 471 class TransposeOp2DToShuffleLowering 472 : public OpRewritePattern<vector::TransposeOp> { 473 public: 474 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 475 476 TransposeOp2DToShuffleLowering( 477 vector::VectorTransformsOptions vectorTransformOptions, 478 MLIRContext *context) 479 : OpRewritePattern<vector::TransposeOp>(context), 480 vectorTransformOptions(vectorTransformOptions) {} 481 482 LogicalResult matchAndRewrite(vector::TransposeOp op, 483 PatternRewriter &rewriter) const override { 484 auto loc = op.getLoc(); 485 486 VectorType srcType = op.getVectorType(); 487 if (srcType.getRank() != 2) 488 return rewriter.notifyMatchFailure(op, "Not a 2D transpose"); 489 490 SmallVector<int64_t, 4> transp; 491 for (auto attr : op.getTransp()) 492 transp.push_back(attr.cast<IntegerAttr>().getInt()); 493 if (transp[0] != 1 && transp[1] != 0) 494 return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation"); 495 496 if (vectorTransformOptions.vectorTransposeLowering != 497 VectorTransposeLowering::Shuffle) 498 return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle"); 499 500 int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); 501 Value casted = rewriter.create<vector::ShapeCastOp>( 502 loc, VectorType::get({m * n}, srcType.getElementType()), 503 op.getVector()); 504 SmallVector<int64_t> mask; 505 mask.reserve(m * n); 506 for (int64_t j = 0; j < n; ++j) 507 for (int64_t i = 0; i < m; ++i) 508 mask.push_back(i * n + j); 509 510 Value shuffled = 511 rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask); 512 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(), 513 shuffled); 514 515 return success(); 516 } 517 518 private: 519 /// Options to control the vector patterns. 520 vector::VectorTransformsOptions vectorTransformOptions; 521 }; 522 523 /// Progressive lowering of OuterProductOp. 524 /// One: 525 /// %x = vector.outerproduct %lhs, %rhs, %acc 526 /// is replaced by: 527 /// %z = zero-result 528 /// %0 = vector.extract %lhs[0] 529 /// %1 = vector.broadcast %0 530 /// %2 = vector.extract %acc[0] 531 /// %3 = vector.fma %1, %rhs, %2 532 /// %4 = vector.insert %3, %z[0] 533 /// .. 534 /// %x = vector.insert %.., %..[N-1] 535 /// 536 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { 537 public: 538 using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern; 539 540 LogicalResult matchAndRewrite(vector::OuterProductOp op, 541 PatternRewriter &rewriter) const override { 542 auto loc = op.getLoc(); 543 544 VectorType lhsType = op.getOperandVectorTypeLHS(); 545 VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>(); 546 VectorType resType = op.getVectorType(); 547 Type eltType = resType.getElementType(); 548 bool isInt = eltType.isa<IntegerType, IndexType>(); 549 Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0]; 550 vector::CombiningKind kind = op.getKind(); 551 552 if (!rhsType) { 553 // Special case: AXPY operation. 554 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs()); 555 Optional<Value> mult = createContractArithOp(loc, op.getLhs(), b, acc, 556 kind, rewriter, isInt); 557 if (!mult.hasValue()) 558 return failure(); 559 rewriter.replaceOp(op, mult.getValue()); 560 return success(); 561 } 562 563 Value result = rewriter.create<arith::ConstantOp>( 564 loc, resType, rewriter.getZeroAttr(resType)); 565 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { 566 auto pos = rewriter.getI64ArrayAttr(d); 567 Value x = 568 rewriter.create<vector::ExtractOp>(loc, eltType, op.getLhs(), pos); 569 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x); 570 Value r = nullptr; 571 if (acc) 572 r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos); 573 Optional<Value> m = 574 createContractArithOp(loc, a, op.getRhs(), r, kind, rewriter, isInt); 575 if (!m.hasValue()) 576 return failure(); 577 result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(), 578 result, pos); 579 } 580 rewriter.replaceOp(op, result); 581 return success(); 582 } 583 }; 584 585 /// Lower vector.contract with all size one reduction dimensions to 586 /// elementwise ops when possible. 587 struct ContractOpToElementwise 588 : public OpRewritePattern<vector::ContractionOp> { 589 using OpRewritePattern::OpRewritePattern; 590 using FilterConstraintType = 591 std::function<LogicalResult(vector::ContractionOp op)>; 592 static LogicalResult defaultFilter(vector::ContractionOp op) { 593 return success(); 594 } 595 ContractOpToElementwise( 596 vector::VectorTransformsOptions vectorTransformOptions, 597 MLIRContext *context, 598 const FilterConstraintType &constraint = defaultFilter) 599 : OpRewritePattern<vector::ContractionOp>(context), 600 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} 601 602 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 603 PatternRewriter &rewriter) const override { 604 // TODO: implement masks 605 if (llvm::size(contractOp.getMasks()) != 0) 606 return failure(); 607 608 if (failed(filter(contractOp))) 609 return failure(); 610 611 if (vectorTransformOptions.vectorContractLowering != 612 vector::VectorContractLowering::ParallelArith) 613 return failure(); 614 ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape(); 615 ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape(); 616 AffineMap lhsMap = contractOp.getIndexingMaps()[0]; 617 AffineMap rhsMap = contractOp.getIndexingMaps()[1]; 618 SmallVector<int64_t> lhsReductionDims = 619 getReductionIndex(lhsMap, contractOp.getIteratorTypes()); 620 SmallVector<int64_t> rhsReductionDims = 621 getReductionIndex(rhsMap, contractOp.getIteratorTypes()); 622 // All the reduction dimensions must be a size 1. 623 for (int64_t dim : lhsReductionDims) { 624 if (lhsShape[dim] != 1) 625 return failure(); 626 } 627 for (int64_t dim : rhsReductionDims) { 628 if (rhsShape[dim] != 1) 629 return failure(); 630 } 631 AffineMap accMap = contractOp.getIndexingMaps()[2]; 632 unsigned numParallelDims = accMap.getNumResults(); 633 unsigned numLhsDimToBroadcast = 634 numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size()); 635 unsigned numRhsDimToBroadcast = 636 numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size()); 637 SmallVector<int64_t> lhsDims; 638 SmallVector<int64_t> lhsTranspose; 639 SmallVector<int64_t> rhsDims; 640 SmallVector<int64_t> rhsTranspose; 641 for (int64_t dim : lhsReductionDims) 642 lhsTranspose.push_back(numLhsDimToBroadcast + dim); 643 for (int64_t dim : rhsReductionDims) 644 rhsTranspose.push_back(numRhsDimToBroadcast + dim); 645 // Loop through the parallel dimensions to calculate the dimensions to 646 // broadcast and to permute in order to extract only parallel dimensions. 647 for (unsigned i = 0; i < numParallelDims; i++) { 648 llvm::Optional<unsigned> lhsDim = 649 getDimPosition(lhsMap, accMap.getDimPosition(i)); 650 if (lhsDim) { 651 lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim); 652 } else { 653 // If the parallel dimension doesn't exist we will have to broadcast it. 654 lhsDims.push_back( 655 contractOp.getResultType().cast<VectorType>().getDimSize(i)); 656 lhsTranspose.push_back(lhsDims.size() - 1); 657 } 658 llvm::Optional<unsigned> rhsDim = 659 getDimPosition(rhsMap, accMap.getDimPosition(i)); 660 if (rhsDim) { 661 rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim); 662 } else { 663 // If the parallel dimension doesn't exist we will have to broadcast it. 664 rhsDims.push_back( 665 contractOp.getResultType().cast<VectorType>().getDimSize(i)); 666 rhsTranspose.push_back(rhsDims.size() - 1); 667 } 668 } 669 Value newLhs = contractOp.getLhs(); 670 Value newRhs = contractOp.getRhs(); 671 Location loc = contractOp.getLoc(); 672 if (!lhsDims.empty()) { 673 lhsDims.append(lhsShape.begin(), lhsShape.end()); 674 auto expandedType = 675 VectorType::get(lhsDims, contractOp.getLhsType().getElementType()); 676 newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs); 677 } 678 if (!rhsDims.empty()) { 679 rhsDims.append(rhsShape.begin(), rhsShape.end()); 680 auto expandedType = 681 VectorType::get(rhsDims, contractOp.getRhsType().getElementType()); 682 newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs); 683 } 684 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex(); 685 newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose); 686 newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose); 687 SmallVector<int64_t, 4> lhsOffsets(lhsReductionDims.size(), 0); 688 SmallVector<int64_t, 4> rhsOffsets(rhsReductionDims.size(), 0); 689 newLhs = rewriter.create<vector::ExtractOp>( 690 loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets)); 691 newRhs = rewriter.create<vector::ExtractOp>( 692 loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets)); 693 Optional<Value> result = 694 createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(), 695 contractOp.getKind(), rewriter, isInt); 696 rewriter.replaceOp(contractOp, {*result}); 697 return success(); 698 } 699 700 private: 701 /// Options to control the vector patterns. 702 vector::VectorTransformsOptions vectorTransformOptions; 703 FilterConstraintType filter; 704 }; 705 706 /// Progressive lowering of ConstantMaskOp. 707 /// One: 708 /// %x = vector.constant_mask [a,b] 709 /// is replaced by: 710 /// %z = zero-result 711 /// %l = vector.constant_mask [b] 712 /// %4 = vector.insert %l, %z[0] 713 /// .. 714 /// %x = vector.insert %l, %..[a-1] 715 /// until a one-dimensional vector is reached. All these operations 716 /// will be folded at LLVM IR level. 717 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { 718 public: 719 using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern; 720 721 LogicalResult matchAndRewrite(vector::ConstantMaskOp op, 722 PatternRewriter &rewriter) const override { 723 auto loc = op.getLoc(); 724 auto dstType = op.getType(); 725 auto eltType = dstType.getElementType(); 726 auto dimSizes = op.getMaskDimSizes(); 727 int64_t rank = dstType.getRank(); 728 729 if (rank == 0) { 730 assert(dimSizes.size() == 1 && 731 "Expected exactly one dim size for a 0-D vector"); 732 bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1; 733 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 734 op, dstType, 735 DenseIntElementsAttr::get( 736 VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()), 737 ArrayRef<bool>{value})); 738 return success(); 739 } 740 741 // Scalable constant masks can only be lowered for the "none set" case. 742 if (dstType.cast<VectorType>().isScalable()) { 743 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 744 op, DenseElementsAttr::get(dstType, false)); 745 return success(); 746 } 747 748 int64_t trueDim = std::min(dstType.getDimSize(0), 749 dimSizes[0].cast<IntegerAttr>().getInt()); 750 751 if (rank == 1) { 752 // Express constant 1-D case in explicit vector form: 753 // [T,..,T,F,..,F]. 754 SmallVector<bool, 4> values(dstType.getDimSize(0)); 755 for (int64_t d = 0; d < trueDim; d++) 756 values[d] = true; 757 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 758 op, dstType, rewriter.getBoolVectorAttr(values)); 759 return success(); 760 } 761 762 VectorType lowType = 763 VectorType::get(dstType.getShape().drop_front(), eltType); 764 SmallVector<int64_t, 4> newDimSizes; 765 for (int64_t r = 1; r < rank; r++) 766 newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt()); 767 Value trueVal = rewriter.create<vector::ConstantMaskOp>( 768 loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); 769 Value result = rewriter.create<arith::ConstantOp>( 770 loc, dstType, rewriter.getZeroAttr(dstType)); 771 for (int64_t d = 0; d < trueDim; d++) { 772 auto pos = rewriter.getI64ArrayAttr(d); 773 result = 774 rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos); 775 } 776 rewriter.replaceOp(op, result); 777 return success(); 778 } 779 }; 780 781 /// Progressive lowering of CreateMaskOp. 782 /// One: 783 /// %x = vector.create_mask %a, ... : vector<dx...> 784 /// is replaced by: 785 /// %l = vector.create_mask ... : vector<...> ; one lower rank 786 /// %0 = arith.cmpi "slt", %ci, %a | 787 /// %1 = select %0, %l, %zeroes | 788 /// %r = vector.insert %1, %pr [i] | d-times 789 /// %x = .... 790 /// until a one-dimensional vector is reached. 791 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { 792 public: 793 using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern; 794 795 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 796 PatternRewriter &rewriter) const override { 797 auto dstType = op.getResult().getType().cast<VectorType>(); 798 int64_t rank = dstType.getRank(); 799 if (rank <= 1) 800 return rewriter.notifyMatchFailure( 801 op, "0-D and 1-D vectors are handled separately"); 802 803 auto loc = op.getLoc(); 804 auto eltType = dstType.getElementType(); 805 int64_t dim = dstType.getDimSize(0); 806 Value idx = op.getOperand(0); 807 808 VectorType lowType = 809 VectorType::get(dstType.getShape().drop_front(), eltType); 810 Value trueVal = rewriter.create<vector::CreateMaskOp>( 811 loc, lowType, op.getOperands().drop_front()); 812 Value falseVal = rewriter.create<arith::ConstantOp>( 813 loc, lowType, rewriter.getZeroAttr(lowType)); 814 Value result = rewriter.create<arith::ConstantOp>( 815 loc, dstType, rewriter.getZeroAttr(dstType)); 816 for (int64_t d = 0; d < dim; d++) { 817 Value bnd = 818 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d)); 819 Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 820 bnd, idx); 821 Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal); 822 auto pos = rewriter.getI64ArrayAttr(d); 823 result = 824 rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos); 825 } 826 rewriter.replaceOp(op, result); 827 return success(); 828 } 829 }; 830 831 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D 832 /// vectors progressively on the way to target llvm.matrix intrinsics. 833 /// This iterates over the most major dimension of the 2-D vector and performs 834 /// rewrites into: 835 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D 836 class ShapeCastOp2DDownCastRewritePattern 837 : public OpRewritePattern<vector::ShapeCastOp> { 838 public: 839 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 840 841 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 842 PatternRewriter &rewriter) const override { 843 auto sourceVectorType = op.getSourceVectorType(); 844 auto resultVectorType = op.getResultVectorType(); 845 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) 846 return failure(); 847 848 auto loc = op.getLoc(); 849 Value desc = rewriter.create<arith::ConstantOp>( 850 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 851 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; 852 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { 853 Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i); 854 desc = rewriter.create<vector::InsertStridedSliceOp>( 855 loc, vec, desc, 856 /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); 857 } 858 rewriter.replaceOp(op, desc); 859 return success(); 860 } 861 }; 862 863 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D 864 /// vectors progressively. 865 /// This iterates over the most major dimension of the 2-D vector and performs 866 /// rewrites into: 867 /// vector.extract_strided_slice from 1-D + vector.insert into 2-D 868 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. 869 class ShapeCastOp2DUpCastRewritePattern 870 : public OpRewritePattern<vector::ShapeCastOp> { 871 public: 872 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 873 874 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 875 PatternRewriter &rewriter) const override { 876 auto sourceVectorType = op.getSourceVectorType(); 877 auto resultVectorType = op.getResultVectorType(); 878 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) 879 return failure(); 880 881 auto loc = op.getLoc(); 882 Value desc = rewriter.create<arith::ConstantOp>( 883 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 884 unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; 885 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { 886 Value vec = rewriter.create<vector::ExtractStridedSliceOp>( 887 loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize, 888 /*sizes=*/mostMinorVectorSize, 889 /*strides=*/1); 890 desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i); 891 } 892 rewriter.replaceOp(op, desc); 893 return success(); 894 } 895 }; 896 897 // We typically should not lower general shape cast operations into data 898 // movement instructions, since the assumption is that these casts are 899 // optimized away during progressive lowering. For completeness, however, 900 // we fall back to a reference implementation that moves all elements 901 // into the right place if we get here. 902 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { 903 public: 904 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 905 906 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 907 PatternRewriter &rewriter) const override { 908 Location loc = op.getLoc(); 909 auto sourceVectorType = op.getSourceVectorType(); 910 auto resultVectorType = op.getResultVectorType(); 911 912 // Special case 2D/1D lowerings with better implementations. 913 // TODO: make is ND/1D to allow generic ND->1D->MD. 914 int64_t srcRank = sourceVectorType.getRank(); 915 int64_t resRank = resultVectorType.getRank(); 916 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2)) 917 return failure(); 918 919 // Generic ShapeCast lowering path goes all the way down to unrolled scalar 920 // extract/insert chains. 921 // TODO: consider evolving the semantics to only allow 1D source or dest and 922 // drop this potentially very expensive lowering. 923 // Compute number of elements involved in the reshape. 924 int64_t numElts = 1; 925 for (int64_t r = 0; r < srcRank; r++) 926 numElts *= sourceVectorType.getDimSize(r); 927 // Replace with data movement operations: 928 // x[0,0,0] = y[0,0] 929 // x[0,0,1] = y[0,1] 930 // x[0,1,0] = y[0,2] 931 // etc., incrementing the two index vectors "row-major" 932 // within the source and result shape. 933 SmallVector<int64_t, 4> srcIdx(srcRank); 934 SmallVector<int64_t, 4> resIdx(resRank); 935 Value result = rewriter.create<arith::ConstantOp>( 936 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 937 for (int64_t i = 0; i < numElts; i++) { 938 if (i != 0) { 939 incIdx(srcIdx, sourceVectorType, srcRank - 1); 940 incIdx(resIdx, resultVectorType, resRank - 1); 941 } 942 Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx); 943 result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx); 944 } 945 rewriter.replaceOp(op, result); 946 return success(); 947 } 948 949 private: 950 static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) { 951 assert(0 <= r && r < tp.getRank()); 952 if (++idx[r] == tp.getDimSize(r)) { 953 idx[r] = 0; 954 incIdx(idx, tp, r - 1); 955 } 956 } 957 }; 958 959 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp. 960 /// Ex: 961 /// ``` 962 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> 963 /// %1 = vector.multi_reduction add, %0 [1] 964 /// : vector<8x32x16xf32> to vector<8x16xf32> 965 /// ``` 966 /// Gets converted to: 967 /// ``` 968 /// %1 = vector.contract {indexing_maps = [ 969 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 970 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 971 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 972 /// iterator_types = ["parallel", "parallel", "reduction"], 973 /// kind = add} %0, %arg1, %cst_f0 974 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 975 /// ``` 976 struct MultiReduceToContract 977 : public OpRewritePattern<vector::MultiDimReductionOp> { 978 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 979 980 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, 981 PatternRewriter &rewriter) const override { 982 if (reduceOp.getKind() != vector::CombiningKind::ADD) 983 return failure(); 984 Operation *mulOp = reduceOp.getSource().getDefiningOp(); 985 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp)) 986 return failure(); 987 SmallVector<bool> reductionMask = reduceOp.getReductionMask(); 988 auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size()); 989 SmallVector<AffineExpr> exprs; 990 SmallVector<StringRef> iteratorTypes; 991 for (const auto &isReduceDim : llvm::enumerate(reductionMask)) { 992 if (!isReduceDim.value()) { 993 iteratorTypes.push_back(getParallelIteratorTypeName()); 994 exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index())); 995 } else { 996 iteratorTypes.push_back(getReductionIteratorTypeName()); 997 } 998 } 999 auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(), 1000 /*symCount=*/0, exprs, reduceOp.getContext()); 1001 Value zero = rewriter.create<arith::ConstantOp>( 1002 reduceOp.getLoc(), reduceOp.getDestType(), 1003 rewriter.getZeroAttr(reduceOp.getDestType())); 1004 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>( 1005 reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero, 1006 rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}), 1007 rewriter.getStrArrayAttr(iteratorTypes)); 1008 return success(); 1009 } 1010 }; 1011 1012 /// Merge TransposeOp into ContractionOp user. 1013 /// Ex: 1014 /// ``` 1015 /// %0 = vector.transpose %arg0, [2, 0, 1] 1016 /// : vector<32x16x8xf32> to vector<8x32x16xf32> 1017 /// %1 = vector.contract {indexing_maps = [ 1018 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 1019 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 1020 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 1021 /// iterator_types = ["parallel", "parallel", "reduction"], 1022 /// kind = add} %0, %arg1, %cst_f0 1023 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 1024 /// ``` 1025 /// Gets converted to: 1026 /// ``` 1027 /// %1 = vector.contract {indexing_maps = [ 1028 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>, 1029 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 1030 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 1031 /// iterator_types = ["parallel", "parallel", "reduction"], 1032 /// kind = add} %arg0, %arg1, %cst_f0 1033 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 1034 /// ``` 1035 struct CombineContractTranspose 1036 : public OpRewritePattern<vector::ContractionOp> { 1037 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 1038 1039 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 1040 PatternRewriter &rewriter) const override { 1041 SmallVector<AffineMap, 4> maps = 1042 llvm::to_vector<4>(contractOp.getIndexingMaps()); 1043 Value lhs = contractOp.getLhs(); 1044 Value rhs = contractOp.getRhs(); 1045 size_t index = 0; 1046 bool changed = false; 1047 for (Value *operand : {&lhs, &rhs}) { 1048 AffineMap &map = maps[index++]; 1049 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>(); 1050 if (!transposeOp) 1051 continue; 1052 SmallVector<int64_t> perm; 1053 transposeOp.getTransp(perm); 1054 AffineMap permutationMap = AffineMap::getPermutationMap( 1055 extractVector<unsigned>(transposeOp.getTransp()), 1056 contractOp.getContext()); 1057 map = inversePermutation(permutationMap).compose(map); 1058 *operand = transposeOp.getVector(); 1059 changed = true; 1060 } 1061 if (!changed) 1062 return failure(); 1063 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 1064 contractOp, lhs, rhs, contractOp.getAcc(), 1065 rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); 1066 return success(); 1067 } 1068 }; 1069 1070 /// Merge BroadcastOp into ContractionOp user. 1071 /// Ex: 1072 /// ``` 1073 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> 1074 /// %1 = vector.contract {indexing_maps = [ 1075 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 1076 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 1077 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 1078 /// iterator_types = ["parallel", "parallel", "reduction"], 1079 /// kind = add} %0, %arg1, %cst_f0 1080 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 1081 /// ``` 1082 /// Gets converted to: 1083 /// ``` 1084 /// %1 = vector.contract {indexing_maps = [ 1085 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 1086 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 1087 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 1088 /// iterator_types = ["parallel", "parallel", "reduction"], 1089 /// kind = add} %arg0, %arg1, %cst_f0 1090 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 1091 /// ``` 1092 struct CombineContractBroadcast 1093 : public OpRewritePattern<vector::ContractionOp> { 1094 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 1095 1096 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 1097 PatternRewriter &rewriter) const override { 1098 SmallVector<AffineMap, 4> maps = 1099 llvm::to_vector<4>(contractOp.getIndexingMaps()); 1100 Value lhs = contractOp.getLhs(); 1101 Value rhs = contractOp.getRhs(); 1102 size_t index = 0; 1103 bool changed = false; 1104 for (Value *operand : {&lhs, &rhs}) { 1105 AffineMap &map = maps[index++]; 1106 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>(); 1107 if (!broadcast) 1108 continue; 1109 // contractionOp can only take vector as operands. 1110 auto srcType = broadcast.getSourceType().dyn_cast<VectorType>(); 1111 if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank()) 1112 continue; 1113 int64_t rankDiff = 1114 broadcast.getVectorType().getRank() - srcType.getRank(); 1115 bool innerDimBroadcast = false; 1116 SmallVector<AffineExpr> originalDims; 1117 for (const auto &dim : llvm::enumerate(srcType.getShape())) { 1118 if (dim.value() != 1119 broadcast.getVectorType().getDimSize(rankDiff + dim.index())) { 1120 innerDimBroadcast = true; 1121 break; 1122 } 1123 originalDims.push_back( 1124 rewriter.getAffineDimExpr(dim.index() + rankDiff)); 1125 } 1126 // Contract doesn't support inner dimension broadcast. Once this is 1127 // relaxed we can remove this case. 1128 if (innerDimBroadcast) 1129 continue; 1130 1131 // It would be incorrect to fold a broadcast onto a reduction dimension 1132 // of non-unit size. 1133 bool nonUnitDimReductionBroadcast = false; 1134 for (int64_t i = 0; i < rankDiff; ++i) { 1135 if (broadcast.getVectorType().getDimSize(i) != 1 && 1136 isReductionIterator(contractOp.getIteratorTypes() 1137 .getValue()[map.getDimPosition(i)])) { 1138 nonUnitDimReductionBroadcast = true; 1139 break; 1140 } 1141 } 1142 if (nonUnitDimReductionBroadcast) 1143 continue; 1144 1145 AffineMap broadcastMap = 1146 AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims, 1147 contractOp.getContext()); 1148 map = broadcastMap.compose(map); 1149 *operand = broadcast.getSource(); 1150 changed = true; 1151 } 1152 1153 if (!changed) 1154 return failure(); 1155 1156 // Determine which dims are usused, now that the maps have been composed 1157 // with the broadcast maps. 1158 unsigned numDims = maps[0].getNumDims(); 1159 llvm::SmallBitVector unusedDims(numDims, true); 1160 for (const auto &m : maps) { 1161 for (unsigned i = 0; i < numDims; ++i) { 1162 if (m.isFunctionOfDim(i)) 1163 unusedDims.reset(i); 1164 } 1165 } 1166 // Compress unused dims. 1167 for (auto &m : maps) 1168 m = compressDims(m, unusedDims); 1169 // Compute the combined iterators. 1170 SmallVector<Attribute, 4> iterators; 1171 for (unsigned i = 0; i < numDims; ++i) { 1172 if (!unusedDims.test(i)) 1173 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]); 1174 } 1175 // Check that compressing unused dims isn't removing all reduction 1176 // iterators. For example, if the vector.contract had only one reduction 1177 // iterator and that was a unit-dimension created by a broadcast, 1178 // then we should bail here, otherwise we would create a contract without 1179 // a reduction iterator. 1180 if (!llvm::any_of(iterators, isReductionIterator)) 1181 return failure(); 1182 1183 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 1184 contractOp, lhs, rhs, contractOp.getAcc(), 1185 rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators)); 1186 return success(); 1187 } 1188 }; 1189 1190 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and 1191 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when 1192 /// casting ops are around these operations. 1193 /// Ex: 1194 /// ``` 1195 /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8> 1196 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> 1197 /// ``` 1198 /// Gets converted to: 1199 /// ``` 1200 /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32> 1201 /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32> 1202 /// ``` 1203 struct ReorderCastOpsOnBroadcast 1204 : public OpInterfaceRewritePattern<CastOpInterface> { 1205 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern; 1206 1207 LogicalResult matchAndRewrite(CastOpInterface op, 1208 PatternRewriter &rewriter) const override { 1209 if (op->getNumOperands() != 1) 1210 return failure(); 1211 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>(); 1212 if (!bcastOp) 1213 return failure(); 1214 1215 Type castResTy = getElementTypeOrSelf(op->getResult(0)); 1216 if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>()) 1217 castResTy = VectorType::get(vecTy.getShape(), castResTy); 1218 auto *castOp = 1219 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 1220 bcastOp.getSource(), castResTy, op->getAttrs()); 1221 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1222 op, op->getResult(0).getType(), castOp->getResult(0)); 1223 return success(); 1224 } 1225 }; 1226 1227 /// Reorders elementwise(transpose) to transpose(elementwise). This makes 1228 /// transpose ops and contraction ops closer, which kicks in 1229 /// CombineContractTranspose pattern when elementwise ops are between these 1230 /// operations. Ex: 1231 /// ``` 1232 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> 1233 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> 1234 /// %r = arith.addf %at, %bt : vector<2x4xf32> 1235 /// ``` 1236 /// Gets converted to: 1237 /// ``` 1238 /// %0 = arith.addf %a, %b : vector<4x2xf32> 1239 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32> 1240 /// ``` 1241 struct ReorderElementwiseOpsOnTranspose final 1242 : public OpTraitRewritePattern<OpTrait::Elementwise> { 1243 using OpTraitRewritePattern::OpTraitRewritePattern; 1244 LogicalResult matchAndRewrite(Operation *op, 1245 PatternRewriter &rewriter) const override { 1246 if (op->getNumResults() != 1 || op->getNumRegions() != 0) 1247 return failure(); 1248 1249 // Make sure all operands are transpose/constant ops and collect their 1250 // transposition maps. 1251 SmallVector<ArrayAttr, 4> transposeMaps; 1252 transposeMaps.reserve(op->getNumOperands()); 1253 // Record the initial type before transposition. We'll use its shape later. 1254 // Any type will do here as we will check all transpose maps are the same. 1255 VectorType srcType; 1256 for (Value operand : op->getOperands()) { 1257 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>(); 1258 if (transposeOp) { 1259 transposeMaps.push_back(transposeOp.getTransp()); 1260 srcType = transposeOp.getVectorType(); 1261 } else if (!matchPattern(operand, m_Constant())) { 1262 return failure(); 1263 } 1264 } 1265 if (transposeMaps.empty()) 1266 return failure(); 1267 // This is an elementwise op, so all transposed operands should have the 1268 // same type. We need to additionally check that all transposes uses the 1269 // same map. 1270 if (!llvm::is_splat(transposeMaps)) 1271 return rewriter.notifyMatchFailure(op, "different transpose map"); 1272 1273 SmallVector<Value, 4> srcValues; 1274 srcValues.reserve(op->getNumOperands()); 1275 1276 // If there are constant operands, we need to insert inverse transposes for 1277 // them. Calculate the inverse order first. 1278 auto order = extractVector<unsigned>(transposeMaps.front()); 1279 SmallVector<int64_t> invOrder(order.size()); 1280 for (int i = 0, e = order.size(); i < e; ++i) 1281 invOrder[order[i]] = i; 1282 1283 for (Value operand : op->getOperands()) { 1284 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>(); 1285 if (transposeOp) { 1286 srcValues.push_back(transposeOp.getVector()); 1287 } else { 1288 // This is a constant. Create a reverse transpose op for it. 1289 auto vectorType = VectorType::get( 1290 srcType.getShape(), 1291 operand.getType().cast<VectorType>().getElementType()); 1292 srcValues.push_back(rewriter.create<vector::TransposeOp>( 1293 operand.getLoc(), vectorType, operand, 1294 rewriter.getI64ArrayAttr(invOrder))); 1295 } 1296 } 1297 1298 auto vectorType = VectorType::get( 1299 srcType.getShape(), 1300 op->getResultTypes()[0].cast<VectorType>().getElementType()); 1301 Operation *elementwiseOp = 1302 rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, 1303 vectorType, op->getAttrs()); 1304 rewriter.replaceOpWithNewOp<vector::TransposeOp>( 1305 op, op->getResultTypes()[0], elementwiseOp->getResult(0), 1306 transposeMaps.front()); 1307 return success(); 1308 } 1309 }; 1310 1311 } // namespace 1312 1313 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using 1314 /// operands `x` and `y`. 1315 static Value createAdd(Location loc, Value x, Value y, bool isInt, 1316 PatternRewriter &rewriter) { 1317 if (isInt) 1318 return rewriter.create<arith::AddIOp>(loc, x, y); 1319 return rewriter.create<arith::AddFOp>(loc, x, y); 1320 } 1321 1322 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using 1323 /// operands `x and `y`. 1324 static Value createMul(Location loc, Value x, Value y, bool isInt, 1325 PatternRewriter &rewriter) { 1326 if (isInt) 1327 return rewriter.create<arith::MulIOp>(loc, x, y); 1328 return rewriter.create<arith::MulFOp>(loc, x, y); 1329 } 1330 1331 namespace mlir { 1332 1333 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1334 /// semantics to: 1335 /// ``` 1336 /// %mta = maybe_transpose 1337 /// %mtb = maybe_transpose 1338 /// %flattened_a = vector.shape_cast %mta 1339 /// %flattened_b = vector.shape_cast %mtb 1340 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 1341 /// %mtd = vector.shape_cast %flattened_d 1342 /// %d = maybe_untranspose %mtd 1343 /// %e = add %c, %d 1344 /// ``` 1345 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 1346 // 1347 /// This only kicks in when VectorTransformsOptions is set to `Matmul`. 1348 /// vector.transpose operations are inserted if the vector.contract op is not a 1349 /// row-major matrix multiply. 1350 LogicalResult 1351 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, 1352 PatternRewriter &rew) const { 1353 // TODO: implement masks 1354 if (llvm::size(op.getMasks()) != 0) 1355 return failure(); 1356 if (vectorTransformOptions.vectorContractLowering != 1357 vector::VectorContractLowering::Matmul) 1358 return failure(); 1359 if (failed(filter(op))) 1360 return failure(); 1361 1362 auto iteratorTypes = op.getIteratorTypes().getValue(); 1363 if (!isParallelIterator(iteratorTypes[0]) || 1364 !isParallelIterator(iteratorTypes[1]) || 1365 !isReductionIterator(iteratorTypes[2])) 1366 return failure(); 1367 1368 Type elementType = op.getLhsType().getElementType(); 1369 if (!elementType.isIntOrFloat()) 1370 return failure(); 1371 1372 Type dstElementType = op.getType(); 1373 if (auto vecType = dstElementType.dyn_cast<VectorType>()) 1374 dstElementType = vecType.getElementType(); 1375 if (elementType != dstElementType) 1376 return failure(); 1377 1378 // Perform lhs + rhs transpositions to conform to matmul row-major semantics. 1379 // Bail out if the contraction cannot be put in this form. 1380 MLIRContext *ctx = op.getContext(); 1381 Location loc = op.getLoc(); 1382 AffineExpr m, n, k; 1383 bindDims(rew.getContext(), m, n, k); 1384 // LHS must be A(m, k) or A(k, m). 1385 Value lhs = op.getLhs(); 1386 auto lhsMap = op.getIndexingMaps()[0]; 1387 if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) 1388 lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0}); 1389 else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) 1390 return failure(); 1391 1392 // RHS must be B(k, n) or B(n, k). 1393 Value rhs = op.getRhs(); 1394 auto rhsMap = op.getIndexingMaps()[1]; 1395 if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) 1396 rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0}); 1397 else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) 1398 return failure(); 1399 1400 // At this point lhs and rhs are in row-major. 1401 VectorType lhsType = lhs.getType().cast<VectorType>(); 1402 VectorType rhsType = rhs.getType().cast<VectorType>(); 1403 int64_t lhsRows = lhsType.getDimSize(0); 1404 int64_t lhsColumns = lhsType.getDimSize(1); 1405 int64_t rhsColumns = rhsType.getDimSize(1); 1406 1407 Type flattenedLHSType = 1408 VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); 1409 lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs); 1410 1411 Type flattenedRHSType = 1412 VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); 1413 rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs); 1414 1415 Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns, 1416 rhsColumns); 1417 mul = rew.create<vector::ShapeCastOp>( 1418 loc, 1419 VectorType::get({lhsRows, rhsColumns}, 1420 getElementTypeOrSelf(op.getAcc().getType())), 1421 mul); 1422 1423 // ACC must be C(m, n) or C(n, m). 1424 auto accMap = op.getIndexingMaps()[2]; 1425 if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) 1426 mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0}); 1427 else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) 1428 llvm_unreachable("invalid contraction semantics"); 1429 1430 Value res = 1431 elementType.isa<IntegerType>() 1432 ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul)) 1433 : static_cast<Value>( 1434 rew.create<arith::AddFOp>(loc, op.getAcc(), mul)); 1435 1436 rew.replaceOp(op, res); 1437 return success(); 1438 } 1439 1440 namespace { 1441 struct IteratorType { 1442 IteratorType(StringRef strRef) : strRef(strRef) {} 1443 bool isOfType(Attribute attr) const { 1444 auto sAttr = attr.dyn_cast<StringAttr>(); 1445 return sAttr && sAttr.getValue() == strRef; 1446 } 1447 StringRef strRef; 1448 }; 1449 struct Par : public IteratorType { 1450 Par() : IteratorType(getParallelIteratorTypeName()) {} 1451 }; 1452 struct Red : public IteratorType { 1453 Red() : IteratorType(getReductionIteratorTypeName()) {} 1454 }; 1455 1456 /// Generate a vector implementation for matmat, matvec and tmatvec. 1457 /// This unrolls outer-products along the reduction dimension. 1458 struct UnrolledOuterProductGenerator 1459 : public StructuredGenerator<vector::ContractionOp> { 1460 UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op) 1461 : StructuredGenerator<vector::ContractionOp>(builder, op), 1462 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), 1463 res(op.getAcc()), lhsType(op.getLhsType()) {} 1464 1465 Value t(Value v) { 1466 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1467 return builder.create<vector::TransposeOp>(loc, v, perm); 1468 } 1469 1470 Value promote(Value v, Type dstElementType) { 1471 Type elementType = v.getType(); 1472 auto vecType = elementType.dyn_cast<VectorType>(); 1473 if (vecType) 1474 elementType = vecType.getElementType(); 1475 if (elementType == dstElementType) 1476 return v; 1477 Type promotedType = dstElementType; 1478 if (vecType) 1479 promotedType = VectorType::get(vecType.getShape(), promotedType); 1480 if (dstElementType.isa<FloatType>()) 1481 return builder.create<arith::ExtFOp>(loc, promotedType, v); 1482 return builder.create<arith::ExtSIOp>(loc, promotedType, v); 1483 } 1484 1485 Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) { 1486 assert(reductionSize > 0); 1487 Type resElementType = res.getType().cast<VectorType>().getElementType(); 1488 for (int64_t k = 0; k < reductionSize; ++k) { 1489 Value a = builder.create<vector::ExtractOp>(loc, lhs, k); 1490 Value b = builder.create<vector::ExtractOp>(loc, rhs, k); 1491 a = promote(a, resElementType); 1492 b = promote(b, resElementType); 1493 res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b, 1494 res, kind); 1495 } 1496 return res; 1497 } 1498 1499 /// Two outer parallel, one inner reduction (matmat flavor). 1500 FailureOr<Value> matmat() { 1501 if (!iters({Par(), Par(), Red()})) 1502 return failure(); 1503 // Set up the parallel/reduction structure in the right form. 1504 AffineExpr m, n, k; 1505 bindDims(builder.getContext(), m, n, k); 1506 // Classical row-major matmul: Just permute the lhs. 1507 if (layout({{m, k}, {k, n}, {m, n}})) 1508 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1509 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1510 if (layout({{m, k}, {n, k}, {m, n}})) { 1511 Value tlhs = t(lhs); 1512 return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1)); 1513 } 1514 // No need to permute anything. 1515 if (layout({{k, m}, {k, n}, {m, n}})) 1516 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1517 // Just permute the rhs. 1518 if (layout({{k, m}, {n, k}, {m, n}})) 1519 return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0)); 1520 // Transposed output: swap RHS and LHS. 1521 // Classical row-major matmul: permute the lhs. 1522 if (layout({{m, k}, {k, n}, {n, m}})) 1523 return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1)); 1524 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1525 if (layout({{m, k}, {n, k}, {n, m}})) { 1526 Value trhs = t(rhs); 1527 return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1)); 1528 } 1529 if (layout({{k, m}, {k, n}, {n, m}})) 1530 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1531 if (layout({{k, m}, {n, k}, {n, m}})) 1532 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1533 return failure(); 1534 } 1535 1536 /// One outer parallel, one inner reduction (matvec flavor) 1537 FailureOr<Value> matvec() { 1538 if (!iters({Par(), Red()})) 1539 return failure(); 1540 AffineExpr m, k; 1541 bindDims(builder.getContext(), m, k); 1542 1543 // Case mat-vec: transpose. 1544 if (layout({{m, k}, {k}, {m}})) 1545 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1546 // Case mat-trans-vec: ready to go. 1547 if (layout({{k, m}, {k}, {m}})) 1548 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1549 // Case vec-mat: swap and transpose. 1550 if (layout({{k}, {m, k}, {m}})) 1551 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1552 // Case vec-mat-trans: swap and ready to go. 1553 if (layout({{k}, {k, m}, {m}})) 1554 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1555 return failure(); 1556 } 1557 1558 // 1559 // One outer reduction, one inner parallel (tmatvec flavor) 1560 // 1561 FailureOr<Value> tmatvec() { 1562 if (!iters({Red(), Par()})) 1563 return failure(); 1564 AffineExpr k, m; 1565 bindDims(builder.getContext(), k, m); 1566 1567 // Case mat-vec: transpose. 1568 if (layout({{m, k}, {k}, {m}})) 1569 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1570 // Case mat-trans-vec: ready to go. 1571 if (layout({{k, m}, {k}, {m}})) 1572 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1573 // Case vec-mat: swap and transpose. 1574 if (layout({{k}, {m, k}, {m}})) 1575 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1576 // Case vec-mat-trans: swap and ready to go. 1577 if (layout({{k}, {k, m}, {m}})) 1578 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1579 return failure(); 1580 } 1581 1582 private: 1583 vector::CombiningKind kind; 1584 Value lhs, rhs, res; 1585 VectorType lhsType; 1586 }; 1587 } // namespace 1588 1589 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1590 /// semantics to a reduction_size-unrolled sequence: 1591 /// ``` 1592 /// %at = vector.transpose %a, [1, 0] 1593 /// %bRow0 = vector.extract %b[0] 1594 /// %atRow0 = vector.extract %at[0] 1595 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 1596 /// ... 1597 /// %bRowK = vector.extract %b[K] 1598 /// %atRowK = vector.extract %at[K] 1599 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 1600 /// ``` 1601 /// 1602 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but 1603 /// otherwise supports any layout permutation of the matrix-multiply. 1604 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( 1605 vector::ContractionOp op, PatternRewriter &rewriter) const { 1606 // TODO: implement masks 1607 if (llvm::size(op.getMasks()) != 0) 1608 return failure(); 1609 1610 if (vectorTransformOptions.vectorContractLowering != 1611 vector::VectorContractLowering::OuterProduct) 1612 return failure(); 1613 1614 if (failed(filter(op))) 1615 return failure(); 1616 1617 UnrolledOuterProductGenerator e(rewriter, op); 1618 FailureOr<Value> matmatRes = e.matmat(); 1619 if (succeeded(matmatRes)) { 1620 rewriter.replaceOp(op, *matmatRes); 1621 return success(); 1622 } 1623 FailureOr<Value> matvecRes = e.matvec(); 1624 if (succeeded(matvecRes)) { 1625 rewriter.replaceOp(op, *matvecRes); 1626 return success(); 1627 } 1628 FailureOr<Value> tmatvecRes = e.tmatvec(); 1629 if (succeeded(tmatvecRes)) { 1630 rewriter.replaceOp(op, *tmatvecRes); 1631 return success(); 1632 } 1633 1634 return failure(); 1635 } 1636 1637 LogicalResult 1638 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, 1639 PatternRewriter &rewriter) const { 1640 // TODO: implement masks 1641 if (llvm::size(op.getMasks()) != 0) 1642 return failure(); 1643 1644 if (failed(filter(op))) 1645 return failure(); 1646 1647 if (vectorTransformOptions.vectorContractLowering != 1648 vector::VectorContractLowering::Dot) 1649 return failure(); 1650 1651 auto iteratorTypes = op.getIteratorTypes().getValue(); 1652 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1653 Location loc = op.getLoc(); 1654 Value lhs = op.getLhs(), rhs = op.getRhs(); 1655 1656 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1657 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 1658 AffineExpr m, n, k; 1659 bindDims(rewriter.getContext(), m, n, k); 1660 SmallVector<AffineMap, 4> maps = op.getIndexingMaps(); 1661 // 1662 // In the following we wish to make the reduction dimension innermost so we 1663 // can load vectors and just fmul + reduce into a scalar. 1664 // 1665 if (isParallelIterator(iteratorTypes[0]) && 1666 isParallelIterator(iteratorTypes[1]) && 1667 isReductionIterator(iteratorTypes[2])) { 1668 // 1669 // Two outer parallel, one inner reduction (matmat flavor). 1670 // 1671 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 1672 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1673 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { 1674 // No need to permute anything. 1675 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 1676 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1677 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1678 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 1679 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1680 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 1681 // This is the classical row-major matmul. Just permute the lhs. 1682 Value tmp = lhs; 1683 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1684 rhs = tmp; 1685 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 1686 std::swap(lhs, rhs); 1687 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 1688 Value tmp = lhs; 1689 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1690 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm); 1691 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 1692 Value tmp = rhs; 1693 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1694 lhs = tmp; 1695 } else { 1696 return failure(); 1697 } 1698 } else if (isParallelIterator(iteratorTypes[0]) && 1699 isReductionIterator(iteratorTypes[1])) { 1700 // 1701 // One outer parallel, one inner reduction (matvec flavor) 1702 // 1703 if (maps == infer({{m, n}, {n}, {m}})) { 1704 // No need to permute anything. 1705 } else if (maps == infer({{n, m}, {n}, {m}})) { 1706 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1707 } else if (maps == infer({{n}, {m, n}, {m}})) { 1708 std::swap(lhs, rhs); 1709 } else if (maps == infer({{n}, {n, m}, {m}})) { 1710 std::swap(lhs, rhs); 1711 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1712 } else { 1713 return failure(); 1714 } 1715 } else { 1716 return failure(); 1717 } 1718 1719 VectorType dstType = op.getResultType().cast<VectorType>(); 1720 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && 1721 "Expected dst type of rank 1 or 2"); 1722 1723 unsigned rank = dstType.getRank(); 1724 unsigned dstRows = dstType.getShape()[0]; 1725 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; 1726 1727 // ExtractOp does not allow dynamic indexing, we must unroll explicitly. 1728 Value res = rewriter.create<arith::ConstantOp>(loc, dstType, 1729 rewriter.getZeroAttr(dstType)); 1730 bool isInt = dstType.getElementType().isa<IntegerType>(); 1731 for (unsigned r = 0; r < dstRows; ++r) { 1732 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r); 1733 for (unsigned c = 0; c < dstColumns; ++c) { 1734 Value b = rank == 1 1735 ? rhs 1736 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c); 1737 Value m = createMul(op.getLoc(), a, b, isInt, rewriter); 1738 Value reduced = rewriter.create<vector::ReductionOp>( 1739 op.getLoc(), vector::CombiningKind::ADD, m); 1740 1741 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r} 1742 : SmallVector<int64_t, 2>{r, c}; 1743 res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos); 1744 } 1745 } 1746 if (auto acc = op.getAcc()) 1747 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1748 rewriter.replaceOp(op, res); 1749 return success(); 1750 } 1751 1752 /// Progressive lowering of ContractionOp. 1753 /// One: 1754 /// %x = vector.contract with at least one free/batch dimension 1755 /// is replaced by: 1756 /// %a = vector.contract with one less free/batch dimension 1757 /// %b = vector.contract with one less free/batch dimension 1758 /// .. 1759 /// %x = combine %a %b .. 1760 /// until a pure contraction is reached (no free/batch dimensions), 1761 /// which is replaced by a dot-product. 1762 /// 1763 /// This only kicks in when either VectorTransformsOptions is set 1764 /// to DOT or when other contraction patterns fail. 1765 // 1766 // TODO: break down into transpose/reshape/cast ops 1767 // when they become available to avoid code dup 1768 // TODO: investigate lowering order impact on performance 1769 LogicalResult 1770 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, 1771 PatternRewriter &rewriter) const { 1772 // TODO: implement masks. 1773 if (llvm::size(op.getMasks()) != 0) 1774 return failure(); 1775 1776 if (failed(filter(op))) 1777 return failure(); 1778 1779 // TODO: support mixed mode contract lowering. 1780 if (op.getLhsType().getElementType() != 1781 getElementTypeOrSelf(op.getAccType()) || 1782 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) 1783 return failure(); 1784 1785 // TODO: implement benefits, cost models. 1786 MLIRContext *ctx = op.getContext(); 1787 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); 1788 if (succeeded(pat1.matchAndRewrite(op, rewriter))) 1789 return success(); 1790 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); 1791 if (succeeded(pat2.matchAndRewrite(op, rewriter))) 1792 return success(); 1793 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); 1794 if (succeeded(pat3.matchAndRewrite(op, rewriter))) 1795 return success(); 1796 ContractOpToElementwise pat4(vectorTransformOptions, ctx); 1797 if (succeeded(pat4.matchAndRewrite(op, rewriter))) 1798 return success(); 1799 1800 // Find first batch dimension in LHS/RHS, and lower when found. 1801 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap(); 1802 if (!batchDimMap.empty()) { 1803 int64_t lhsIndex = batchDimMap[0].first; 1804 int64_t rhsIndex = batchDimMap[0].second; 1805 rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter)); 1806 return success(); 1807 } 1808 1809 // Collect contracting dimensions. 1810 std::vector<std::pair<int64_t, int64_t>> contractingDimMap = 1811 op.getContractingDimMap(); 1812 DenseSet<int64_t> lhsContractingDimSet; 1813 DenseSet<int64_t> rhsContractingDimSet; 1814 for (auto &dimPair : contractingDimMap) { 1815 lhsContractingDimSet.insert(dimPair.first); 1816 rhsContractingDimSet.insert(dimPair.second); 1817 } 1818 1819 // Find first free dimension in LHS, and lower when found. 1820 VectorType lhsType = op.getLhsType(); 1821 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { 1822 if (lhsContractingDimSet.count(lhsIndex) == 0) { 1823 rewriter.replaceOp( 1824 op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter)); 1825 return success(); 1826 } 1827 } 1828 1829 // Find first free dimension in RHS, and lower when found. 1830 VectorType rhsType = op.getRhsType(); 1831 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { 1832 if (rhsContractingDimSet.count(rhsIndex) == 0) { 1833 rewriter.replaceOp( 1834 op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter)); 1835 return success(); 1836 } 1837 } 1838 1839 // Lower the first remaining reduction dimension. 1840 if (!contractingDimMap.empty()) { 1841 rewriter.replaceOp(op, lowerReduction(op, rewriter)); 1842 return success(); 1843 } 1844 1845 return failure(); 1846 } 1847 1848 // Lower one parallel dimension. 1849 // TODO: consider reusing existing contract unrolling 1850 Value ContractionOpLowering::lowerParallel(vector::ContractionOp op, 1851 int64_t lhsIndex, int64_t rhsIndex, 1852 PatternRewriter &rewriter) const { 1853 VectorType lhsType = op.getLhsType(); 1854 VectorType rhsType = op.getRhsType(); 1855 VectorType resType = op.getResultType().cast<VectorType>(); 1856 // Find the iterator type index and result index. 1857 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1858 int64_t iterIndex = -1; 1859 int64_t dimSize = -1; 1860 if (lhsIndex >= 0) { 1861 iterIndex = iMap[0].getDimPosition(lhsIndex); 1862 assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) && 1863 "parallel index should be free in LHS or batch in LHS/RHS"); 1864 dimSize = lhsType.getDimSize(lhsIndex); 1865 } else { 1866 assert(rhsIndex >= 0 && "missing parallel index"); 1867 iterIndex = iMap[1].getDimPosition(rhsIndex); 1868 dimSize = rhsType.getDimSize(rhsIndex); 1869 } 1870 assert(iterIndex >= 0 && "parallel index not listed in operand mapping"); 1871 Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex); 1872 assert(lookup.hasValue() && "parallel index not listed in reduction"); 1873 int64_t resIndex = lookup.getValue(); 1874 // Construct new iterator types and affine map array attribute. 1875 std::array<AffineMap, 3> lowIndexingMaps = { 1876 adjustMap(iMap[0], iterIndex, rewriter), 1877 adjustMap(iMap[1], iterIndex, rewriter), 1878 adjustMap(iMap[2], iterIndex, rewriter)}; 1879 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1880 auto lowIter = 1881 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); 1882 // Unroll into a series of lower dimensional vector.contract ops. 1883 Location loc = op.getLoc(); 1884 Value result = rewriter.create<arith::ConstantOp>( 1885 loc, resType, rewriter.getZeroAttr(resType)); 1886 for (int64_t d = 0; d < dimSize; ++d) { 1887 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); 1888 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); 1889 auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter); 1890 Value lowContract = rewriter.create<vector::ContractionOp>( 1891 loc, lhs, rhs, acc, lowAffine, lowIter); 1892 result = 1893 reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter); 1894 } 1895 return result; 1896 } 1897 1898 // Lower one reduction dimension. 1899 Value ContractionOpLowering::lowerReduction(vector::ContractionOp op, 1900 PatternRewriter &rewriter) const { 1901 auto loc = op.getLoc(); 1902 VectorType lhsType = op.getLhsType(); 1903 VectorType rhsType = op.getRhsType(); 1904 Type resType = op.getResultType(); 1905 assert(!resType.isa<VectorType>()); 1906 bool isInt = resType.isa<IntegerType>(); 1907 // Use iterator index 0. 1908 int64_t iterIndex = 0; 1909 SmallVector<AffineMap, 4> iMap = op.getIndexingMaps(); 1910 Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex); 1911 Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex); 1912 assert(lookupLhs.hasValue() && "missing LHS parallel index"); 1913 assert(lookupRhs.hasValue() && "missing RHS parallel index"); 1914 int64_t lhsIndex = lookupLhs.getValue(); 1915 int64_t rhsIndex = lookupRhs.getValue(); 1916 int64_t dimSize = lhsType.getDimSize(lhsIndex); 1917 assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape"); 1918 // Base case. 1919 if (lhsType.getRank() == 1) { 1920 assert(rhsType.getRank() == 1 && "corrupt contraction"); 1921 Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter); 1922 auto kind = vector::CombiningKind::ADD; 1923 if (auto acc = op.getAcc()) 1924 return rewriter.create<vector::ReductionOp>(loc, kind, m, acc); 1925 return rewriter.create<vector::ReductionOp>(loc, kind, m); 1926 } 1927 // Construct new iterator types and affine map array attribute. 1928 std::array<AffineMap, 3> lowIndexingMaps = { 1929 adjustMap(iMap[0], iterIndex, rewriter), 1930 adjustMap(iMap[1], iterIndex, rewriter), 1931 adjustMap(iMap[2], iterIndex, rewriter)}; 1932 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1933 auto lowIter = 1934 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); 1935 // Unroll into a series of lower dimensional vector.contract ops. 1936 // By feeding the initial accumulator into the first contraction, 1937 // and the result of each contraction into the next, eventually 1938 // the sum of all reductions is computed. 1939 Value result = op.getAcc(); 1940 for (int64_t d = 0; d < dimSize; ++d) { 1941 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); 1942 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); 1943 result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result, 1944 lowAffine, lowIter); 1945 } 1946 return result; 1947 } 1948 1949 } // namespace mlir 1950 1951 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp( 1952 OpBuilder &builder, Operation *op, ArrayRef<Value> ids, 1953 ArrayRef<int64_t> multiplicity, const AffineMap &map) { 1954 OpBuilder::InsertionGuard guard(builder); 1955 builder.setInsertionPointAfter(op); 1956 Location loc = op->getLoc(); 1957 if (op->getNumResults() != 1) 1958 return {}; 1959 Value result = op->getResult(0); 1960 VectorType type = op->getResult(0).getType().dyn_cast<VectorType>(); 1961 if (!type || map.getNumResults() != multiplicity.size()) 1962 return {}; 1963 // For each dimension being distributed check that the size is a multiple of 1964 // the multiplicity. To handle more sizes we would need to support masking. 1965 unsigned multiplictyCount = 0; 1966 for (auto exp : map.getResults()) { 1967 auto affinExp = exp.dyn_cast<AffineDimExpr>(); 1968 if (!affinExp || affinExp.getPosition() >= type.getRank() || 1969 type.getDimSize(affinExp.getPosition()) % 1970 multiplicity[multiplictyCount++] != 1971 0) 1972 return {}; 1973 } 1974 DistributeOps ops; 1975 ops.extract = 1976 builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map); 1977 ops.insert = 1978 builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids); 1979 return ops; 1980 } 1981 1982 /// Progressive lowering of transfer_read. This pattern supports lowering of 1983 /// `vector.transfer_read` to a combination of `vector.load` and 1984 /// `vector.broadcast` if all of the following hold: 1985 /// - Stride of most minor memref dimension must be 1. 1986 /// - Out-of-bounds masking is not required. 1987 /// - If the memref's element type is a vector type then it coincides with the 1988 /// result type. 1989 /// - The permutation map doesn't perform permutation (broadcasting is allowed). 1990 struct TransferReadToVectorLoadLowering 1991 : public OpRewritePattern<vector::TransferReadOp> { 1992 TransferReadToVectorLoadLowering(MLIRContext *context, 1993 llvm::Optional<unsigned> maxRank) 1994 : OpRewritePattern<vector::TransferReadOp>(context), 1995 maxTransferRank(maxRank) {} 1996 1997 LogicalResult matchAndRewrite(vector::TransferReadOp read, 1998 PatternRewriter &rewriter) const override { 1999 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) 2000 return failure(); 2001 2002 SmallVector<unsigned, 4> broadcastedDims; 2003 // Permutations are handled by VectorToSCF or 2004 // populateVectorTransferPermutationMapLoweringPatterns. 2005 // We let the 0-d corner case pass-through as it is supported. 2006 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting( 2007 &broadcastedDims)) 2008 return failure(); 2009 2010 auto memRefType = read.getShapedType().dyn_cast<MemRefType>(); 2011 if (!memRefType) 2012 return failure(); 2013 2014 // Non-unit strides are handled by VectorToSCF. 2015 if (!vector::isLastMemrefDimUnitStride(memRefType)) 2016 return failure(); 2017 2018 // If there is broadcasting involved then we first load the unbroadcasted 2019 // vector, and then broadcast it with `vector.broadcast`. 2020 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape(); 2021 SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(), 2022 vectorShape.end()); 2023 for (unsigned i : broadcastedDims) 2024 unbroadcastedVectorShape[i] = 1; 2025 VectorType unbroadcastedVectorType = VectorType::get( 2026 unbroadcastedVectorShape, read.getVectorType().getElementType()); 2027 2028 // `vector.load` supports vector types as memref's elements only when the 2029 // resulting vector type is the same as the element type. 2030 auto memrefElTy = memRefType.getElementType(); 2031 if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType) 2032 return failure(); 2033 2034 // Otherwise, element types of the memref and the vector must match. 2035 if (!memrefElTy.isa<VectorType>() && 2036 memrefElTy != read.getVectorType().getElementType()) 2037 return failure(); 2038 2039 // Out-of-bounds dims are handled by MaterializeTransferMask. 2040 if (read.hasOutOfBoundsDim()) 2041 return failure(); 2042 2043 // Create vector load op. 2044 Operation *loadOp; 2045 if (read.getMask()) { 2046 Value fill = rewriter.create<vector::SplatOp>( 2047 read.getLoc(), unbroadcastedVectorType, read.getPadding()); 2048 loadOp = rewriter.create<vector::MaskedLoadOp>( 2049 read.getLoc(), unbroadcastedVectorType, read.getSource(), 2050 read.getIndices(), read.getMask(), fill); 2051 } else { 2052 loadOp = rewriter.create<vector::LoadOp>( 2053 read.getLoc(), unbroadcastedVectorType, read.getSource(), 2054 read.getIndices()); 2055 } 2056 2057 // Insert a broadcasting op if required. 2058 if (!broadcastedDims.empty()) { 2059 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 2060 read, read.getVectorType(), loadOp->getResult(0)); 2061 } else { 2062 rewriter.replaceOp(read, loadOp->getResult(0)); 2063 } 2064 2065 return success(); 2066 } 2067 2068 llvm::Optional<unsigned> maxTransferRank; 2069 }; 2070 2071 /// Replace a 0-d vector.load with a memref.load + vector.broadcast. 2072 // TODO: we shouldn't cross the vector/scalar domains just for this 2073 // but atm we lack the infra to avoid it. Possible solutions include: 2074 // - go directly to LLVM + bitcast 2075 // - introduce a bitcast op and likely a new pointer dialect 2076 // - let memref.load/store additionally support the 0-d vector case 2077 // There are still deeper data layout issues lingering even in this 2078 // trivial case (for architectures for which this matters). 2079 struct VectorLoadToMemrefLoadLowering 2080 : public OpRewritePattern<vector::LoadOp> { 2081 using OpRewritePattern<vector::LoadOp>::OpRewritePattern; 2082 2083 LogicalResult matchAndRewrite(vector::LoadOp loadOp, 2084 PatternRewriter &rewriter) const override { 2085 auto vecType = loadOp.getVectorType(); 2086 if (vecType.getNumElements() != 1) 2087 return failure(); 2088 auto memrefLoad = rewriter.create<memref::LoadOp>( 2089 loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices()); 2090 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType, 2091 memrefLoad); 2092 return success(); 2093 } 2094 }; 2095 2096 /// Replace a 0-d vector.store with a vector.extractelement + memref.store. 2097 struct VectorStoreToMemrefStoreLowering 2098 : public OpRewritePattern<vector::StoreOp> { 2099 using OpRewritePattern<vector::StoreOp>::OpRewritePattern; 2100 2101 LogicalResult matchAndRewrite(vector::StoreOp storeOp, 2102 PatternRewriter &rewriter) const override { 2103 auto vecType = storeOp.getVectorType(); 2104 if (vecType.getNumElements() != 1) 2105 return failure(); 2106 Value extracted; 2107 if (vecType.getRank() == 0) { 2108 // TODO: Unifiy once ExtractOp supports 0-d vectors. 2109 extracted = rewriter.create<vector::ExtractElementOp>( 2110 storeOp.getLoc(), storeOp.getValueToStore()); 2111 } else { 2112 SmallVector<int64_t> indices(vecType.getRank(), 0); 2113 extracted = rewriter.create<vector::ExtractOp>( 2114 storeOp.getLoc(), storeOp.getValueToStore(), indices); 2115 } 2116 2117 rewriter.replaceOpWithNewOp<memref::StoreOp>( 2118 storeOp, extracted, storeOp.getBase(), storeOp.getIndices()); 2119 return success(); 2120 } 2121 }; 2122 2123 /// Progressive lowering of transfer_write. This pattern supports lowering of 2124 /// `vector.transfer_write` to `vector.store` if all of the following hold: 2125 /// - Stride of most minor memref dimension must be 1. 2126 /// - Out-of-bounds masking is not required. 2127 /// - If the memref's element type is a vector type then it coincides with the 2128 /// type of the written value. 2129 /// - The permutation map is the minor identity map (neither permutation nor 2130 /// broadcasting is allowed). 2131 struct TransferWriteToVectorStoreLowering 2132 : public OpRewritePattern<vector::TransferWriteOp> { 2133 TransferWriteToVectorStoreLowering(MLIRContext *context, 2134 llvm::Optional<unsigned> maxRank) 2135 : OpRewritePattern<vector::TransferWriteOp>(context), 2136 maxTransferRank(maxRank) {} 2137 2138 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 2139 PatternRewriter &rewriter) const override { 2140 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) 2141 return failure(); 2142 2143 // Permutations are handled by VectorToSCF or 2144 // populateVectorTransferPermutationMapLoweringPatterns. 2145 if ( // pass-through for the 0-d corner case. 2146 !write.getPermutationMap().isMinorIdentity()) 2147 return failure(); 2148 2149 auto memRefType = write.getShapedType().dyn_cast<MemRefType>(); 2150 if (!memRefType) 2151 return failure(); 2152 2153 // Non-unit strides are handled by VectorToSCF. 2154 if (!vector::isLastMemrefDimUnitStride(memRefType)) 2155 return failure(); 2156 2157 // `vector.store` supports vector types as memref's elements only when the 2158 // type of the vector value being written is the same as the element type. 2159 auto memrefElTy = memRefType.getElementType(); 2160 if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType()) 2161 return failure(); 2162 2163 // Otherwise, element types of the memref and the vector must match. 2164 if (!memrefElTy.isa<VectorType>() && 2165 memrefElTy != write.getVectorType().getElementType()) 2166 return failure(); 2167 2168 // Out-of-bounds dims are handled by MaterializeTransferMask. 2169 if (write.hasOutOfBoundsDim()) 2170 return failure(); 2171 if (write.getMask()) { 2172 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 2173 write, write.getSource(), write.getIndices(), write.getMask(), 2174 write.getVector()); 2175 } else { 2176 rewriter.replaceOpWithNewOp<vector::StoreOp>( 2177 write, write.getVector(), write.getSource(), write.getIndices()); 2178 } 2179 return success(); 2180 } 2181 2182 llvm::Optional<unsigned> maxTransferRank; 2183 }; 2184 2185 // Returns the values in `arrayAttr` as an integer vector. 2186 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) { 2187 return llvm::to_vector<4>( 2188 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(), 2189 [](IntegerAttr attr) { return attr.getInt(); })); 2190 } 2191 2192 // Shuffles vector.bitcast op after vector.extract op. 2193 // 2194 // This transforms IR like: 2195 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> 2196 // %1 = vector.extract %0[3] : vector<8xf16> 2197 // Into: 2198 // %0 = vector.extract %src[1] : vector<4xf32> 2199 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16> 2200 // %2 = vector.extract %1[1] : vector<2xf16> 2201 struct BubbleDownVectorBitCastForExtract 2202 : public OpRewritePattern<vector::ExtractOp> { 2203 using OpRewritePattern::OpRewritePattern; 2204 2205 LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 2206 PatternRewriter &rewriter) const override { 2207 // Only support extracting scalars for now. 2208 if (extractOp.getVectorType().getRank() != 1) 2209 return failure(); 2210 2211 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>(); 2212 if (!castOp) 2213 return failure(); 2214 2215 VectorType castSrcType = castOp.getSourceVectorType(); 2216 VectorType castDstType = castOp.getResultVectorType(); 2217 assert(castSrcType.getRank() == castDstType.getRank()); 2218 2219 // Fail to match if we only have one element in the cast op source. 2220 // This is to avoid infinite loop given that this pattern can generate 2221 // such cases. 2222 if (castSrcType.getNumElements() == 1) 2223 return failure(); 2224 2225 // Only support casting to a larger number of elements or now. 2226 // E.g., vector<4xf32> -> vector<8xf16>. 2227 if (castSrcType.getNumElements() > castDstType.getNumElements()) 2228 return failure(); 2229 2230 unsigned expandRatio = 2231 castDstType.getNumElements() / castSrcType.getNumElements(); 2232 2233 auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t { 2234 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 2235 }; 2236 2237 uint64_t index = getFirstIntValue(extractOp.getPosition()); 2238 2239 // Get the single scalar (as a vector) in the source value that packs the 2240 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> 2241 VectorType oneScalarType = 2242 VectorType::get({1}, castSrcType.getElementType()); 2243 Value packedValue = rewriter.create<vector::ExtractOp>( 2244 extractOp.getLoc(), oneScalarType, castOp.getSource(), 2245 rewriter.getI64ArrayAttr(index / expandRatio)); 2246 2247 // Cast it to a vector with the desired scalar's type. 2248 // E.g. f32 -> vector<2xf16> 2249 VectorType packedType = 2250 VectorType::get({expandRatio}, castDstType.getElementType()); 2251 Value castedValue = rewriter.create<vector::BitCastOp>( 2252 extractOp.getLoc(), packedType, packedValue); 2253 2254 // Finally extract the desired scalar. 2255 rewriter.replaceOpWithNewOp<vector::ExtractOp>( 2256 extractOp, extractOp.getType(), castedValue, 2257 rewriter.getI64ArrayAttr(index % expandRatio)); 2258 2259 return success(); 2260 } 2261 }; 2262 2263 // Shuffles vector.bitcast op after vector.extract_strided_slice op. 2264 // 2265 // This transforms IR like: 2266 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> 2267 // %0 = vector.extract_strided_slice %cast { 2268 // offsets = [4], sizes = [4], strides = [1] 2269 // } : vector<8xf16> to vector<4xf16> 2270 // Into: 2271 // %0 = vector.extract_strided_slice %src { 2272 // offsets = [2], sizes = [2], strides = [1] 2273 // } : vector<4xf32> to vector<2xf32> 2274 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> 2275 struct BubbleDownBitCastForStridedSliceExtract 2276 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 2277 using OpRewritePattern::OpRewritePattern; 2278 2279 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 2280 PatternRewriter &rewriter) const override { 2281 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>(); 2282 if (!castOp) 2283 return failure(); 2284 2285 VectorType castSrcType = castOp.getSourceVectorType(); 2286 VectorType castDstType = castOp.getResultVectorType(); 2287 assert(castSrcType.getRank() == castDstType.getRank()); 2288 2289 int64_t castSrcLastDim = castSrcType.getShape().back(); 2290 int64_t castDstLastDim = castDstType.getShape().back(); 2291 // Require casting to more elements for now; other cases to be implemented. 2292 if (castSrcLastDim > castDstLastDim) 2293 return failure(); 2294 2295 // Only accept all one strides for now. 2296 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(), 2297 [](const APInt &val) { return !val.isOneValue(); })) 2298 return failure(); 2299 2300 unsigned rank = extractOp.getVectorType().getRank(); 2301 assert(castDstLastDim % castSrcLastDim == 0); 2302 int64_t expandRatio = castDstLastDim / castSrcLastDim; 2303 2304 // If we have a less number of offsets than the rank, then implicitly we 2305 // are selecting the full range for the last bitcasted dimension; other 2306 // dimensions aren't affected. Otherwise, we need to scale down the last 2307 // dimension's offset given we are extracting from less elements now. 2308 ArrayAttr newOffsets = extractOp.getOffsets(); 2309 if (newOffsets.size() == rank) { 2310 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2311 if (offsets.back() % expandRatio != 0) 2312 return failure(); 2313 offsets.back() = offsets.back() / expandRatio; 2314 newOffsets = rewriter.getI64ArrayAttr(offsets); 2315 } 2316 2317 // Similarly for sizes. 2318 ArrayAttr newSizes = extractOp.getSizes(); 2319 if (newSizes.size() == rank) { 2320 SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes); 2321 if (sizes.back() % expandRatio != 0) 2322 return failure(); 2323 sizes.back() = sizes.back() / expandRatio; 2324 newSizes = rewriter.getI64ArrayAttr(sizes); 2325 } 2326 2327 SmallVector<int64_t, 4> dims = 2328 llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape()); 2329 dims.back() = dims.back() / expandRatio; 2330 VectorType newExtractType = 2331 VectorType::get(dims, castSrcType.getElementType()); 2332 2333 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 2334 extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets, 2335 newSizes, extractOp.getStrides()); 2336 2337 rewriter.replaceOpWithNewOp<vector::BitCastOp>( 2338 extractOp, extractOp.getType(), newExtractOp); 2339 2340 return success(); 2341 } 2342 }; 2343 2344 // Shuffles vector.bitcast op before vector.insert_strided_slice op. 2345 // 2346 // This transforms IR like: 2347 // %0 = vector.insert_strided_slice %src, %dst { 2348 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> 2349 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> 2350 // Into: 2351 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32> 2352 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32> 2353 // %2 = vector.insert_strided_slice %src, %dst { 2354 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 2355 struct BubbleUpBitCastForStridedSliceInsert 2356 : public OpRewritePattern<vector::BitCastOp> { 2357 using OpRewritePattern::OpRewritePattern; 2358 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 2359 PatternRewriter &rewriter) const override { 2360 VectorType castSrcType = bitcastOp.getSourceVectorType(); 2361 VectorType castDstType = bitcastOp.getResultVectorType(); 2362 assert(castSrcType.getRank() == castDstType.getRank()); 2363 2364 int64_t castSrcLastDim = castSrcType.getShape().back(); 2365 int64_t castDstLastDim = castDstType.getShape().back(); 2366 // Require casting to less elements for now; other cases to be implemented. 2367 if (castSrcLastDim < castDstLastDim) 2368 return failure(); 2369 2370 assert(castSrcLastDim % castDstLastDim == 0); 2371 int64_t shrinkRatio = castSrcLastDim / castDstLastDim; 2372 2373 auto insertOp = 2374 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>(); 2375 if (!insertOp) 2376 return failure(); 2377 2378 // Only accept all one strides for now. 2379 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(), 2380 [](const APInt &val) { return !val.isOneValue(); })) 2381 return failure(); 2382 2383 unsigned rank = insertOp.getSourceVectorType().getRank(); 2384 // Require insert op to have the same rank for the source and destination 2385 // vector; other cases to be implemented. 2386 if (rank != insertOp.getDestVectorType().getRank()) 2387 return failure(); 2388 2389 ArrayAttr newOffsets = insertOp.getOffsets(); 2390 assert(newOffsets.size() == rank); 2391 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2392 if (offsets.back() % shrinkRatio != 0) 2393 return failure(); 2394 offsets.back() = offsets.back() / shrinkRatio; 2395 newOffsets = rewriter.getI64ArrayAttr(offsets); 2396 2397 SmallVector<int64_t, 4> srcDims = 2398 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape()); 2399 srcDims.back() = srcDims.back() / shrinkRatio; 2400 VectorType newCastSrcType = 2401 VectorType::get(srcDims, castDstType.getElementType()); 2402 2403 auto newCastSrcOp = rewriter.create<vector::BitCastOp>( 2404 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource()); 2405 2406 SmallVector<int64_t, 4> dstDims = 2407 llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); 2408 dstDims.back() = dstDims.back() / shrinkRatio; 2409 VectorType newCastDstType = 2410 VectorType::get(dstDims, castDstType.getElementType()); 2411 2412 auto newCastDstOp = rewriter.create<vector::BitCastOp>( 2413 bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); 2414 2415 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>( 2416 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, 2417 insertOp.getStrides()); 2418 2419 return success(); 2420 } 2421 }; 2422 2423 // Helper that returns a vector comparison that constructs a mask: 2424 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 2425 // 2426 // If `dim == 0` then the result will be a 0-D vector. 2427 // 2428 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 2429 // much more compact, IR for this operation, but LLVM eventually 2430 // generates more elaborate instructions for this intrinsic since it 2431 // is very conservative on the boundary conditions. 2432 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, 2433 bool force32BitVectorIndices, int64_t dim, 2434 Value b, Value *off = nullptr) { 2435 auto loc = op->getLoc(); 2436 // If we can assume all indices fit in 32-bit, we perform the vector 2437 // comparison in 32-bit to get a higher degree of SIMD parallelism. 2438 // Otherwise we perform the vector comparison using 64-bit indices. 2439 Type idxType = 2440 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 2441 DenseIntElementsAttr indicesAttr; 2442 if (dim == 0 && force32BitVectorIndices) { 2443 indicesAttr = DenseIntElementsAttr::get( 2444 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0}); 2445 } else if (dim == 0) { 2446 indicesAttr = DenseIntElementsAttr::get( 2447 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0}); 2448 } else if (force32BitVectorIndices) { 2449 indicesAttr = rewriter.getI32VectorAttr( 2450 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))); 2451 } else { 2452 indicesAttr = rewriter.getI64VectorAttr( 2453 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))); 2454 } 2455 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 2456 // Add in an offset if requested. 2457 if (off) { 2458 Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off); 2459 Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o); 2460 indices = rewriter.create<arith::AddIOp>(loc, ov, indices); 2461 } 2462 // Construct the vector comparison. 2463 Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b); 2464 Value bounds = 2465 rewriter.create<vector::SplatOp>(loc, indices.getType(), bound); 2466 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices, 2467 bounds); 2468 } 2469 2470 template <typename ConcreteOp> 2471 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> { 2472 public: 2473 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt) 2474 : mlir::OpRewritePattern<ConcreteOp>(context), 2475 force32BitVectorIndices(enableIndexOpt) {} 2476 2477 LogicalResult matchAndRewrite(ConcreteOp xferOp, 2478 PatternRewriter &rewriter) const override { 2479 if (!xferOp.hasOutOfBoundsDim()) 2480 return failure(); 2481 2482 if (xferOp.getVectorType().getRank() > 1 || 2483 llvm::size(xferOp.getIndices()) == 0) 2484 return failure(); 2485 2486 Location loc = xferOp->getLoc(); 2487 VectorType vtp = xferOp.getVectorType(); 2488 2489 // Create the in-bounds mask with all elements between [0 .. dim - offset) 2490 // set and [dim - offset .. vector_length) unset. 2491 // 2492 // TODO: when the leaf transfer rank is k > 1, we need the last `k` 2493 // dimensions here. 2494 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1; 2495 Value off = xferOp.getIndices()[lastIndex]; 2496 Value dim = 2497 vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex); 2498 Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off); 2499 Value mask = rewriter.create<vector::CreateMaskOp>( 2500 loc, 2501 VectorType::get(vtp.getShape(), rewriter.getI1Type(), 2502 vtp.getNumScalableDims()), 2503 b); 2504 if (xferOp.getMask()) { 2505 // Intersect the in-bounds with the mask specified as an op parameter. 2506 mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask()); 2507 } 2508 2509 rewriter.updateRootInPlace(xferOp, [&]() { 2510 xferOp.getMaskMutable().assign(mask); 2511 xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 2512 }); 2513 2514 return success(); 2515 } 2516 2517 private: 2518 const bool force32BitVectorIndices; 2519 }; 2520 2521 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only). 2522 class VectorCreateMaskOpConversion 2523 : public OpRewritePattern<vector::CreateMaskOp> { 2524 public: 2525 explicit VectorCreateMaskOpConversion(MLIRContext *context, 2526 bool enableIndexOpt) 2527 : mlir::OpRewritePattern<vector::CreateMaskOp>(context), 2528 force32BitVectorIndices(enableIndexOpt) {} 2529 2530 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 2531 PatternRewriter &rewriter) const override { 2532 auto dstType = op.getType(); 2533 if (dstType.cast<VectorType>().isScalable()) 2534 return failure(); 2535 int64_t rank = dstType.getRank(); 2536 if (rank > 1) 2537 return failure(); 2538 rewriter.replaceOp( 2539 op, buildVectorComparison(rewriter, op, force32BitVectorIndices, 2540 rank == 0 ? 0 : dstType.getDimSize(0), 2541 op.getOperand(0))); 2542 return success(); 2543 } 2544 2545 private: 2546 const bool force32BitVectorIndices; 2547 }; 2548 2549 // Drop inner most contiguous unit dimensions from transfer_read operand. 2550 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> { 2551 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 2552 2553 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 2554 PatternRewriter &rewriter) const override { 2555 // TODO: support 0-d corner case. 2556 if (readOp.getTransferRank() == 0) 2557 return failure(); 2558 2559 // TODO: support mask. 2560 if (readOp.getMask()) 2561 return failure(); 2562 2563 auto srcType = readOp.getSource().getType().dyn_cast<MemRefType>(); 2564 if (!srcType || !srcType.hasStaticShape()) 2565 return failure(); 2566 2567 if (!readOp.getPermutationMap().isMinorIdentity()) 2568 return failure(); 2569 2570 auto targetType = readOp.getVectorType(); 2571 if (targetType.getRank() <= 1) 2572 return failure(); 2573 2574 SmallVector<int64_t> srcStrides; 2575 int64_t srcOffset; 2576 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) 2577 return failure(); 2578 2579 size_t dimsToDrop = 0; 2580 for (size_t i = 1; i < srcStrides.size(); ++i) { 2581 int dim = srcType.getRank() - i - 1; 2582 if (srcStrides[dim] == 1) { 2583 dimsToDrop++; 2584 } else { 2585 break; 2586 } 2587 } 2588 if (dimsToDrop == 0) 2589 return failure(); 2590 2591 auto resultTargetVecType = 2592 VectorType::get(targetType.getShape().drop_back(dimsToDrop), 2593 targetType.getElementType()); 2594 2595 MemRefType resultMemrefType; 2596 if (srcType.getLayout().getAffineMap().isIdentity()) { 2597 resultMemrefType = MemRefType::get( 2598 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2599 {}, srcType.getMemorySpaceAsInt()); 2600 } else { 2601 AffineMap map = srcType.getLayout().getAffineMap(); 2602 int numSymbols = map.getNumSymbols(); 2603 for (size_t i = 0; i < dimsToDrop; ++i) { 2604 int dim = srcType.getRank() - i - 1; 2605 map = map.replace(rewriter.getAffineDimExpr(dim), 2606 rewriter.getAffineConstantExpr(0), 2607 map.getNumDims() - 1, numSymbols); 2608 } 2609 resultMemrefType = MemRefType::get( 2610 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2611 map, srcType.getMemorySpaceAsInt()); 2612 } 2613 2614 auto loc = readOp.getLoc(); 2615 SmallVector<int64_t> offsets(srcType.getRank(), 0); 2616 SmallVector<int64_t> strides(srcType.getRank(), 1); 2617 2618 ArrayAttr inBoundsAttr = 2619 readOp.getInBounds() 2620 ? rewriter.getArrayAttr( 2621 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)) 2622 : ArrayAttr(); 2623 Value rankedReducedView = rewriter.create<memref::SubViewOp>( 2624 loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(), 2625 strides); 2626 auto permMap = getTransferMinorIdentityMap( 2627 rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType); 2628 Value result = rewriter.create<vector::TransferReadOp>( 2629 loc, resultTargetVecType, rankedReducedView, 2630 readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), 2631 readOp.getPadding(), 2632 // TODO: support mask. 2633 /*mask=*/Value(), inBoundsAttr); 2634 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType, 2635 result); 2636 return success(); 2637 } 2638 }; 2639 2640 namespace { 2641 2642 /// This function checks to see if the vector combining kind 2643 /// is consistent with the integer or float element type. 2644 static bool isValidKind(bool isInt, vector::CombiningKind kind) { 2645 using vector::CombiningKind; 2646 enum class KindType { FLOAT, INT, INVALID }; 2647 KindType type{KindType::INVALID}; 2648 switch (kind) { 2649 case CombiningKind::MINF: 2650 case CombiningKind::MAXF: 2651 type = KindType::FLOAT; 2652 break; 2653 case CombiningKind::MINUI: 2654 case CombiningKind::MINSI: 2655 case CombiningKind::MAXUI: 2656 case CombiningKind::MAXSI: 2657 case CombiningKind::AND: 2658 case CombiningKind::OR: 2659 case CombiningKind::XOR: 2660 type = KindType::INT; 2661 break; 2662 case CombiningKind::ADD: 2663 case CombiningKind::MUL: 2664 type = isInt ? KindType::INT : KindType::FLOAT; 2665 break; 2666 } 2667 bool isValidIntKind = (type == KindType::INT) && isInt; 2668 bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt); 2669 return (isValidIntKind || isValidFloatKind); 2670 } 2671 2672 /// This function constructs the appropriate integer or float 2673 /// operation given the vector combining kind and operands. The 2674 /// supported int operations are : add, mul, min (signed/unsigned), 2675 /// max(signed/unsigned), and, or, xor. The supported float 2676 /// operations are : add, mul, min and max. 2677 static Value genOperator(Location loc, Value x, Value y, 2678 vector::CombiningKind kind, 2679 PatternRewriter &rewriter) { 2680 using vector::CombiningKind; 2681 2682 auto elType = x.getType().cast<VectorType>().getElementType(); 2683 bool isInt = elType.isIntOrIndex(); 2684 2685 Value combinedResult{nullptr}; 2686 switch (kind) { 2687 case CombiningKind::ADD: 2688 if (isInt) 2689 combinedResult = rewriter.create<arith::AddIOp>(loc, x, y); 2690 else 2691 combinedResult = rewriter.create<arith::AddFOp>(loc, x, y); 2692 break; 2693 case CombiningKind::MUL: 2694 if (isInt) 2695 combinedResult = rewriter.create<arith::MulIOp>(loc, x, y); 2696 else 2697 combinedResult = rewriter.create<arith::MulFOp>(loc, x, y); 2698 break; 2699 case CombiningKind::MINUI: 2700 combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y); 2701 break; 2702 case CombiningKind::MINSI: 2703 combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y); 2704 break; 2705 case CombiningKind::MAXUI: 2706 combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y); 2707 break; 2708 case CombiningKind::MAXSI: 2709 combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y); 2710 break; 2711 case CombiningKind::AND: 2712 combinedResult = rewriter.create<arith::AndIOp>(loc, x, y); 2713 break; 2714 case CombiningKind::OR: 2715 combinedResult = rewriter.create<arith::OrIOp>(loc, x, y); 2716 break; 2717 case CombiningKind::XOR: 2718 combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y); 2719 break; 2720 case CombiningKind::MINF: 2721 combinedResult = rewriter.create<arith::MinFOp>(loc, x, y); 2722 break; 2723 case CombiningKind::MAXF: 2724 combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y); 2725 break; 2726 } 2727 return combinedResult; 2728 } 2729 2730 /// Convert vector.scan op into arith ops and 2731 /// vector.insert_strided_slice/extract_strided_slice 2732 /// 2733 /// Ex: 2734 /// ``` 2735 /// %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 2736 /// 1} : 2737 /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>) 2738 /// ``` 2739 /// Gets converted to: 2740 /// ``` 2741 /// %cst = arith.constant dense<0> : vector<2x3xi32> 2742 /// %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1], 2743 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 = 2744 /// vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]} 2745 /// : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice 2746 /// %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : 2747 /// vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 : 2748 /// vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1], 2749 /// strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 = 2750 /// vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1], 2751 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3, 2752 /// %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets = 2753 /// [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 = 2754 /// vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 : 2755 /// vector<2x3xi32>, vector<2xi32> 2756 /// ``` 2757 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { 2758 using OpRewritePattern<vector::ScanOp>::OpRewritePattern; 2759 2760 LogicalResult matchAndRewrite(vector::ScanOp scanOp, 2761 PatternRewriter &rewriter) const override { 2762 auto loc = scanOp.getLoc(); 2763 VectorType destType = scanOp.getDestType(); 2764 ArrayRef<int64_t> destShape = destType.getShape(); 2765 auto elType = destType.getElementType(); 2766 bool isInt = elType.isIntOrIndex(); 2767 if (!isValidKind(isInt, scanOp.getKind())) 2768 return failure(); 2769 2770 VectorType resType = VectorType::get(destShape, elType); 2771 Value result = rewriter.create<arith::ConstantOp>( 2772 loc, resType, rewriter.getZeroAttr(resType)); 2773 int64_t reductionDim = scanOp.getReductionDim(); 2774 bool inclusive = scanOp.getInclusive(); 2775 int64_t destRank = destType.getRank(); 2776 VectorType initialValueType = scanOp.getInitialValueType(); 2777 int64_t initialValueRank = initialValueType.getRank(); 2778 2779 SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end()); 2780 reductionShape[reductionDim] = 1; 2781 VectorType reductionType = VectorType::get(reductionShape, elType); 2782 SmallVector<int64_t> offsets(destRank, 0); 2783 SmallVector<int64_t> strides(destRank, 1); 2784 SmallVector<int64_t> sizes(destShape.begin(), destShape.end()); 2785 sizes[reductionDim] = 1; 2786 ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes); 2787 ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides); 2788 2789 Value lastOutput, lastInput; 2790 for (int i = 0; i < destShape[reductionDim]; i++) { 2791 offsets[reductionDim] = i; 2792 ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets); 2793 Value input = rewriter.create<vector::ExtractStridedSliceOp>( 2794 loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes, 2795 scanStrides); 2796 Value output; 2797 if (i == 0) { 2798 if (inclusive) { 2799 output = input; 2800 } else { 2801 if (initialValueRank == 0) { 2802 // ShapeCastOp cannot handle 0-D vectors 2803 output = rewriter.create<vector::BroadcastOp>( 2804 loc, input.getType(), scanOp.getInitialValue()); 2805 } else { 2806 output = rewriter.create<vector::ShapeCastOp>( 2807 loc, input.getType(), scanOp.getInitialValue()); 2808 } 2809 } 2810 } else { 2811 Value y = inclusive ? input : lastInput; 2812 output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter); 2813 assert(output != nullptr); 2814 } 2815 result = rewriter.create<vector::InsertStridedSliceOp>( 2816 loc, output, result, offsets, strides); 2817 lastOutput = output; 2818 lastInput = input; 2819 } 2820 2821 Value reduction; 2822 if (initialValueRank == 0) { 2823 Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0); 2824 reduction = 2825 rewriter.create<vector::BroadcastOp>(loc, initialValueType, v); 2826 } else { 2827 reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType, 2828 lastOutput); 2829 } 2830 2831 rewriter.replaceOp(scanOp, {result, reduction}); 2832 return success(); 2833 } 2834 }; 2835 2836 } // namespace 2837 2838 void mlir::vector::populateVectorMaskMaterializationPatterns( 2839 RewritePatternSet &patterns, bool force32BitVectorIndices) { 2840 patterns.add<VectorCreateMaskOpConversion, 2841 MaterializeTransferMask<vector::TransferReadOp>, 2842 MaterializeTransferMask<vector::TransferWriteOp>>( 2843 patterns.getContext(), force32BitVectorIndices); 2844 } 2845 2846 void mlir::vector::populateShapeCastFoldingPatterns( 2847 RewritePatternSet &patterns) { 2848 patterns.add<ShapeCastOpFolder>(patterns.getContext()); 2849 } 2850 2851 void mlir::vector::populateBubbleVectorBitCastOpPatterns( 2852 RewritePatternSet &patterns) { 2853 patterns.add<BubbleDownVectorBitCastForExtract, 2854 BubbleDownBitCastForStridedSliceExtract, 2855 BubbleUpBitCastForStridedSliceInsert>(patterns.getContext()); 2856 } 2857 2858 void mlir::vector::populateVectorBroadcastLoweringPatterns( 2859 RewritePatternSet &patterns) { 2860 patterns.add<BroadcastOpLowering>(patterns.getContext()); 2861 } 2862 2863 void mlir::vector::populateVectorMaskOpLoweringPatterns( 2864 RewritePatternSet &patterns) { 2865 patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>( 2866 patterns.getContext()); 2867 } 2868 2869 void mlir::vector::populateVectorShapeCastLoweringPatterns( 2870 RewritePatternSet &patterns) { 2871 patterns.add<ShapeCastOp2DDownCastRewritePattern, 2872 ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>( 2873 patterns.getContext()); 2874 } 2875 2876 void mlir::vector::populateVectorContractLoweringPatterns( 2877 RewritePatternSet &patterns, VectorTransformsOptions options) { 2878 patterns.add<OuterProductOpLowering>(patterns.getContext()); 2879 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering, 2880 ContractionOpToOuterProductOpLowering>(options, 2881 patterns.getContext()); 2882 } 2883 2884 void mlir::vector::populateVectorTransposeLoweringPatterns( 2885 RewritePatternSet &patterns, VectorTransformsOptions options) { 2886 patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>( 2887 options, patterns.getContext()); 2888 } 2889 2890 void mlir::vector::populateVectorReductionToContractPatterns( 2891 RewritePatternSet &patterns) { 2892 patterns.add<MultiReduceToContract, CombineContractBroadcast, 2893 CombineContractTranspose, ReorderCastOpsOnBroadcast, 2894 ReorderElementwiseOpsOnTranspose>(patterns.getContext()); 2895 } 2896 2897 void mlir::vector:: 2898 populateVectorTransferCollapseInnerMostContiguousDimsPatterns( 2899 RewritePatternSet &patterns) { 2900 patterns.add<DropInnerMostUnitDims>(patterns.getContext()); 2901 } 2902 2903 void mlir::vector::populateVectorTransferLoweringPatterns( 2904 RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) { 2905 patterns.add<TransferReadToVectorLoadLowering, 2906 TransferWriteToVectorStoreLowering>(patterns.getContext(), 2907 maxTransferRank); 2908 patterns 2909 .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>( 2910 patterns.getContext()); 2911 } 2912 2913 void mlir::vector::populateVectorScanLoweringPatterns( 2914 RewritePatternSet &patterns) { 2915 patterns.add<ScanToArithOps>(patterns.getContext()); 2916 } 2917