1 //===- VectorTransforms.cpp - Conversion within the Vector dialect --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites as 1->N patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 14 15 #include <type_traits> 16 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 19 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 20 #include "mlir/Dialect/Linalg/IR/Linalg.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/IR/SCF.h" 23 #include "mlir/Dialect/Utils/IndexingUtils.h" 24 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 26 #include "mlir/IR/BuiltinTypes.h" 27 #include "mlir/IR/ImplicitLocOpBuilder.h" 28 #include "mlir/IR/Matchers.h" 29 #include "mlir/IR/PatternMatch.h" 30 #include "mlir/Interfaces/VectorInterfaces.h" 31 32 #include "llvm/ADT/DenseSet.h" 33 #include "llvm/ADT/MapVector.h" 34 #include "llvm/ADT/STLExtras.h" 35 #include "llvm/Support/CommandLine.h" 36 #include "llvm/Support/Debug.h" 37 #include "llvm/Support/raw_ostream.h" 38 39 #define DEBUG_TYPE "vector-to-vector" 40 41 using namespace mlir; 42 using namespace mlir::vector; 43 44 // Helper to find an index in an affine map. 45 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) { 46 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 47 int64_t idx = map.getDimPosition(i); 48 if (idx == index) 49 return i; 50 } 51 return None; 52 } 53 54 // Helper to construct iterator types with one index removed. 55 static SmallVector<Attribute, 4> adjustIter(ArrayAttr iteratorTypes, 56 int64_t index) { 57 SmallVector<Attribute, 4> results; 58 for (const auto &it : llvm::enumerate(iteratorTypes)) { 59 int64_t idx = it.index(); 60 if (idx == index) 61 continue; 62 results.push_back(it.value()); 63 } 64 return results; 65 } 66 67 // Helper to construct an affine map with one index removed. 68 static AffineMap adjustMap(AffineMap map, int64_t index, 69 PatternRewriter &rewriter) { 70 auto *ctx = rewriter.getContext(); 71 SmallVector<AffineExpr, 4> results; 72 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 73 int64_t idx = map.getDimPosition(i); 74 if (idx == index) 75 continue; 76 // Re-insert remaining indices, but renamed when occurring 77 // after the removed index. 78 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); 79 results.push_back(targetExpr); 80 } 81 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); 82 } 83 84 // Helper method to possibly drop a dimension in a load. 85 // TODO 86 static Value reshapeLoad(Location loc, Value val, VectorType type, 87 int64_t index, int64_t pos, 88 PatternRewriter &rewriter) { 89 if (index == -1) 90 return val; 91 Type lowType = VectorType::Builder(type).dropDim(0); 92 // At extraction dimension? 93 if (index == 0) { 94 auto posAttr = rewriter.getI64ArrayAttr(pos); 95 return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr); 96 } 97 // Unroll leading dimensions. 98 VectorType vType = lowType.cast<VectorType>(); 99 Type resType = VectorType::Builder(type).dropDim(index); 100 auto resVectorType = resType.cast<VectorType>(); 101 Value result = rewriter.create<arith::ConstantOp>( 102 loc, resVectorType, rewriter.getZeroAttr(resVectorType)); 103 for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { 104 auto posAttr = rewriter.getI64ArrayAttr(d); 105 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr); 106 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); 107 result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, 108 posAttr); 109 } 110 return result; 111 } 112 113 // Helper method to possibly drop a dimension in a store. 114 // TODO 115 static Value reshapeStore(Location loc, Value val, Value result, 116 VectorType type, int64_t index, int64_t pos, 117 PatternRewriter &rewriter) { 118 // Unmodified? 119 if (index == -1) 120 return val; 121 // At insertion dimension? 122 if (index == 0) { 123 auto posAttr = rewriter.getI64ArrayAttr(pos); 124 return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr); 125 } 126 // Unroll leading dimensions. 127 Type lowType = VectorType::Builder(type).dropDim(0); 128 VectorType vType = lowType.cast<VectorType>(); 129 Type insType = VectorType::Builder(vType).dropDim(0); 130 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { 131 auto posAttr = rewriter.getI64ArrayAttr(d); 132 Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr); 133 Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, posAttr); 134 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); 135 result = rewriter.create<vector::InsertOp>(loc, type, sto, result, posAttr); 136 } 137 return result; 138 } 139 140 template <typename IntType> 141 static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) { 142 return llvm::to_vector<4>(llvm::map_range( 143 arrayAttr.getAsRange<IntegerAttr>(), 144 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); 145 } 146 147 /// Helper to create arithmetic operation associated with a kind of contraction. 148 static Optional<Value> createContractArithOp(Location loc, Value x, Value y, 149 Value acc, 150 vector::CombiningKind kind, 151 PatternRewriter &rewriter, 152 bool isInt) { 153 using vector::CombiningKind; 154 Value mul; 155 if (isInt) { 156 if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) 157 // Only valid for floating point types. 158 return Optional<Value>(); 159 mul = rewriter.create<arith::MulIOp>(loc, x, y); 160 } else { 161 // Float case. 162 if (kind == CombiningKind::AND || kind == CombiningKind::MINUI || 163 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI || 164 kind == CombiningKind::MAXSI || kind == CombiningKind::OR || 165 kind == CombiningKind::XOR) 166 // Only valid for integer types. 167 return Optional<Value>(); 168 // Special case for fused multiply-add. 169 if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) { 170 return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc)); 171 } 172 mul = rewriter.create<arith::MulFOp>(loc, x, y); 173 } 174 if (!acc) 175 return Optional<Value>(mul); 176 return makeArithReduction(rewriter, loc, kind, mul, acc); 177 } 178 179 /// Return the positions of the reductions in the given map. 180 static SmallVector<int64_t> getReductionIndex(AffineMap map, 181 ArrayAttr iteratorTypes) { 182 SmallVector<int64_t> dimsIdx; 183 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 184 if (isReductionIterator(iteratorTypes[map.getDimPosition(i)])) 185 dimsIdx.push_back(i); 186 } 187 return dimsIdx; 188 } 189 190 /// Look for a given dimension in an affine map and return its position. Return 191 /// llvm::None if the dimension is not in the map results. 192 static llvm::Optional<unsigned> getDimPosition(AffineMap map, unsigned dim) { 193 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 194 if (map.getDimPosition(i) == dim) 195 return i; 196 } 197 return llvm::None; 198 } 199 200 namespace { 201 202 /// ShapeCastOpFolder folds cancelling ShapeCastOps away. 203 // 204 // Example: 205 // 206 // The following MLIR with cancelling ShapeCastOps: 207 // 208 // %0 = source : vector<5x4x2xf32> 209 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> 210 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> 211 // %3 = user %2 : vector<5x4x2xf32> 212 // 213 // Should canonicalize to the following: 214 // 215 // %0 = source : vector<5x4x2xf32> 216 // %1 = user %0 : vector<5x4x2xf32> 217 // 218 struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> { 219 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 220 221 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, 222 PatternRewriter &rewriter) const override { 223 // Check if 'shapeCastOp' has vector source/result type. 224 auto sourceVectorType = 225 shapeCastOp.getSource().getType().dyn_cast_or_null<VectorType>(); 226 auto resultVectorType = 227 shapeCastOp.getResult().getType().dyn_cast_or_null<VectorType>(); 228 if (!sourceVectorType || !resultVectorType) 229 return failure(); 230 231 // Check if shape cast op source operand is also a shape cast op. 232 auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>( 233 shapeCastOp.getSource().getDefiningOp()); 234 if (!sourceShapeCastOp) 235 return failure(); 236 auto operandSourceVectorType = 237 sourceShapeCastOp.getSource().getType().cast<VectorType>(); 238 auto operandResultVectorType = sourceShapeCastOp.getType(); 239 240 // Check if shape cast operations invert each other. 241 if (operandSourceVectorType != resultVectorType || 242 operandResultVectorType != sourceVectorType) 243 return failure(); 244 245 rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource()); 246 return success(); 247 } 248 }; 249 250 /// Progressive lowering of BroadcastOp. 251 class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> { 252 public: 253 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern; 254 255 LogicalResult matchAndRewrite(vector::BroadcastOp op, 256 PatternRewriter &rewriter) const override { 257 auto loc = op.getLoc(); 258 VectorType dstType = op.getVectorType(); 259 VectorType srcType = op.getSourceType().dyn_cast<VectorType>(); 260 Type eltType = dstType.getElementType(); 261 262 // Scalar to any vector can use splat. 263 if (!srcType) { 264 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource()); 265 return success(); 266 } 267 268 // Determine rank of source and destination. 269 int64_t srcRank = srcType.getRank(); 270 int64_t dstRank = dstType.getRank(); 271 272 // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. 273 if (srcRank <= 1 && dstRank == 1) { 274 Value ext; 275 if (srcRank == 0) 276 ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource()); 277 else 278 ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0); 279 rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext); 280 return success(); 281 } 282 283 // Duplicate this rank. 284 // For example: 285 // %x = broadcast %y : k-D to n-D, k < n 286 // becomes: 287 // %b = broadcast %y : k-D to (n-1)-D 288 // %x = [%b,%b,%b,%b] : n-D 289 // becomes: 290 // %b = [%y,%y] : (n-1)-D 291 // %x = [%b,%b,%b,%b] : n-D 292 if (srcRank < dstRank) { 293 // Duplication. 294 VectorType resType = 295 VectorType::get(dstType.getShape().drop_front(), eltType); 296 Value bcst = 297 rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource()); 298 Value result = rewriter.create<arith::ConstantOp>( 299 loc, dstType, rewriter.getZeroAttr(dstType)); 300 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 301 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 302 rewriter.replaceOp(op, result); 303 return success(); 304 } 305 306 // Find non-matching dimension, if any. 307 assert(srcRank == dstRank); 308 int64_t m = -1; 309 for (int64_t r = 0; r < dstRank; r++) 310 if (srcType.getDimSize(r) != dstType.getDimSize(r)) { 311 m = r; 312 break; 313 } 314 315 // All trailing dimensions are the same. Simply pass through. 316 if (m == -1) { 317 rewriter.replaceOp(op, op.getSource()); 318 return success(); 319 } 320 321 // Any non-matching dimension forces a stretch along this rank. 322 // For example: 323 // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32> 324 // becomes: 325 // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32> 326 // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32> 327 // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32> 328 // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32> 329 // %x = [%a,%b,%c,%d] 330 // becomes: 331 // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32> 332 // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32> 333 // %a = [%u, %v] 334 // .. 335 // %x = [%a,%b,%c,%d] 336 VectorType resType = 337 VectorType::get(dstType.getShape().drop_front(), eltType); 338 Value result = rewriter.create<arith::ConstantOp>( 339 loc, dstType, rewriter.getZeroAttr(dstType)); 340 if (m == 0) { 341 // Stetch at start. 342 Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0); 343 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 344 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) 345 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 346 } else { 347 // Stetch not at start. 348 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) { 349 Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d); 350 Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext); 351 result = rewriter.create<vector::InsertOp>(loc, bcst, result, d); 352 } 353 } 354 rewriter.replaceOp(op, result); 355 return success(); 356 } 357 }; 358 359 /// Given a 'transpose' pattern, prune the rightmost dimensions that are not 360 /// transposed. 361 void pruneNonTransposedDims(ArrayRef<int64_t> transpose, 362 SmallVectorImpl<int64_t> &result) { 363 size_t numTransposedDims = transpose.size(); 364 for (size_t transpDim : llvm::reverse(transpose)) { 365 if (transpDim != numTransposedDims - 1) 366 break; 367 numTransposedDims--; 368 } 369 370 result.append(transpose.begin(), transpose.begin() + numTransposedDims); 371 } 372 373 /// Progressive lowering of TransposeOp. 374 /// One: 375 /// %x = vector.transpose %y, [1, 0] 376 /// is replaced by: 377 /// %z = arith.constant dense<0.000000e+00> 378 /// %0 = vector.extract %y[0, 0] 379 /// %1 = vector.insert %0, %z [0, 0] 380 /// .. 381 /// %x = vector.insert .., .. [.., ..] 382 class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { 383 public: 384 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 385 386 TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions, 387 MLIRContext *context) 388 : OpRewritePattern<vector::TransposeOp>(context), 389 vectorTransformOptions(vectorTransformOptions) {} 390 391 LogicalResult matchAndRewrite(vector::TransposeOp op, 392 PatternRewriter &rewriter) const override { 393 auto loc = op.getLoc(); 394 395 Value input = op.getVector(); 396 VectorType inputType = op.getVectorType(); 397 VectorType resType = op.getResultType(); 398 399 // Set up convenience transposition table. 400 SmallVector<int64_t, 4> transp; 401 for (auto attr : op.getTransp()) 402 transp.push_back(attr.cast<IntegerAttr>().getInt()); 403 404 if (vectorTransformOptions.vectorTransposeLowering == 405 vector::VectorTransposeLowering::Shuffle && 406 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) 407 return rewriter.notifyMatchFailure( 408 op, "Options specifies lowering to shuffle"); 409 410 // Handle a true 2-D matrix transpose differently when requested. 411 if (vectorTransformOptions.vectorTransposeLowering == 412 vector::VectorTransposeLowering::Flat && 413 resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { 414 Type flattenedType = 415 VectorType::get(resType.getNumElements(), resType.getElementType()); 416 auto matrix = 417 rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input); 418 auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); 419 auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); 420 Value trans = rewriter.create<vector::FlatTransposeOp>( 421 loc, flattenedType, matrix, rows, columns); 422 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans); 423 return success(); 424 } 425 426 // Generate unrolled extract/insert ops. We do not unroll the rightmost 427 // (i.e., highest-order) dimensions that are not transposed and leave them 428 // in vector form to improve performance. Therefore, we prune those 429 // dimensions from the shape/transpose data structures used to generate the 430 // extract/insert ops. 431 SmallVector<int64_t, 4> prunedTransp; 432 pruneNonTransposedDims(transp, prunedTransp); 433 size_t numPrunedDims = transp.size() - prunedTransp.size(); 434 auto prunedInShape = inputType.getShape().drop_back(numPrunedDims); 435 SmallVector<int64_t, 4> ones(prunedInShape.size(), 1); 436 auto prunedInStrides = computeStrides(prunedInShape, ones); 437 438 // Generates the extract/insert operations for every scalar/vector element 439 // of the leftmost transposed dimensions. We traverse every transpose 440 // element using a linearized index that we delinearize to generate the 441 // appropriate indices for the extract/insert operations. 442 Value result = rewriter.create<arith::ConstantOp>( 443 loc, resType, rewriter.getZeroAttr(resType)); 444 int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape); 445 446 for (int64_t linearIdx = 0; linearIdx < numTransposedElements; 447 ++linearIdx) { 448 auto extractIdxs = delinearize(prunedInStrides, linearIdx); 449 SmallVector<int64_t, 4> insertIdxs(extractIdxs); 450 applyPermutationToVector(insertIdxs, prunedTransp); 451 Value extractOp = 452 rewriter.create<vector::ExtractOp>(loc, input, extractIdxs); 453 result = 454 rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs); 455 } 456 457 rewriter.replaceOp(op, result); 458 return success(); 459 } 460 461 private: 462 /// Options to control the vector patterns. 463 vector::VectorTransformsOptions vectorTransformOptions; 464 }; 465 466 /// Rewrite a 2-D vector.transpose as a sequence of: 467 /// vector.shape_cast 2D -> 1D 468 /// vector.shuffle 469 /// vector.shape_cast 1D -> 2D 470 class TransposeOp2DToShuffleLowering 471 : public OpRewritePattern<vector::TransposeOp> { 472 public: 473 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 474 475 TransposeOp2DToShuffleLowering( 476 vector::VectorTransformsOptions vectorTransformOptions, 477 MLIRContext *context) 478 : OpRewritePattern<vector::TransposeOp>(context), 479 vectorTransformOptions(vectorTransformOptions) {} 480 481 LogicalResult matchAndRewrite(vector::TransposeOp op, 482 PatternRewriter &rewriter) const override { 483 auto loc = op.getLoc(); 484 485 VectorType srcType = op.getVectorType(); 486 if (srcType.getRank() != 2) 487 return rewriter.notifyMatchFailure(op, "Not a 2D transpose"); 488 489 SmallVector<int64_t, 4> transp; 490 for (auto attr : op.getTransp()) 491 transp.push_back(attr.cast<IntegerAttr>().getInt()); 492 if (transp[0] != 1 && transp[1] != 0) 493 return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation"); 494 495 if (vectorTransformOptions.vectorTransposeLowering != 496 VectorTransposeLowering::Shuffle) 497 return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle"); 498 499 int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); 500 Value casted = rewriter.create<vector::ShapeCastOp>( 501 loc, VectorType::get({m * n}, srcType.getElementType()), 502 op.getVector()); 503 SmallVector<int64_t> mask; 504 mask.reserve(m * n); 505 for (int64_t j = 0; j < n; ++j) 506 for (int64_t i = 0; i < m; ++i) 507 mask.push_back(i * n + j); 508 509 Value shuffled = 510 rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask); 511 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(), 512 shuffled); 513 514 return success(); 515 } 516 517 private: 518 /// Options to control the vector patterns. 519 vector::VectorTransformsOptions vectorTransformOptions; 520 }; 521 522 /// Progressive lowering of OuterProductOp. 523 /// One: 524 /// %x = vector.outerproduct %lhs, %rhs, %acc 525 /// is replaced by: 526 /// %z = zero-result 527 /// %0 = vector.extract %lhs[0] 528 /// %1 = vector.broadcast %0 529 /// %2 = vector.extract %acc[0] 530 /// %3 = vector.fma %1, %rhs, %2 531 /// %4 = vector.insert %3, %z[0] 532 /// .. 533 /// %x = vector.insert %.., %..[N-1] 534 /// 535 class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { 536 public: 537 using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern; 538 539 LogicalResult matchAndRewrite(vector::OuterProductOp op, 540 PatternRewriter &rewriter) const override { 541 auto loc = op.getLoc(); 542 543 VectorType lhsType = op.getOperandVectorTypeLHS(); 544 VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>(); 545 VectorType resType = op.getVectorType(); 546 Type eltType = resType.getElementType(); 547 bool isInt = eltType.isa<IntegerType, IndexType>(); 548 Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0]; 549 vector::CombiningKind kind = op.getKind(); 550 551 if (!rhsType) { 552 // Special case: AXPY operation. 553 Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs()); 554 Optional<Value> mult = createContractArithOp(loc, op.getLhs(), b, acc, 555 kind, rewriter, isInt); 556 if (!mult.has_value()) 557 return failure(); 558 rewriter.replaceOp(op, mult.value()); 559 return success(); 560 } 561 562 Value result = rewriter.create<arith::ConstantOp>( 563 loc, resType, rewriter.getZeroAttr(resType)); 564 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { 565 auto pos = rewriter.getI64ArrayAttr(d); 566 Value x = 567 rewriter.create<vector::ExtractOp>(loc, eltType, op.getLhs(), pos); 568 Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x); 569 Value r = nullptr; 570 if (acc) 571 r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos); 572 Optional<Value> m = 573 createContractArithOp(loc, a, op.getRhs(), r, kind, rewriter, isInt); 574 if (!m.has_value()) 575 return failure(); 576 result = rewriter.create<vector::InsertOp>(loc, resType, m.value(), 577 result, pos); 578 } 579 rewriter.replaceOp(op, result); 580 return success(); 581 } 582 }; 583 584 /// Lower vector.contract with all size one reduction dimensions to 585 /// elementwise ops when possible. 586 struct ContractOpToElementwise 587 : public OpRewritePattern<vector::ContractionOp> { 588 using OpRewritePattern::OpRewritePattern; 589 using FilterConstraintType = 590 std::function<LogicalResult(vector::ContractionOp op)>; 591 static LogicalResult defaultFilter(vector::ContractionOp op) { 592 return success(); 593 } 594 ContractOpToElementwise( 595 vector::VectorTransformsOptions vectorTransformOptions, 596 MLIRContext *context, 597 const FilterConstraintType &constraint = defaultFilter) 598 : OpRewritePattern<vector::ContractionOp>(context), 599 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} 600 601 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 602 PatternRewriter &rewriter) const override { 603 // TODO: implement masks 604 if (llvm::size(contractOp.getMasks()) != 0) 605 return failure(); 606 607 if (failed(filter(contractOp))) 608 return failure(); 609 610 if (vectorTransformOptions.vectorContractLowering != 611 vector::VectorContractLowering::ParallelArith) 612 return failure(); 613 ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape(); 614 ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape(); 615 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0]; 616 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1]; 617 SmallVector<int64_t> lhsReductionDims = 618 getReductionIndex(lhsMap, contractOp.getIteratorTypes()); 619 SmallVector<int64_t> rhsReductionDims = 620 getReductionIndex(rhsMap, contractOp.getIteratorTypes()); 621 // All the reduction dimensions must be a size 1. 622 for (int64_t dim : lhsReductionDims) { 623 if (lhsShape[dim] != 1) 624 return failure(); 625 } 626 for (int64_t dim : rhsReductionDims) { 627 if (rhsShape[dim] != 1) 628 return failure(); 629 } 630 AffineMap accMap = contractOp.getIndexingMapsArray()[2]; 631 unsigned numParallelDims = accMap.getNumResults(); 632 unsigned numLhsDimToBroadcast = 633 numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size()); 634 unsigned numRhsDimToBroadcast = 635 numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size()); 636 SmallVector<int64_t> lhsDims; 637 SmallVector<int64_t> lhsTranspose; 638 SmallVector<int64_t> rhsDims; 639 SmallVector<int64_t> rhsTranspose; 640 for (int64_t dim : lhsReductionDims) 641 lhsTranspose.push_back(numLhsDimToBroadcast + dim); 642 for (int64_t dim : rhsReductionDims) 643 rhsTranspose.push_back(numRhsDimToBroadcast + dim); 644 // Loop through the parallel dimensions to calculate the dimensions to 645 // broadcast and to permute in order to extract only parallel dimensions. 646 for (unsigned i = 0; i < numParallelDims; i++) { 647 llvm::Optional<unsigned> lhsDim = 648 getDimPosition(lhsMap, accMap.getDimPosition(i)); 649 if (lhsDim) { 650 lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim); 651 } else { 652 // If the parallel dimension doesn't exist we will have to broadcast it. 653 lhsDims.push_back( 654 contractOp.getResultType().cast<VectorType>().getDimSize(i)); 655 lhsTranspose.push_back(lhsDims.size() - 1); 656 } 657 llvm::Optional<unsigned> rhsDim = 658 getDimPosition(rhsMap, accMap.getDimPosition(i)); 659 if (rhsDim) { 660 rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim); 661 } else { 662 // If the parallel dimension doesn't exist we will have to broadcast it. 663 rhsDims.push_back( 664 contractOp.getResultType().cast<VectorType>().getDimSize(i)); 665 rhsTranspose.push_back(rhsDims.size() - 1); 666 } 667 } 668 Value newLhs = contractOp.getLhs(); 669 Value newRhs = contractOp.getRhs(); 670 Location loc = contractOp.getLoc(); 671 if (!lhsDims.empty()) { 672 lhsDims.append(lhsShape.begin(), lhsShape.end()); 673 auto expandedType = 674 VectorType::get(lhsDims, contractOp.getLhsType().getElementType()); 675 newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs); 676 } 677 if (!rhsDims.empty()) { 678 rhsDims.append(rhsShape.begin(), rhsShape.end()); 679 auto expandedType = 680 VectorType::get(rhsDims, contractOp.getRhsType().getElementType()); 681 newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs); 682 } 683 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex(); 684 newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose); 685 newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose); 686 SmallVector<int64_t, 4> lhsOffsets(lhsReductionDims.size(), 0); 687 SmallVector<int64_t, 4> rhsOffsets(rhsReductionDims.size(), 0); 688 newLhs = rewriter.create<vector::ExtractOp>( 689 loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets)); 690 newRhs = rewriter.create<vector::ExtractOp>( 691 loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets)); 692 Optional<Value> result = 693 createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(), 694 contractOp.getKind(), rewriter, isInt); 695 rewriter.replaceOp(contractOp, {*result}); 696 return success(); 697 } 698 699 private: 700 /// Options to control the vector patterns. 701 vector::VectorTransformsOptions vectorTransformOptions; 702 FilterConstraintType filter; 703 }; 704 705 /// Progressive lowering of ConstantMaskOp. 706 /// One: 707 /// %x = vector.constant_mask [a,b] 708 /// is replaced by: 709 /// %z = zero-result 710 /// %l = vector.constant_mask [b] 711 /// %4 = vector.insert %l, %z[0] 712 /// .. 713 /// %x = vector.insert %l, %..[a-1] 714 /// until a one-dimensional vector is reached. All these operations 715 /// will be folded at LLVM IR level. 716 class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { 717 public: 718 using OpRewritePattern<vector::ConstantMaskOp>::OpRewritePattern; 719 720 LogicalResult matchAndRewrite(vector::ConstantMaskOp op, 721 PatternRewriter &rewriter) const override { 722 auto loc = op.getLoc(); 723 auto dstType = op.getType(); 724 auto eltType = dstType.getElementType(); 725 auto dimSizes = op.getMaskDimSizes(); 726 int64_t rank = dstType.getRank(); 727 728 if (rank == 0) { 729 assert(dimSizes.size() == 1 && 730 "Expected exactly one dim size for a 0-D vector"); 731 bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1; 732 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 733 op, dstType, 734 DenseIntElementsAttr::get( 735 VectorType::get(ArrayRef<int64_t>{}, rewriter.getI1Type()), 736 ArrayRef<bool>{value})); 737 return success(); 738 } 739 740 // Scalable constant masks can only be lowered for the "none set" case. 741 if (dstType.cast<VectorType>().isScalable()) { 742 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 743 op, DenseElementsAttr::get(dstType, false)); 744 return success(); 745 } 746 747 int64_t trueDim = std::min(dstType.getDimSize(0), 748 dimSizes[0].cast<IntegerAttr>().getInt()); 749 750 if (rank == 1) { 751 // Express constant 1-D case in explicit vector form: 752 // [T,..,T,F,..,F]. 753 SmallVector<bool, 4> values(dstType.getDimSize(0)); 754 for (int64_t d = 0; d < trueDim; d++) 755 values[d] = true; 756 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 757 op, dstType, rewriter.getBoolVectorAttr(values)); 758 return success(); 759 } 760 761 VectorType lowType = 762 VectorType::get(dstType.getShape().drop_front(), eltType); 763 SmallVector<int64_t, 4> newDimSizes; 764 for (int64_t r = 1; r < rank; r++) 765 newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt()); 766 Value trueVal = rewriter.create<vector::ConstantMaskOp>( 767 loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); 768 Value result = rewriter.create<arith::ConstantOp>( 769 loc, dstType, rewriter.getZeroAttr(dstType)); 770 for (int64_t d = 0; d < trueDim; d++) { 771 auto pos = rewriter.getI64ArrayAttr(d); 772 result = 773 rewriter.create<vector::InsertOp>(loc, dstType, trueVal, result, pos); 774 } 775 rewriter.replaceOp(op, result); 776 return success(); 777 } 778 }; 779 780 /// Progressive lowering of CreateMaskOp. 781 /// One: 782 /// %x = vector.create_mask %a, ... : vector<dx...> 783 /// is replaced by: 784 /// %l = vector.create_mask ... : vector<...> ; one lower rank 785 /// %0 = arith.cmpi "slt", %ci, %a | 786 /// %1 = select %0, %l, %zeroes | 787 /// %r = vector.insert %1, %pr [i] | d-times 788 /// %x = .... 789 /// until a one-dimensional vector is reached. 790 class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { 791 public: 792 using OpRewritePattern<vector::CreateMaskOp>::OpRewritePattern; 793 794 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 795 PatternRewriter &rewriter) const override { 796 auto dstType = op.getResult().getType().cast<VectorType>(); 797 int64_t rank = dstType.getRank(); 798 if (rank <= 1) 799 return rewriter.notifyMatchFailure( 800 op, "0-D and 1-D vectors are handled separately"); 801 802 auto loc = op.getLoc(); 803 auto eltType = dstType.getElementType(); 804 int64_t dim = dstType.getDimSize(0); 805 Value idx = op.getOperand(0); 806 807 VectorType lowType = 808 VectorType::get(dstType.getShape().drop_front(), eltType); 809 Value trueVal = rewriter.create<vector::CreateMaskOp>( 810 loc, lowType, op.getOperands().drop_front()); 811 Value falseVal = rewriter.create<arith::ConstantOp>( 812 loc, lowType, rewriter.getZeroAttr(lowType)); 813 Value result = rewriter.create<arith::ConstantOp>( 814 loc, dstType, rewriter.getZeroAttr(dstType)); 815 for (int64_t d = 0; d < dim; d++) { 816 Value bnd = 817 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d)); 818 Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, 819 bnd, idx); 820 Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal); 821 auto pos = rewriter.getI64ArrayAttr(d); 822 result = 823 rewriter.create<vector::InsertOp>(loc, dstType, sel, result, pos); 824 } 825 rewriter.replaceOp(op, result); 826 return success(); 827 } 828 }; 829 830 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D 831 /// vectors progressively on the way to target llvm.matrix intrinsics. 832 /// This iterates over the most major dimension of the 2-D vector and performs 833 /// rewrites into: 834 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D 835 class ShapeCastOp2DDownCastRewritePattern 836 : public OpRewritePattern<vector::ShapeCastOp> { 837 public: 838 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 839 840 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 841 PatternRewriter &rewriter) const override { 842 auto sourceVectorType = op.getSourceVectorType(); 843 auto resultVectorType = op.getResultVectorType(); 844 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) 845 return failure(); 846 847 auto loc = op.getLoc(); 848 Value desc = rewriter.create<arith::ConstantOp>( 849 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 850 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; 851 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { 852 Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i); 853 desc = rewriter.create<vector::InsertStridedSliceOp>( 854 loc, vec, desc, 855 /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); 856 } 857 rewriter.replaceOp(op, desc); 858 return success(); 859 } 860 }; 861 862 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D 863 /// vectors progressively. 864 /// This iterates over the most major dimension of the 2-D vector and performs 865 /// rewrites into: 866 /// vector.extract_strided_slice from 1-D + vector.insert into 2-D 867 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. 868 class ShapeCastOp2DUpCastRewritePattern 869 : public OpRewritePattern<vector::ShapeCastOp> { 870 public: 871 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 872 873 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 874 PatternRewriter &rewriter) const override { 875 auto sourceVectorType = op.getSourceVectorType(); 876 auto resultVectorType = op.getResultVectorType(); 877 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) 878 return failure(); 879 880 auto loc = op.getLoc(); 881 Value desc = rewriter.create<arith::ConstantOp>( 882 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 883 unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; 884 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { 885 Value vec = rewriter.create<vector::ExtractStridedSliceOp>( 886 loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize, 887 /*sizes=*/mostMinorVectorSize, 888 /*strides=*/1); 889 desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i); 890 } 891 rewriter.replaceOp(op, desc); 892 return success(); 893 } 894 }; 895 896 // We typically should not lower general shape cast operations into data 897 // movement instructions, since the assumption is that these casts are 898 // optimized away during progressive lowering. For completeness, however, 899 // we fall back to a reference implementation that moves all elements 900 // into the right place if we get here. 901 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { 902 public: 903 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 904 905 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 906 PatternRewriter &rewriter) const override { 907 Location loc = op.getLoc(); 908 auto sourceVectorType = op.getSourceVectorType(); 909 auto resultVectorType = op.getResultVectorType(); 910 911 // Special case 2D/1D lowerings with better implementations. 912 // TODO: make is ND/1D to allow generic ND->1D->MD. 913 int64_t srcRank = sourceVectorType.getRank(); 914 int64_t resRank = resultVectorType.getRank(); 915 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2)) 916 return failure(); 917 918 // Generic ShapeCast lowering path goes all the way down to unrolled scalar 919 // extract/insert chains. 920 // TODO: consider evolving the semantics to only allow 1D source or dest and 921 // drop this potentially very expensive lowering. 922 // Compute number of elements involved in the reshape. 923 int64_t numElts = 1; 924 for (int64_t r = 0; r < srcRank; r++) 925 numElts *= sourceVectorType.getDimSize(r); 926 // Replace with data movement operations: 927 // x[0,0,0] = y[0,0] 928 // x[0,0,1] = y[0,1] 929 // x[0,1,0] = y[0,2] 930 // etc., incrementing the two index vectors "row-major" 931 // within the source and result shape. 932 SmallVector<int64_t, 4> srcIdx(srcRank); 933 SmallVector<int64_t, 4> resIdx(resRank); 934 Value result = rewriter.create<arith::ConstantOp>( 935 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 936 for (int64_t i = 0; i < numElts; i++) { 937 if (i != 0) { 938 incIdx(srcIdx, sourceVectorType, srcRank - 1); 939 incIdx(resIdx, resultVectorType, resRank - 1); 940 } 941 Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx); 942 result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx); 943 } 944 rewriter.replaceOp(op, result); 945 return success(); 946 } 947 948 private: 949 static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) { 950 assert(0 <= r && r < tp.getRank()); 951 if (++idx[r] == tp.getDimSize(r)) { 952 idx[r] = 0; 953 incIdx(idx, tp, r - 1); 954 } 955 } 956 }; 957 958 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp. 959 /// Ex: 960 /// ``` 961 /// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> 962 /// %1 = vector.multi_reduction add, %0 [1] 963 /// : vector<8x32x16xf32> to vector<8x16xf32> 964 /// ``` 965 /// Gets converted to: 966 /// ``` 967 /// %1 = vector.contract {indexing_maps = [ 968 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 969 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 970 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 971 /// iterator_types = ["parallel", "parallel", "reduction"], 972 /// kind = add} %0, %arg1, %cst_f0 973 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 974 /// ``` 975 struct MultiReduceToContract 976 : public OpRewritePattern<vector::MultiDimReductionOp> { 977 using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern; 978 979 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, 980 PatternRewriter &rewriter) const override { 981 if (reduceOp.getKind() != vector::CombiningKind::ADD) 982 return failure(); 983 Operation *mulOp = reduceOp.getSource().getDefiningOp(); 984 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp)) 985 return failure(); 986 SmallVector<bool> reductionMask = reduceOp.getReductionMask(); 987 auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size()); 988 SmallVector<AffineExpr> exprs; 989 SmallVector<StringRef> iteratorTypes; 990 for (const auto &isReduceDim : llvm::enumerate(reductionMask)) { 991 if (!isReduceDim.value()) { 992 iteratorTypes.push_back(getParallelIteratorTypeName()); 993 exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index())); 994 } else { 995 iteratorTypes.push_back(getReductionIteratorTypeName()); 996 } 997 } 998 auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(), 999 /*symCount=*/0, exprs, reduceOp.getContext()); 1000 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>( 1001 reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(), 1002 rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}), 1003 rewriter.getStrArrayAttr(iteratorTypes)); 1004 return success(); 1005 } 1006 }; 1007 1008 /// Merge TransposeOp into ContractionOp user. 1009 /// Ex: 1010 /// ``` 1011 /// %0 = vector.transpose %arg0, [2, 0, 1] 1012 /// : vector<32x16x8xf32> to vector<8x32x16xf32> 1013 /// %1 = vector.contract {indexing_maps = [ 1014 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 1015 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 1016 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 1017 /// iterator_types = ["parallel", "parallel", "reduction"], 1018 /// kind = add} %0, %arg1, %cst_f0 1019 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 1020 /// ``` 1021 /// Gets converted to: 1022 /// ``` 1023 /// %1 = vector.contract {indexing_maps = [ 1024 /// affine_map<(d0, d1, d2) -> (d1, d2, d0)>, 1025 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 1026 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 1027 /// iterator_types = ["parallel", "parallel", "reduction"], 1028 /// kind = add} %arg0, %arg1, %cst_f0 1029 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 1030 /// ``` 1031 struct CombineContractTranspose 1032 : public OpRewritePattern<vector::ContractionOp> { 1033 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 1034 1035 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 1036 PatternRewriter &rewriter) const override { 1037 SmallVector<AffineMap, 4> maps = 1038 llvm::to_vector<4>(contractOp.getIndexingMapsArray()); 1039 Value lhs = contractOp.getLhs(); 1040 Value rhs = contractOp.getRhs(); 1041 size_t index = 0; 1042 bool changed = false; 1043 for (Value *operand : {&lhs, &rhs}) { 1044 AffineMap &map = maps[index++]; 1045 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>(); 1046 if (!transposeOp) 1047 continue; 1048 SmallVector<int64_t> perm; 1049 transposeOp.getTransp(perm); 1050 AffineMap permutationMap = AffineMap::getPermutationMap( 1051 extractVector<unsigned>(transposeOp.getTransp()), 1052 contractOp.getContext()); 1053 map = inversePermutation(permutationMap).compose(map); 1054 *operand = transposeOp.getVector(); 1055 changed = true; 1056 } 1057 if (!changed) 1058 return failure(); 1059 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 1060 contractOp, lhs, rhs, contractOp.getAcc(), 1061 rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); 1062 return success(); 1063 } 1064 }; 1065 1066 /// Merge BroadcastOp into ContractionOp user. 1067 /// Ex: 1068 /// ``` 1069 /// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> 1070 /// %1 = vector.contract {indexing_maps = [ 1071 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 1072 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 1073 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 1074 /// iterator_types = ["parallel", "parallel", "reduction"], 1075 /// kind = add} %0, %arg1, %cst_f0 1076 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 1077 /// ``` 1078 /// Gets converted to: 1079 /// ``` 1080 /// %1 = vector.contract {indexing_maps = [ 1081 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 1082 /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 1083 /// affine_map<(d0, d1, d2) -> (d0, d1)>], 1084 /// iterator_types = ["parallel", "parallel", "reduction"], 1085 /// kind = add} %arg0, %arg1, %cst_f0 1086 /// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> 1087 /// ``` 1088 struct CombineContractBroadcast 1089 : public OpRewritePattern<vector::ContractionOp> { 1090 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 1091 1092 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 1093 PatternRewriter &rewriter) const override { 1094 SmallVector<AffineMap, 4> maps = 1095 llvm::to_vector<4>(contractOp.getIndexingMapsArray()); 1096 Value lhs = contractOp.getLhs(); 1097 Value rhs = contractOp.getRhs(); 1098 size_t index = 0; 1099 bool changed = false; 1100 for (Value *operand : {&lhs, &rhs}) { 1101 AffineMap &map = maps[index++]; 1102 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>(); 1103 if (!broadcast) 1104 continue; 1105 // contractionOp can only take vector as operands. 1106 auto srcType = broadcast.getSourceType().dyn_cast<VectorType>(); 1107 if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank()) 1108 continue; 1109 int64_t rankDiff = 1110 broadcast.getVectorType().getRank() - srcType.getRank(); 1111 bool innerDimBroadcast = false; 1112 SmallVector<AffineExpr> originalDims; 1113 for (const auto &dim : llvm::enumerate(srcType.getShape())) { 1114 if (dim.value() != 1115 broadcast.getVectorType().getDimSize(rankDiff + dim.index())) { 1116 innerDimBroadcast = true; 1117 break; 1118 } 1119 originalDims.push_back( 1120 rewriter.getAffineDimExpr(dim.index() + rankDiff)); 1121 } 1122 // Contract doesn't support inner dimension broadcast. Once this is 1123 // relaxed we can remove this case. 1124 if (innerDimBroadcast) 1125 continue; 1126 1127 // It would be incorrect to fold a broadcast onto a reduction dimension 1128 // of non-unit size. 1129 bool nonUnitDimReductionBroadcast = false; 1130 for (int64_t i = 0; i < rankDiff; ++i) { 1131 if (broadcast.getVectorType().getDimSize(i) != 1 && 1132 isReductionIterator(contractOp.getIteratorTypes() 1133 .getValue()[map.getDimPosition(i)])) { 1134 nonUnitDimReductionBroadcast = true; 1135 break; 1136 } 1137 } 1138 if (nonUnitDimReductionBroadcast) 1139 continue; 1140 1141 AffineMap broadcastMap = 1142 AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims, 1143 contractOp.getContext()); 1144 map = broadcastMap.compose(map); 1145 *operand = broadcast.getSource(); 1146 changed = true; 1147 } 1148 1149 if (!changed) 1150 return failure(); 1151 1152 // Determine which dims are usused, now that the maps have been composed 1153 // with the broadcast maps. 1154 llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps); 1155 // Compress unused dims. 1156 for (auto &m : maps) 1157 m = compressDims(m, unusedDimsBitVector); 1158 // Compute the combined iterators. 1159 SmallVector<Attribute, 4> iterators; 1160 for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) { 1161 if (!unusedDimsBitVector.test(i)) 1162 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]); 1163 } 1164 // Check that compressing unused dims isn't removing all reduction dimension 1165 // pairs. For example, if the vector.contract had only one reduction 1166 // iterator and that was a unit-dimension created by a broadcast, 1167 // then we should bail here, otherwise we would create a contract without 1168 // a reduction dimension pair. 1169 bool hasReductionIteratorApplyingOnBothSides = false; 1170 for (unsigned i = 0; i < iterators.size(); ++i) { 1171 if (!isReductionIterator(iterators[i])) 1172 continue; 1173 if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) { 1174 hasReductionIteratorApplyingOnBothSides = true; 1175 break; 1176 } 1177 } 1178 if (!hasReductionIteratorApplyingOnBothSides) 1179 return failure(); 1180 1181 // If the compressed maps have a dimension that is not used by either LHS or 1182 // RHS then the ContractionOp verifier would fail. 1183 if (getUnusedDimsBitVector({maps[0], maps[1]}).any()) 1184 return failure(); 1185 rewriter.replaceOpWithNewOp<vector::ContractionOp>( 1186 contractOp, lhs, rhs, contractOp.getAcc(), 1187 rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators)); 1188 return success(); 1189 } 1190 }; 1191 1192 /// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and 1193 /// contraction ops closer, which kicks in CombineContractBroadcast pattern when 1194 /// casting ops are around these operations. 1195 /// Ex: 1196 /// ``` 1197 /// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8> 1198 /// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> 1199 /// ``` 1200 /// Gets converted to: 1201 /// ``` 1202 /// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32> 1203 /// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32> 1204 /// ``` 1205 struct ReorderCastOpsOnBroadcast 1206 : public OpInterfaceRewritePattern<CastOpInterface> { 1207 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern; 1208 1209 LogicalResult matchAndRewrite(CastOpInterface op, 1210 PatternRewriter &rewriter) const override { 1211 if (op->getNumOperands() != 1) 1212 return failure(); 1213 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>(); 1214 if (!bcastOp) 1215 return failure(); 1216 1217 Type castResTy = getElementTypeOrSelf(op->getResult(0)); 1218 if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>()) 1219 castResTy = VectorType::get(vecTy.getShape(), castResTy); 1220 auto *castOp = 1221 rewriter.create(op->getLoc(), op->getName().getIdentifier(), 1222 bcastOp.getSource(), castResTy, op->getAttrs()); 1223 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 1224 op, op->getResult(0).getType(), castOp->getResult(0)); 1225 return success(); 1226 } 1227 }; 1228 1229 /// Reorders elementwise(transpose) to transpose(elementwise). This makes 1230 /// transpose ops and contraction ops closer, which kicks in 1231 /// CombineContractTranspose pattern when elementwise ops are between these 1232 /// operations. Ex: 1233 /// ``` 1234 /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> 1235 /// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> 1236 /// %r = arith.addf %at, %bt : vector<2x4xf32> 1237 /// ``` 1238 /// Gets converted to: 1239 /// ``` 1240 /// %0 = arith.addf %a, %b : vector<4x2xf32> 1241 /// %r = vector.transpose %0, [1, 0] : vector<2x4xf32> 1242 /// ``` 1243 struct ReorderElementwiseOpsOnTranspose final 1244 : public OpTraitRewritePattern<OpTrait::Elementwise> { 1245 using OpTraitRewritePattern::OpTraitRewritePattern; 1246 LogicalResult matchAndRewrite(Operation *op, 1247 PatternRewriter &rewriter) const override { 1248 if (op->getNumResults() != 1 || op->getNumRegions() != 0) 1249 return failure(); 1250 1251 // Make sure all operands are transpose/constant ops and collect their 1252 // transposition maps. 1253 SmallVector<ArrayAttr, 4> transposeMaps; 1254 transposeMaps.reserve(op->getNumOperands()); 1255 // Record the initial type before transposition. We'll use its shape later. 1256 // Any type will do here as we will check all transpose maps are the same. 1257 VectorType srcType; 1258 for (Value operand : op->getOperands()) { 1259 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>(); 1260 if (transposeOp) { 1261 transposeMaps.push_back(transposeOp.getTransp()); 1262 srcType = transposeOp.getVectorType(); 1263 } else if (!matchPattern(operand, m_Constant())) { 1264 return failure(); 1265 } 1266 } 1267 if (transposeMaps.empty()) 1268 return failure(); 1269 // This is an elementwise op, so all transposed operands should have the 1270 // same type. We need to additionally check that all transposes uses the 1271 // same map. 1272 if (!llvm::is_splat(transposeMaps)) 1273 return rewriter.notifyMatchFailure(op, "different transpose map"); 1274 1275 SmallVector<Value, 4> srcValues; 1276 srcValues.reserve(op->getNumOperands()); 1277 1278 // If there are constant operands, we need to insert inverse transposes for 1279 // them. Calculate the inverse order first. 1280 auto order = extractVector<unsigned>(transposeMaps.front()); 1281 SmallVector<int64_t> invOrder(order.size()); 1282 for (int i = 0, e = order.size(); i < e; ++i) 1283 invOrder[order[i]] = i; 1284 1285 for (Value operand : op->getOperands()) { 1286 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>(); 1287 if (transposeOp) { 1288 srcValues.push_back(transposeOp.getVector()); 1289 } else { 1290 // This is a constant. Create a reverse transpose op for it. 1291 auto vectorType = VectorType::get( 1292 srcType.getShape(), 1293 operand.getType().cast<VectorType>().getElementType()); 1294 srcValues.push_back(rewriter.create<vector::TransposeOp>( 1295 operand.getLoc(), vectorType, operand, 1296 rewriter.getI64ArrayAttr(invOrder))); 1297 } 1298 } 1299 1300 auto vectorType = VectorType::get( 1301 srcType.getShape(), 1302 op->getResultTypes()[0].cast<VectorType>().getElementType()); 1303 Operation *elementwiseOp = 1304 rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, 1305 vectorType, op->getAttrs()); 1306 rewriter.replaceOpWithNewOp<vector::TransposeOp>( 1307 op, op->getResultTypes()[0], elementwiseOp->getResult(0), 1308 transposeMaps.front()); 1309 return success(); 1310 } 1311 }; 1312 1313 } // namespace 1314 1315 /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using 1316 /// operands `x` and `y`. 1317 static Value createAdd(Location loc, Value x, Value y, bool isInt, 1318 PatternRewriter &rewriter) { 1319 if (isInt) 1320 return rewriter.create<arith::AddIOp>(loc, x, y); 1321 return rewriter.create<arith::AddFOp>(loc, x, y); 1322 } 1323 1324 /// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using 1325 /// operands `x and `y`. 1326 static Value createMul(Location loc, Value x, Value y, bool isInt, 1327 PatternRewriter &rewriter) { 1328 if (isInt) 1329 return rewriter.create<arith::MulIOp>(loc, x, y); 1330 return rewriter.create<arith::MulFOp>(loc, x, y); 1331 } 1332 1333 namespace mlir { 1334 1335 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1336 /// semantics to: 1337 /// ``` 1338 /// %mta = maybe_transpose 1339 /// %mtb = maybe_transpose 1340 /// %flattened_a = vector.shape_cast %mta 1341 /// %flattened_b = vector.shape_cast %mtb 1342 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 1343 /// %mtd = vector.shape_cast %flattened_d 1344 /// %d = maybe_untranspose %mtd 1345 /// %e = add %c, %d 1346 /// ``` 1347 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 1348 // 1349 /// This only kicks in when VectorTransformsOptions is set to `Matmul`. 1350 /// vector.transpose operations are inserted if the vector.contract op is not a 1351 /// row-major matrix multiply. 1352 LogicalResult 1353 ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, 1354 PatternRewriter &rew) const { 1355 // TODO: implement masks 1356 if (llvm::size(op.getMasks()) != 0) 1357 return failure(); 1358 if (vectorTransformOptions.vectorContractLowering != 1359 vector::VectorContractLowering::Matmul) 1360 return failure(); 1361 if (failed(filter(op))) 1362 return failure(); 1363 1364 auto iteratorTypes = op.getIteratorTypes().getValue(); 1365 if (!isParallelIterator(iteratorTypes[0]) || 1366 !isParallelIterator(iteratorTypes[1]) || 1367 !isReductionIterator(iteratorTypes[2])) 1368 return failure(); 1369 1370 Type elementType = op.getLhsType().getElementType(); 1371 if (!elementType.isIntOrFloat()) 1372 return failure(); 1373 1374 Type dstElementType = op.getType(); 1375 if (auto vecType = dstElementType.dyn_cast<VectorType>()) 1376 dstElementType = vecType.getElementType(); 1377 if (elementType != dstElementType) 1378 return failure(); 1379 1380 // Perform lhs + rhs transpositions to conform to matmul row-major semantics. 1381 // Bail out if the contraction cannot be put in this form. 1382 MLIRContext *ctx = op.getContext(); 1383 Location loc = op.getLoc(); 1384 AffineExpr m, n, k; 1385 bindDims(rew.getContext(), m, n, k); 1386 // LHS must be A(m, k) or A(k, m). 1387 Value lhs = op.getLhs(); 1388 auto lhsMap = op.getIndexingMapsArray()[0]; 1389 if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) 1390 lhs = rew.create<vector::TransposeOp>(loc, lhs, ArrayRef<int64_t>{1, 0}); 1391 else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) 1392 return failure(); 1393 1394 // RHS must be B(k, n) or B(n, k). 1395 Value rhs = op.getRhs(); 1396 auto rhsMap = op.getIndexingMapsArray()[1]; 1397 if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) 1398 rhs = rew.create<vector::TransposeOp>(loc, rhs, ArrayRef<int64_t>{1, 0}); 1399 else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) 1400 return failure(); 1401 1402 // At this point lhs and rhs are in row-major. 1403 VectorType lhsType = lhs.getType().cast<VectorType>(); 1404 VectorType rhsType = rhs.getType().cast<VectorType>(); 1405 int64_t lhsRows = lhsType.getDimSize(0); 1406 int64_t lhsColumns = lhsType.getDimSize(1); 1407 int64_t rhsColumns = rhsType.getDimSize(1); 1408 1409 Type flattenedLHSType = 1410 VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); 1411 lhs = rew.create<vector::ShapeCastOp>(loc, flattenedLHSType, lhs); 1412 1413 Type flattenedRHSType = 1414 VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); 1415 rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs); 1416 1417 Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns, 1418 rhsColumns); 1419 mul = rew.create<vector::ShapeCastOp>( 1420 loc, 1421 VectorType::get({lhsRows, rhsColumns}, 1422 getElementTypeOrSelf(op.getAcc().getType())), 1423 mul); 1424 1425 // ACC must be C(m, n) or C(n, m). 1426 auto accMap = op.getIndexingMapsArray()[2]; 1427 if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) 1428 mul = rew.create<vector::TransposeOp>(loc, mul, ArrayRef<int64_t>{1, 0}); 1429 else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) 1430 llvm_unreachable("invalid contraction semantics"); 1431 1432 Value res = 1433 elementType.isa<IntegerType>() 1434 ? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul)) 1435 : static_cast<Value>( 1436 rew.create<arith::AddFOp>(loc, op.getAcc(), mul)); 1437 1438 rew.replaceOp(op, res); 1439 return success(); 1440 } 1441 1442 namespace { 1443 struct IteratorType { 1444 IteratorType(StringRef strRef) : strRef(strRef) {} 1445 bool isOfType(Attribute attr) const { 1446 auto sAttr = attr.dyn_cast<StringAttr>(); 1447 return sAttr && sAttr.getValue() == strRef; 1448 } 1449 StringRef strRef; 1450 }; 1451 struct Par : public IteratorType { 1452 Par() : IteratorType(getParallelIteratorTypeName()) {} 1453 }; 1454 struct Red : public IteratorType { 1455 Red() : IteratorType(getReductionIteratorTypeName()) {} 1456 }; 1457 1458 /// Generate a vector implementation for matmat, matvec and tmatvec. 1459 /// This unrolls outer-products along the reduction dimension. 1460 struct UnrolledOuterProductGenerator 1461 : public StructuredGenerator<vector::ContractionOp> { 1462 UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op) 1463 : StructuredGenerator<vector::ContractionOp>(builder, op), 1464 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), 1465 res(op.getAcc()), lhsType(op.getLhsType()) {} 1466 1467 Value t(Value v) { 1468 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1469 return builder.create<vector::TransposeOp>(loc, v, perm); 1470 } 1471 1472 Value promote(Value v, Type dstElementType) { 1473 Type elementType = v.getType(); 1474 auto vecType = elementType.dyn_cast<VectorType>(); 1475 if (vecType) 1476 elementType = vecType.getElementType(); 1477 if (elementType == dstElementType) 1478 return v; 1479 Type promotedType = dstElementType; 1480 if (vecType) 1481 promotedType = VectorType::get(vecType.getShape(), promotedType); 1482 if (dstElementType.isa<FloatType>()) 1483 return builder.create<arith::ExtFOp>(loc, promotedType, v); 1484 return builder.create<arith::ExtSIOp>(loc, promotedType, v); 1485 } 1486 1487 Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) { 1488 assert(reductionSize > 0); 1489 Type resElementType = res.getType().cast<VectorType>().getElementType(); 1490 for (int64_t k = 0; k < reductionSize; ++k) { 1491 Value a = builder.create<vector::ExtractOp>(loc, lhs, k); 1492 Value b = builder.create<vector::ExtractOp>(loc, rhs, k); 1493 a = promote(a, resElementType); 1494 b = promote(b, resElementType); 1495 res = builder.create<vector::OuterProductOp>(loc, res.getType(), a, b, 1496 res, kind); 1497 } 1498 return res; 1499 } 1500 1501 /// Two outer parallel, one inner reduction (matmat flavor). 1502 FailureOr<Value> matmat() { 1503 if (!iters({Par(), Par(), Red()})) 1504 return failure(); 1505 // Set up the parallel/reduction structure in the right form. 1506 AffineExpr m, n, k; 1507 bindDims(builder.getContext(), m, n, k); 1508 // Classical row-major matmul: Just permute the lhs. 1509 if (layout({{m, k}, {k, n}, {m, n}})) 1510 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1511 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1512 if (layout({{m, k}, {n, k}, {m, n}})) { 1513 Value tlhs = t(lhs); 1514 return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1)); 1515 } 1516 // No need to permute anything. 1517 if (layout({{k, m}, {k, n}, {m, n}})) 1518 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1519 // Just permute the rhs. 1520 if (layout({{k, m}, {n, k}, {m, n}})) 1521 return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0)); 1522 // Transposed output: swap RHS and LHS. 1523 // Classical row-major matmul: permute the lhs. 1524 if (layout({{m, k}, {k, n}, {n, m}})) 1525 return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1)); 1526 // TODO: may be better to fail and use some vector<k> -> scalar reduction. 1527 if (layout({{m, k}, {n, k}, {n, m}})) { 1528 Value trhs = t(rhs); 1529 return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1)); 1530 } 1531 if (layout({{k, m}, {k, n}, {n, m}})) 1532 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1533 if (layout({{k, m}, {n, k}, {n, m}})) 1534 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1535 return failure(); 1536 } 1537 1538 /// One outer parallel, one inner reduction (matvec flavor) 1539 FailureOr<Value> matvec() { 1540 if (!iters({Par(), Red()})) 1541 return failure(); 1542 AffineExpr m, k; 1543 bindDims(builder.getContext(), m, k); 1544 1545 // Case mat-vec: transpose. 1546 if (layout({{m, k}, {k}, {m}})) 1547 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1548 // Case mat-trans-vec: ready to go. 1549 if (layout({{k, m}, {k}, {m}})) 1550 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1551 // Case vec-mat: swap and transpose. 1552 if (layout({{k}, {m, k}, {m}})) 1553 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1554 // Case vec-mat-trans: swap and ready to go. 1555 if (layout({{k}, {k, m}, {m}})) 1556 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1557 return failure(); 1558 } 1559 1560 // 1561 // One outer reduction, one inner parallel (tmatvec flavor) 1562 // 1563 FailureOr<Value> tmatvec() { 1564 if (!iters({Red(), Par()})) 1565 return failure(); 1566 AffineExpr k, m; 1567 bindDims(builder.getContext(), k, m); 1568 1569 // Case mat-vec: transpose. 1570 if (layout({{m, k}, {k}, {m}})) 1571 return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); 1572 // Case mat-trans-vec: ready to go. 1573 if (layout({{k, m}, {k}, {m}})) 1574 return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); 1575 // Case vec-mat: swap and transpose. 1576 if (layout({{k}, {m, k}, {m}})) 1577 return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); 1578 // Case vec-mat-trans: swap and ready to go. 1579 if (layout({{k}, {k, m}, {m}})) 1580 return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); 1581 return failure(); 1582 } 1583 1584 private: 1585 vector::CombiningKind kind; 1586 Value lhs, rhs, res; 1587 VectorType lhsType; 1588 }; 1589 } // namespace 1590 1591 /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul 1592 /// semantics to a reduction_size-unrolled sequence: 1593 /// ``` 1594 /// %at = vector.transpose %a, [1, 0] 1595 /// %bRow0 = vector.extract %b[0] 1596 /// %atRow0 = vector.extract %at[0] 1597 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 1598 /// ... 1599 /// %bRowK = vector.extract %b[K] 1600 /// %atRowK = vector.extract %at[K] 1601 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 1602 /// ``` 1603 /// 1604 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but 1605 /// otherwise supports any layout permutation of the matrix-multiply. 1606 LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( 1607 vector::ContractionOp op, PatternRewriter &rewriter) const { 1608 // TODO: implement masks 1609 if (llvm::size(op.getMasks()) != 0) 1610 return failure(); 1611 1612 if (vectorTransformOptions.vectorContractLowering != 1613 vector::VectorContractLowering::OuterProduct) 1614 return failure(); 1615 1616 if (failed(filter(op))) 1617 return failure(); 1618 1619 UnrolledOuterProductGenerator e(rewriter, op); 1620 FailureOr<Value> matmatRes = e.matmat(); 1621 if (succeeded(matmatRes)) { 1622 rewriter.replaceOp(op, *matmatRes); 1623 return success(); 1624 } 1625 FailureOr<Value> matvecRes = e.matvec(); 1626 if (succeeded(matvecRes)) { 1627 rewriter.replaceOp(op, *matvecRes); 1628 return success(); 1629 } 1630 FailureOr<Value> tmatvecRes = e.tmatvec(); 1631 if (succeeded(tmatvecRes)) { 1632 rewriter.replaceOp(op, *tmatvecRes); 1633 return success(); 1634 } 1635 1636 return failure(); 1637 } 1638 1639 LogicalResult 1640 ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, 1641 PatternRewriter &rewriter) const { 1642 // TODO: implement masks 1643 if (llvm::size(op.getMasks()) != 0) 1644 return failure(); 1645 1646 if (failed(filter(op))) 1647 return failure(); 1648 1649 if (vectorTransformOptions.vectorContractLowering != 1650 vector::VectorContractLowering::Dot) 1651 return failure(); 1652 1653 auto iteratorTypes = op.getIteratorTypes().getValue(); 1654 static constexpr std::array<int64_t, 2> perm = {1, 0}; 1655 Location loc = op.getLoc(); 1656 Value lhs = op.getLhs(), rhs = op.getRhs(); 1657 1658 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 1659 auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; 1660 AffineExpr m, n, k; 1661 bindDims(rewriter.getContext(), m, n, k); 1662 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray(); 1663 // 1664 // In the following we wish to make the reduction dimension innermost so we 1665 // can load vectors and just fmul + reduce into a scalar. 1666 // 1667 if (isParallelIterator(iteratorTypes[0]) && 1668 isParallelIterator(iteratorTypes[1]) && 1669 isReductionIterator(iteratorTypes[2])) { 1670 // 1671 // Two outer parallel, one inner reduction (matmat flavor). 1672 // 1673 if (maps == infer({{m, k}, {k, n}, {m, n}})) { 1674 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1675 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { 1676 // No need to permute anything. 1677 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { 1678 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1679 rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1680 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { 1681 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1682 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { 1683 // This is the classical row-major matmul. Just permute the lhs. 1684 Value tmp = lhs; 1685 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1686 rhs = tmp; 1687 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { 1688 std::swap(lhs, rhs); 1689 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { 1690 Value tmp = lhs; 1691 lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm); 1692 rhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm); 1693 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { 1694 Value tmp = rhs; 1695 rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1696 lhs = tmp; 1697 } else { 1698 return failure(); 1699 } 1700 } else if (isParallelIterator(iteratorTypes[0]) && 1701 isReductionIterator(iteratorTypes[1])) { 1702 // 1703 // One outer parallel, one inner reduction (matvec flavor) 1704 // 1705 if (maps == infer({{m, n}, {n}, {m}})) { 1706 // No need to permute anything. 1707 } else if (maps == infer({{n, m}, {n}, {m}})) { 1708 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1709 } else if (maps == infer({{n}, {m, n}, {m}})) { 1710 std::swap(lhs, rhs); 1711 } else if (maps == infer({{n}, {n, m}, {m}})) { 1712 std::swap(lhs, rhs); 1713 lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm); 1714 } else { 1715 return failure(); 1716 } 1717 } else { 1718 return failure(); 1719 } 1720 1721 VectorType dstType = op.getResultType().cast<VectorType>(); 1722 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && 1723 "Expected dst type of rank 1 or 2"); 1724 1725 unsigned rank = dstType.getRank(); 1726 unsigned dstRows = dstType.getShape()[0]; 1727 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; 1728 1729 // ExtractOp does not allow dynamic indexing, we must unroll explicitly. 1730 Value res = rewriter.create<arith::ConstantOp>(loc, dstType, 1731 rewriter.getZeroAttr(dstType)); 1732 bool isInt = dstType.getElementType().isa<IntegerType>(); 1733 for (unsigned r = 0; r < dstRows; ++r) { 1734 Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r); 1735 for (unsigned c = 0; c < dstColumns; ++c) { 1736 Value b = rank == 1 1737 ? rhs 1738 : rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c); 1739 Value m = createMul(op.getLoc(), a, b, isInt, rewriter); 1740 Value reduced = rewriter.create<vector::ReductionOp>( 1741 op.getLoc(), vector::CombiningKind::ADD, m); 1742 1743 SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r} 1744 : SmallVector<int64_t, 2>{r, c}; 1745 res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos); 1746 } 1747 } 1748 if (auto acc = op.getAcc()) 1749 res = createAdd(op.getLoc(), res, acc, isInt, rewriter); 1750 rewriter.replaceOp(op, res); 1751 return success(); 1752 } 1753 1754 /// Progressive lowering of ContractionOp. 1755 /// One: 1756 /// %x = vector.contract with at least one free/batch dimension 1757 /// is replaced by: 1758 /// %a = vector.contract with one less free/batch dimension 1759 /// %b = vector.contract with one less free/batch dimension 1760 /// .. 1761 /// %x = combine %a %b .. 1762 /// until a pure contraction is reached (no free/batch dimensions), 1763 /// which is replaced by a dot-product. 1764 /// 1765 /// This only kicks in when either VectorTransformsOptions is set 1766 /// to DOT or when other contraction patterns fail. 1767 // 1768 // TODO: break down into transpose/reshape/cast ops 1769 // when they become available to avoid code dup 1770 // TODO: investigate lowering order impact on performance 1771 LogicalResult 1772 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, 1773 PatternRewriter &rewriter) const { 1774 // TODO: implement masks. 1775 if (llvm::size(op.getMasks()) != 0) 1776 return failure(); 1777 1778 if (failed(filter(op))) 1779 return failure(); 1780 1781 // TODO: support mixed mode contract lowering. 1782 if (op.getLhsType().getElementType() != 1783 getElementTypeOrSelf(op.getAccType()) || 1784 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) 1785 return failure(); 1786 1787 // TODO: implement benefits, cost models. 1788 MLIRContext *ctx = op.getContext(); 1789 ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); 1790 if (succeeded(pat1.matchAndRewrite(op, rewriter))) 1791 return success(); 1792 ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); 1793 if (succeeded(pat2.matchAndRewrite(op, rewriter))) 1794 return success(); 1795 ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); 1796 if (succeeded(pat3.matchAndRewrite(op, rewriter))) 1797 return success(); 1798 ContractOpToElementwise pat4(vectorTransformOptions, ctx); 1799 if (succeeded(pat4.matchAndRewrite(op, rewriter))) 1800 return success(); 1801 1802 // Find first batch dimension in LHS/RHS, and lower when found. 1803 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap(); 1804 if (!batchDimMap.empty()) { 1805 int64_t lhsIndex = batchDimMap[0].first; 1806 int64_t rhsIndex = batchDimMap[0].second; 1807 auto newOp = lowerParallel(op, lhsIndex, rhsIndex, rewriter); 1808 if (failed(newOp)) 1809 return failure(); 1810 rewriter.replaceOp(op, newOp.value()); 1811 return success(); 1812 } 1813 1814 // Collect contracting dimensions. 1815 std::vector<std::pair<int64_t, int64_t>> contractingDimMap = 1816 op.getContractingDimMap(); 1817 DenseSet<int64_t> lhsContractingDimSet; 1818 DenseSet<int64_t> rhsContractingDimSet; 1819 for (auto &dimPair : contractingDimMap) { 1820 lhsContractingDimSet.insert(dimPair.first); 1821 rhsContractingDimSet.insert(dimPair.second); 1822 } 1823 1824 // Find first free dimension in LHS, and lower when found. 1825 VectorType lhsType = op.getLhsType(); 1826 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { 1827 if (lhsContractingDimSet.count(lhsIndex) == 0) { 1828 auto newOp = lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter); 1829 if (failed(newOp)) 1830 return failure(); 1831 rewriter.replaceOp(op, newOp.value()); 1832 return success(); 1833 } 1834 } 1835 1836 // Find first free dimension in RHS, and lower when found. 1837 VectorType rhsType = op.getRhsType(); 1838 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { 1839 if (rhsContractingDimSet.count(rhsIndex) == 0) { 1840 auto newOp = lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter); 1841 if (failed(newOp)) 1842 return failure(); 1843 rewriter.replaceOp(op, newOp.value()); 1844 return success(); 1845 } 1846 } 1847 1848 // Lower the first remaining reduction dimension. 1849 if (!contractingDimMap.empty()) { 1850 auto newOp = lowerReduction(op, rewriter); 1851 if (failed(newOp)) 1852 return failure(); 1853 rewriter.replaceOp(op, newOp.value()); 1854 return success(); 1855 } 1856 1857 return failure(); 1858 } 1859 1860 // Lower one parallel dimension. 1861 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions. 1862 // TODO: consider reusing existing contract unrolling 1863 FailureOr<Value> 1864 ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex, 1865 int64_t rhsIndex, 1866 PatternRewriter &rewriter) const { 1867 VectorType lhsType = op.getLhsType(); 1868 VectorType rhsType = op.getRhsType(); 1869 VectorType resType = op.getResultType().cast<VectorType>(); 1870 // Find the iterator type index and result index. 1871 SmallVector<AffineMap, 4> iMap = op.getIndexingMapsArray(); 1872 int64_t iterIndex = -1; 1873 int64_t dimSize = -1; 1874 if (lhsIndex >= 0) { 1875 iterIndex = iMap[0].getDimPosition(lhsIndex); 1876 if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex)) 1877 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1878 diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex 1879 << " to map to the same dimension"; 1880 }); 1881 dimSize = lhsType.getDimSize(lhsIndex); 1882 } else if (rhsIndex >= 0) { 1883 iterIndex = iMap[1].getDimPosition(rhsIndex); 1884 dimSize = rhsType.getDimSize(rhsIndex); 1885 } 1886 if (iterIndex < 0) 1887 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1888 diag << "expected either lhsIndex=" << lhsIndex 1889 << " or rhsIndex=" << rhsIndex << " to be nonnegative"; 1890 }); 1891 // value_or(-1) means that we tolerate a dimension not appearing 1892 // in the result map. That can't happen for actual parallel iterators, but 1893 // the caller ContractionOpLowering::matchAndRewrite is currently calling 1894 // lowerParallel also for the case of unit-size reduction dims appearing only 1895 // on one of LHS or RHS, not both. At the moment, such cases are created by 1896 // CastAwayContractionLeadingOneDim, so we need to either support that or 1897 // modify that pattern. 1898 int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1); 1899 if (resIndex == -1 && dimSize != 1) 1900 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1901 diag << "expected the dimension for iterIndex=" << iterIndex 1902 << " to either appear in the result map, or to be a unit dimension"; 1903 }); 1904 // Construct new iterator types and affine map array attribute. 1905 std::array<AffineMap, 3> lowIndexingMaps = { 1906 adjustMap(iMap[0], iterIndex, rewriter), 1907 adjustMap(iMap[1], iterIndex, rewriter), 1908 adjustMap(iMap[2], iterIndex, rewriter)}; 1909 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1910 auto lowIter = 1911 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); 1912 // Unroll into a series of lower dimensional vector.contract ops. 1913 Location loc = op.getLoc(); 1914 Value result = rewriter.create<arith::ConstantOp>( 1915 loc, resType, rewriter.getZeroAttr(resType)); 1916 for (int64_t d = 0; d < dimSize; ++d) { 1917 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); 1918 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); 1919 auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter); 1920 Value lowContract = rewriter.create<vector::ContractionOp>( 1921 loc, lhs, rhs, acc, lowAffine, lowIter); 1922 result = 1923 reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter); 1924 } 1925 return result; 1926 } 1927 1928 // Lower one reduction dimension. 1929 FailureOr<Value> 1930 ContractionOpLowering::lowerReduction(vector::ContractionOp op, 1931 PatternRewriter &rewriter) const { 1932 auto loc = op.getLoc(); 1933 VectorType lhsType = op.getLhsType(); 1934 VectorType rhsType = op.getRhsType(); 1935 Type resType = op.getResultType(); 1936 if (resType.isa<VectorType>()) 1937 return rewriter.notifyMatchFailure(op, 1938 "did not expect a VectorType result"); 1939 bool isInt = resType.isa<IntegerType>(); 1940 // Use iterator index 0. 1941 int64_t iterIndex = 0; 1942 SmallVector<AffineMap, 4> iMap = op.getIndexingMapsArray(); 1943 Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex); 1944 Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex); 1945 if (!lookupLhs.has_value()) 1946 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1947 diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension"; 1948 }); 1949 if (!lookupRhs.has_value()) 1950 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1951 diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension"; 1952 }); 1953 int64_t lhsIndex = lookupLhs.value(); 1954 int64_t rhsIndex = lookupRhs.value(); 1955 int64_t dimSize = lhsType.getDimSize(lhsIndex); 1956 if (dimSize != rhsType.getDimSize(rhsIndex)) 1957 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 1958 diag << "expect LHS dimension " << lhsIndex 1959 << " to have the same size as RHS dimension " << rhsIndex; 1960 }); 1961 // Base case. 1962 if (lhsType.getRank() == 1) { 1963 if (rhsType.getRank() != 1) 1964 return rewriter.notifyMatchFailure( 1965 op, "When LHS has rank 1, expected also RHS to have rank 1"); 1966 Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter); 1967 auto kind = vector::CombiningKind::ADD; 1968 if (auto acc = op.getAcc()) 1969 return rewriter.create<vector::ReductionOp>(loc, kind, m, acc) 1970 .getResult(); 1971 return rewriter.create<vector::ReductionOp>(loc, kind, m).getResult(); 1972 } 1973 // Construct new iterator types and affine map array attribute. 1974 std::array<AffineMap, 3> lowIndexingMaps = { 1975 adjustMap(iMap[0], iterIndex, rewriter), 1976 adjustMap(iMap[1], iterIndex, rewriter), 1977 adjustMap(iMap[2], iterIndex, rewriter)}; 1978 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); 1979 auto lowIter = 1980 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); 1981 // Unroll into a series of lower dimensional vector.contract ops. 1982 // By feeding the initial accumulator into the first contraction, 1983 // and the result of each contraction into the next, eventually 1984 // the sum of all reductions is computed. 1985 Value result = op.getAcc(); 1986 for (int64_t d = 0; d < dimSize; ++d) { 1987 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); 1988 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); 1989 result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result, 1990 lowAffine, lowIter); 1991 } 1992 return result; 1993 } 1994 1995 } // namespace mlir 1996 1997 Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp( 1998 OpBuilder &builder, Operation *op, ArrayRef<Value> ids, 1999 ArrayRef<int64_t> multiplicity, const AffineMap &map) { 2000 OpBuilder::InsertionGuard guard(builder); 2001 builder.setInsertionPointAfter(op); 2002 Location loc = op->getLoc(); 2003 if (op->getNumResults() != 1) 2004 return {}; 2005 Value result = op->getResult(0); 2006 VectorType type = op->getResult(0).getType().dyn_cast<VectorType>(); 2007 if (!type || map.getNumResults() != multiplicity.size()) 2008 return {}; 2009 // For each dimension being distributed check that the size is a multiple of 2010 // the multiplicity. To handle more sizes we would need to support masking. 2011 unsigned multiplictyCount = 0; 2012 for (auto exp : map.getResults()) { 2013 auto affinExp = exp.dyn_cast<AffineDimExpr>(); 2014 if (!affinExp || affinExp.getPosition() >= type.getRank() || 2015 type.getDimSize(affinExp.getPosition()) % 2016 multiplicity[multiplictyCount++] != 2017 0) 2018 return {}; 2019 } 2020 DistributeOps ops; 2021 ops.extract = 2022 builder.create<vector::ExtractMapOp>(loc, result, ids, multiplicity, map); 2023 ops.insert = 2024 builder.create<vector::InsertMapOp>(loc, ops.extract, result, ids); 2025 return ops; 2026 } 2027 2028 /// Progressive lowering of transfer_read. This pattern supports lowering of 2029 /// `vector.transfer_read` to a combination of `vector.load` and 2030 /// `vector.broadcast` if all of the following hold: 2031 /// - Stride of most minor memref dimension must be 1. 2032 /// - Out-of-bounds masking is not required. 2033 /// - If the memref's element type is a vector type then it coincides with the 2034 /// result type. 2035 /// - The permutation map doesn't perform permutation (broadcasting is allowed). 2036 struct TransferReadToVectorLoadLowering 2037 : public OpRewritePattern<vector::TransferReadOp> { 2038 TransferReadToVectorLoadLowering(MLIRContext *context, 2039 llvm::Optional<unsigned> maxRank) 2040 : OpRewritePattern<vector::TransferReadOp>(context), 2041 maxTransferRank(maxRank) {} 2042 2043 LogicalResult matchAndRewrite(vector::TransferReadOp read, 2044 PatternRewriter &rewriter) const override { 2045 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) 2046 return failure(); 2047 2048 SmallVector<unsigned, 4> broadcastedDims; 2049 // Permutations are handled by VectorToSCF or 2050 // populateVectorTransferPermutationMapLoweringPatterns. 2051 // We let the 0-d corner case pass-through as it is supported. 2052 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting( 2053 &broadcastedDims)) 2054 return failure(); 2055 2056 auto memRefType = read.getShapedType().dyn_cast<MemRefType>(); 2057 if (!memRefType) 2058 return failure(); 2059 2060 // Non-unit strides are handled by VectorToSCF. 2061 if (!vector::isLastMemrefDimUnitStride(memRefType)) 2062 return failure(); 2063 2064 // If there is broadcasting involved then we first load the unbroadcasted 2065 // vector, and then broadcast it with `vector.broadcast`. 2066 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape(); 2067 SmallVector<int64_t, 4> unbroadcastedVectorShape(vectorShape.begin(), 2068 vectorShape.end()); 2069 for (unsigned i : broadcastedDims) 2070 unbroadcastedVectorShape[i] = 1; 2071 VectorType unbroadcastedVectorType = VectorType::get( 2072 unbroadcastedVectorShape, read.getVectorType().getElementType()); 2073 2074 // `vector.load` supports vector types as memref's elements only when the 2075 // resulting vector type is the same as the element type. 2076 auto memrefElTy = memRefType.getElementType(); 2077 if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType) 2078 return failure(); 2079 2080 // Otherwise, element types of the memref and the vector must match. 2081 if (!memrefElTy.isa<VectorType>() && 2082 memrefElTy != read.getVectorType().getElementType()) 2083 return failure(); 2084 2085 // Out-of-bounds dims are handled by MaterializeTransferMask. 2086 if (read.hasOutOfBoundsDim()) 2087 return failure(); 2088 2089 // Create vector load op. 2090 Operation *loadOp; 2091 if (read.getMask()) { 2092 Value fill = rewriter.create<vector::SplatOp>( 2093 read.getLoc(), unbroadcastedVectorType, read.getPadding()); 2094 loadOp = rewriter.create<vector::MaskedLoadOp>( 2095 read.getLoc(), unbroadcastedVectorType, read.getSource(), 2096 read.getIndices(), read.getMask(), fill); 2097 } else { 2098 loadOp = rewriter.create<vector::LoadOp>( 2099 read.getLoc(), unbroadcastedVectorType, read.getSource(), 2100 read.getIndices()); 2101 } 2102 2103 // Insert a broadcasting op if required. 2104 if (!broadcastedDims.empty()) { 2105 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 2106 read, read.getVectorType(), loadOp->getResult(0)); 2107 } else { 2108 rewriter.replaceOp(read, loadOp->getResult(0)); 2109 } 2110 2111 return success(); 2112 } 2113 2114 llvm::Optional<unsigned> maxTransferRank; 2115 }; 2116 2117 /// Replace a 0-d vector.load with a memref.load + vector.broadcast. 2118 // TODO: we shouldn't cross the vector/scalar domains just for this 2119 // but atm we lack the infra to avoid it. Possible solutions include: 2120 // - go directly to LLVM + bitcast 2121 // - introduce a bitcast op and likely a new pointer dialect 2122 // - let memref.load/store additionally support the 0-d vector case 2123 // There are still deeper data layout issues lingering even in this 2124 // trivial case (for architectures for which this matters). 2125 struct VectorLoadToMemrefLoadLowering 2126 : public OpRewritePattern<vector::LoadOp> { 2127 using OpRewritePattern<vector::LoadOp>::OpRewritePattern; 2128 2129 LogicalResult matchAndRewrite(vector::LoadOp loadOp, 2130 PatternRewriter &rewriter) const override { 2131 auto vecType = loadOp.getVectorType(); 2132 if (vecType.getNumElements() != 1) 2133 return failure(); 2134 auto memrefLoad = rewriter.create<memref::LoadOp>( 2135 loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices()); 2136 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType, 2137 memrefLoad); 2138 return success(); 2139 } 2140 }; 2141 2142 /// Replace a 0-d vector.store with a vector.extractelement + memref.store. 2143 struct VectorStoreToMemrefStoreLowering 2144 : public OpRewritePattern<vector::StoreOp> { 2145 using OpRewritePattern<vector::StoreOp>::OpRewritePattern; 2146 2147 LogicalResult matchAndRewrite(vector::StoreOp storeOp, 2148 PatternRewriter &rewriter) const override { 2149 auto vecType = storeOp.getVectorType(); 2150 if (vecType.getNumElements() != 1) 2151 return failure(); 2152 Value extracted; 2153 if (vecType.getRank() == 0) { 2154 // TODO: Unifiy once ExtractOp supports 0-d vectors. 2155 extracted = rewriter.create<vector::ExtractElementOp>( 2156 storeOp.getLoc(), storeOp.getValueToStore()); 2157 } else { 2158 SmallVector<int64_t> indices(vecType.getRank(), 0); 2159 extracted = rewriter.create<vector::ExtractOp>( 2160 storeOp.getLoc(), storeOp.getValueToStore(), indices); 2161 } 2162 2163 rewriter.replaceOpWithNewOp<memref::StoreOp>( 2164 storeOp, extracted, storeOp.getBase(), storeOp.getIndices()); 2165 return success(); 2166 } 2167 }; 2168 2169 /// Progressive lowering of transfer_write. This pattern supports lowering of 2170 /// `vector.transfer_write` to `vector.store` if all of the following hold: 2171 /// - Stride of most minor memref dimension must be 1. 2172 /// - Out-of-bounds masking is not required. 2173 /// - If the memref's element type is a vector type then it coincides with the 2174 /// type of the written value. 2175 /// - The permutation map is the minor identity map (neither permutation nor 2176 /// broadcasting is allowed). 2177 struct TransferWriteToVectorStoreLowering 2178 : public OpRewritePattern<vector::TransferWriteOp> { 2179 TransferWriteToVectorStoreLowering(MLIRContext *context, 2180 llvm::Optional<unsigned> maxRank) 2181 : OpRewritePattern<vector::TransferWriteOp>(context), 2182 maxTransferRank(maxRank) {} 2183 2184 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 2185 PatternRewriter &rewriter) const override { 2186 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) 2187 return failure(); 2188 2189 // Permutations are handled by VectorToSCF or 2190 // populateVectorTransferPermutationMapLoweringPatterns. 2191 if ( // pass-through for the 0-d corner case. 2192 !write.getPermutationMap().isMinorIdentity()) 2193 return failure(); 2194 2195 auto memRefType = write.getShapedType().dyn_cast<MemRefType>(); 2196 if (!memRefType) 2197 return failure(); 2198 2199 // Non-unit strides are handled by VectorToSCF. 2200 if (!vector::isLastMemrefDimUnitStride(memRefType)) 2201 return failure(); 2202 2203 // `vector.store` supports vector types as memref's elements only when the 2204 // type of the vector value being written is the same as the element type. 2205 auto memrefElTy = memRefType.getElementType(); 2206 if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType()) 2207 return failure(); 2208 2209 // Otherwise, element types of the memref and the vector must match. 2210 if (!memrefElTy.isa<VectorType>() && 2211 memrefElTy != write.getVectorType().getElementType()) 2212 return failure(); 2213 2214 // Out-of-bounds dims are handled by MaterializeTransferMask. 2215 if (write.hasOutOfBoundsDim()) 2216 return failure(); 2217 if (write.getMask()) { 2218 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 2219 write, write.getSource(), write.getIndices(), write.getMask(), 2220 write.getVector()); 2221 } else { 2222 rewriter.replaceOpWithNewOp<vector::StoreOp>( 2223 write, write.getVector(), write.getSource(), write.getIndices()); 2224 } 2225 return success(); 2226 } 2227 2228 llvm::Optional<unsigned> maxTransferRank; 2229 }; 2230 2231 // Returns the values in `arrayAttr` as an integer vector. 2232 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) { 2233 return llvm::to_vector<4>( 2234 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(), 2235 [](IntegerAttr attr) { return attr.getInt(); })); 2236 } 2237 2238 // Shuffles vector.bitcast op after vector.extract op. 2239 // 2240 // This transforms IR like: 2241 // %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> 2242 // %1 = vector.extract %0[3] : vector<8xf16> 2243 // Into: 2244 // %0 = vector.extract %src[1] : vector<4xf32> 2245 // %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16> 2246 // %2 = vector.extract %1[1] : vector<2xf16> 2247 struct BubbleDownVectorBitCastForExtract 2248 : public OpRewritePattern<vector::ExtractOp> { 2249 using OpRewritePattern::OpRewritePattern; 2250 2251 LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 2252 PatternRewriter &rewriter) const override { 2253 // Only support extracting scalars for now. 2254 if (extractOp.getVectorType().getRank() != 1) 2255 return failure(); 2256 2257 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>(); 2258 if (!castOp) 2259 return failure(); 2260 2261 VectorType castSrcType = castOp.getSourceVectorType(); 2262 VectorType castDstType = castOp.getResultVectorType(); 2263 assert(castSrcType.getRank() == castDstType.getRank()); 2264 2265 // Fail to match if we only have one element in the cast op source. 2266 // This is to avoid infinite loop given that this pattern can generate 2267 // such cases. 2268 if (castSrcType.getNumElements() == 1) 2269 return failure(); 2270 2271 // Only support casting to a larger number of elements or now. 2272 // E.g., vector<4xf32> -> vector<8xf16>. 2273 if (castSrcType.getNumElements() > castDstType.getNumElements()) 2274 return failure(); 2275 2276 unsigned expandRatio = 2277 castDstType.getNumElements() / castSrcType.getNumElements(); 2278 2279 auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t { 2280 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 2281 }; 2282 2283 uint64_t index = getFirstIntValue(extractOp.getPosition()); 2284 2285 // Get the single scalar (as a vector) in the source value that packs the 2286 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> 2287 VectorType oneScalarType = 2288 VectorType::get({1}, castSrcType.getElementType()); 2289 Value packedValue = rewriter.create<vector::ExtractOp>( 2290 extractOp.getLoc(), oneScalarType, castOp.getSource(), 2291 rewriter.getI64ArrayAttr(index / expandRatio)); 2292 2293 // Cast it to a vector with the desired scalar's type. 2294 // E.g. f32 -> vector<2xf16> 2295 VectorType packedType = 2296 VectorType::get({expandRatio}, castDstType.getElementType()); 2297 Value castedValue = rewriter.create<vector::BitCastOp>( 2298 extractOp.getLoc(), packedType, packedValue); 2299 2300 // Finally extract the desired scalar. 2301 rewriter.replaceOpWithNewOp<vector::ExtractOp>( 2302 extractOp, extractOp.getType(), castedValue, 2303 rewriter.getI64ArrayAttr(index % expandRatio)); 2304 2305 return success(); 2306 } 2307 }; 2308 2309 // Shuffles vector.bitcast op after vector.extract_strided_slice op. 2310 // 2311 // This transforms IR like: 2312 // %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> 2313 // %0 = vector.extract_strided_slice %cast { 2314 // offsets = [4], sizes = [4], strides = [1] 2315 // } : vector<8xf16> to vector<4xf16> 2316 // Into: 2317 // %0 = vector.extract_strided_slice %src { 2318 // offsets = [2], sizes = [2], strides = [1] 2319 // } : vector<4xf32> to vector<2xf32> 2320 // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> 2321 struct BubbleDownBitCastForStridedSliceExtract 2322 : public OpRewritePattern<vector::ExtractStridedSliceOp> { 2323 using OpRewritePattern::OpRewritePattern; 2324 2325 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, 2326 PatternRewriter &rewriter) const override { 2327 auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>(); 2328 if (!castOp) 2329 return failure(); 2330 2331 VectorType castSrcType = castOp.getSourceVectorType(); 2332 VectorType castDstType = castOp.getResultVectorType(); 2333 assert(castSrcType.getRank() == castDstType.getRank()); 2334 2335 int64_t castSrcLastDim = castSrcType.getShape().back(); 2336 int64_t castDstLastDim = castDstType.getShape().back(); 2337 // Require casting to more elements for now; other cases to be implemented. 2338 if (castSrcLastDim > castDstLastDim) 2339 return failure(); 2340 2341 // Only accept all one strides for now. 2342 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(), 2343 [](const APInt &val) { return !val.isOneValue(); })) 2344 return failure(); 2345 2346 unsigned rank = extractOp.getVectorType().getRank(); 2347 assert(castDstLastDim % castSrcLastDim == 0); 2348 int64_t expandRatio = castDstLastDim / castSrcLastDim; 2349 2350 // If we have a less number of offsets than the rank, then implicitly we 2351 // are selecting the full range for the last bitcasted dimension; other 2352 // dimensions aren't affected. Otherwise, we need to scale down the last 2353 // dimension's offset given we are extracting from less elements now. 2354 ArrayAttr newOffsets = extractOp.getOffsets(); 2355 if (newOffsets.size() == rank) { 2356 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2357 if (offsets.back() % expandRatio != 0) 2358 return failure(); 2359 offsets.back() = offsets.back() / expandRatio; 2360 newOffsets = rewriter.getI64ArrayAttr(offsets); 2361 } 2362 2363 // Similarly for sizes. 2364 ArrayAttr newSizes = extractOp.getSizes(); 2365 if (newSizes.size() == rank) { 2366 SmallVector<int64_t, 4> sizes = getIntValueVector(newSizes); 2367 if (sizes.back() % expandRatio != 0) 2368 return failure(); 2369 sizes.back() = sizes.back() / expandRatio; 2370 newSizes = rewriter.getI64ArrayAttr(sizes); 2371 } 2372 2373 SmallVector<int64_t, 4> dims = 2374 llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape()); 2375 dims.back() = dims.back() / expandRatio; 2376 VectorType newExtractType = 2377 VectorType::get(dims, castSrcType.getElementType()); 2378 2379 auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>( 2380 extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets, 2381 newSizes, extractOp.getStrides()); 2382 2383 rewriter.replaceOpWithNewOp<vector::BitCastOp>( 2384 extractOp, extractOp.getType(), newExtractOp); 2385 2386 return success(); 2387 } 2388 }; 2389 2390 // Shuffles vector.bitcast op before vector.insert_strided_slice op. 2391 // 2392 // This transforms IR like: 2393 // %0 = vector.insert_strided_slice %src, %dst { 2394 // offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> 2395 // %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> 2396 // Into: 2397 // %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32> 2398 // %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32> 2399 // %2 = vector.insert_strided_slice %src, %dst { 2400 // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> 2401 struct BubbleUpBitCastForStridedSliceInsert 2402 : public OpRewritePattern<vector::BitCastOp> { 2403 using OpRewritePattern::OpRewritePattern; 2404 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, 2405 PatternRewriter &rewriter) const override { 2406 VectorType castSrcType = bitcastOp.getSourceVectorType(); 2407 VectorType castDstType = bitcastOp.getResultVectorType(); 2408 assert(castSrcType.getRank() == castDstType.getRank()); 2409 2410 int64_t castSrcLastDim = castSrcType.getShape().back(); 2411 int64_t castDstLastDim = castDstType.getShape().back(); 2412 // Require casting to less elements for now; other cases to be implemented. 2413 if (castSrcLastDim < castDstLastDim) 2414 return failure(); 2415 2416 assert(castSrcLastDim % castDstLastDim == 0); 2417 int64_t shrinkRatio = castSrcLastDim / castDstLastDim; 2418 2419 auto insertOp = 2420 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>(); 2421 if (!insertOp) 2422 return failure(); 2423 2424 // Only accept all one strides for now. 2425 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(), 2426 [](const APInt &val) { return !val.isOneValue(); })) 2427 return failure(); 2428 2429 unsigned rank = insertOp.getSourceVectorType().getRank(); 2430 // Require insert op to have the same rank for the source and destination 2431 // vector; other cases to be implemented. 2432 if (rank != insertOp.getDestVectorType().getRank()) 2433 return failure(); 2434 2435 ArrayAttr newOffsets = insertOp.getOffsets(); 2436 assert(newOffsets.size() == rank); 2437 SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets); 2438 if (offsets.back() % shrinkRatio != 0) 2439 return failure(); 2440 offsets.back() = offsets.back() / shrinkRatio; 2441 newOffsets = rewriter.getI64ArrayAttr(offsets); 2442 2443 SmallVector<int64_t, 4> srcDims = 2444 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape()); 2445 srcDims.back() = srcDims.back() / shrinkRatio; 2446 VectorType newCastSrcType = 2447 VectorType::get(srcDims, castDstType.getElementType()); 2448 2449 auto newCastSrcOp = rewriter.create<vector::BitCastOp>( 2450 bitcastOp.getLoc(), newCastSrcType, insertOp.getSource()); 2451 2452 SmallVector<int64_t, 4> dstDims = 2453 llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); 2454 dstDims.back() = dstDims.back() / shrinkRatio; 2455 VectorType newCastDstType = 2456 VectorType::get(dstDims, castDstType.getElementType()); 2457 2458 auto newCastDstOp = rewriter.create<vector::BitCastOp>( 2459 bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); 2460 2461 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>( 2462 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, 2463 insertOp.getStrides()); 2464 2465 return success(); 2466 } 2467 }; 2468 2469 // Helper that returns a vector comparison that constructs a mask: 2470 // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] 2471 // 2472 // If `dim == 0` then the result will be a 0-D vector. 2473 // 2474 // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, 2475 // much more compact, IR for this operation, but LLVM eventually 2476 // generates more elaborate instructions for this intrinsic since it 2477 // is very conservative on the boundary conditions. 2478 static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, 2479 bool force32BitVectorIndices, int64_t dim, 2480 Value b, Value *off = nullptr) { 2481 auto loc = op->getLoc(); 2482 // If we can assume all indices fit in 32-bit, we perform the vector 2483 // comparison in 32-bit to get a higher degree of SIMD parallelism. 2484 // Otherwise we perform the vector comparison using 64-bit indices. 2485 Type idxType = 2486 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type(); 2487 DenseIntElementsAttr indicesAttr; 2488 if (dim == 0 && force32BitVectorIndices) { 2489 indicesAttr = DenseIntElementsAttr::get( 2490 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0}); 2491 } else if (dim == 0) { 2492 indicesAttr = DenseIntElementsAttr::get( 2493 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0}); 2494 } else if (force32BitVectorIndices) { 2495 indicesAttr = rewriter.getI32VectorAttr( 2496 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim))); 2497 } else { 2498 indicesAttr = rewriter.getI64VectorAttr( 2499 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim))); 2500 } 2501 Value indices = rewriter.create<arith::ConstantOp>(loc, indicesAttr); 2502 // Add in an offset if requested. 2503 if (off) { 2504 Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off); 2505 Value ov = rewriter.create<vector::SplatOp>(loc, indices.getType(), o); 2506 indices = rewriter.create<arith::AddIOp>(loc, ov, indices); 2507 } 2508 // Construct the vector comparison. 2509 Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b); 2510 Value bounds = 2511 rewriter.create<vector::SplatOp>(loc, indices.getType(), bound); 2512 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, indices, 2513 bounds); 2514 } 2515 2516 template <typename ConcreteOp> 2517 struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> { 2518 public: 2519 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt) 2520 : mlir::OpRewritePattern<ConcreteOp>(context), 2521 force32BitVectorIndices(enableIndexOpt) {} 2522 2523 LogicalResult matchAndRewrite(ConcreteOp xferOp, 2524 PatternRewriter &rewriter) const override { 2525 if (!xferOp.hasOutOfBoundsDim()) 2526 return failure(); 2527 2528 if (xferOp.getVectorType().getRank() > 1 || 2529 llvm::size(xferOp.getIndices()) == 0) 2530 return failure(); 2531 2532 Location loc = xferOp->getLoc(); 2533 VectorType vtp = xferOp.getVectorType(); 2534 2535 // Create the in-bounds mask with all elements between [0 .. dim - offset) 2536 // set and [dim - offset .. vector_length) unset. 2537 // 2538 // TODO: when the leaf transfer rank is k > 1, we need the last `k` 2539 // dimensions here. 2540 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1; 2541 Value off = xferOp.getIndices()[lastIndex]; 2542 Value dim = 2543 vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex); 2544 Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off); 2545 Value mask = rewriter.create<vector::CreateMaskOp>( 2546 loc, 2547 VectorType::get(vtp.getShape(), rewriter.getI1Type(), 2548 vtp.getNumScalableDims()), 2549 b); 2550 if (xferOp.getMask()) { 2551 // Intersect the in-bounds with the mask specified as an op parameter. 2552 mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask()); 2553 } 2554 2555 rewriter.updateRootInPlace(xferOp, [&]() { 2556 xferOp.getMaskMutable().assign(mask); 2557 xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 2558 }); 2559 2560 return success(); 2561 } 2562 2563 private: 2564 const bool force32BitVectorIndices; 2565 }; 2566 2567 /// Conversion pattern for a `vector.create_mask` (0-D and 1-D only). 2568 class VectorCreateMaskOpConversion 2569 : public OpRewritePattern<vector::CreateMaskOp> { 2570 public: 2571 explicit VectorCreateMaskOpConversion(MLIRContext *context, 2572 bool enableIndexOpt) 2573 : mlir::OpRewritePattern<vector::CreateMaskOp>(context), 2574 force32BitVectorIndices(enableIndexOpt) {} 2575 2576 LogicalResult matchAndRewrite(vector::CreateMaskOp op, 2577 PatternRewriter &rewriter) const override { 2578 auto dstType = op.getType(); 2579 if (dstType.cast<VectorType>().isScalable()) 2580 return failure(); 2581 int64_t rank = dstType.getRank(); 2582 if (rank > 1) 2583 return failure(); 2584 rewriter.replaceOp( 2585 op, buildVectorComparison(rewriter, op, force32BitVectorIndices, 2586 rank == 0 ? 0 : dstType.getDimSize(0), 2587 op.getOperand(0))); 2588 return success(); 2589 } 2590 2591 private: 2592 const bool force32BitVectorIndices; 2593 }; 2594 2595 // Drop inner most contiguous unit dimensions from transfer_read operand. 2596 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> { 2597 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 2598 2599 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 2600 PatternRewriter &rewriter) const override { 2601 // TODO: support 0-d corner case. 2602 if (readOp.getTransferRank() == 0) 2603 return failure(); 2604 2605 // TODO: support mask. 2606 if (readOp.getMask()) 2607 return failure(); 2608 2609 auto srcType = readOp.getSource().getType().dyn_cast<MemRefType>(); 2610 if (!srcType || !srcType.hasStaticShape()) 2611 return failure(); 2612 2613 if (!readOp.getPermutationMap().isMinorIdentity()) 2614 return failure(); 2615 2616 auto targetType = readOp.getVectorType(); 2617 if (targetType.getRank() <= 1) 2618 return failure(); 2619 2620 SmallVector<int64_t> srcStrides; 2621 int64_t srcOffset; 2622 if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) 2623 return failure(); 2624 2625 size_t dimsToDrop = 0; 2626 for (size_t i = 1; i < srcStrides.size(); ++i) { 2627 int dim = srcType.getRank() - i - 1; 2628 if (srcStrides[dim] == 1) { 2629 dimsToDrop++; 2630 } else { 2631 break; 2632 } 2633 } 2634 if (dimsToDrop == 0) 2635 return failure(); 2636 2637 auto resultTargetVecType = 2638 VectorType::get(targetType.getShape().drop_back(dimsToDrop), 2639 targetType.getElementType()); 2640 2641 MemRefType resultMemrefType; 2642 if (srcType.getLayout().getAffineMap().isIdentity()) { 2643 resultMemrefType = MemRefType::get( 2644 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2645 {}, srcType.getMemorySpaceAsInt()); 2646 } else { 2647 AffineMap map = srcType.getLayout().getAffineMap(); 2648 int numSymbols = map.getNumSymbols(); 2649 for (size_t i = 0; i < dimsToDrop; ++i) { 2650 int dim = srcType.getRank() - i - 1; 2651 map = map.replace(rewriter.getAffineDimExpr(dim), 2652 rewriter.getAffineConstantExpr(0), 2653 map.getNumDims() - 1, numSymbols); 2654 } 2655 resultMemrefType = MemRefType::get( 2656 srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), 2657 map, srcType.getMemorySpaceAsInt()); 2658 } 2659 2660 auto loc = readOp.getLoc(); 2661 SmallVector<int64_t> offsets(srcType.getRank(), 0); 2662 SmallVector<int64_t> strides(srcType.getRank(), 1); 2663 2664 ArrayAttr inBoundsAttr = 2665 readOp.getInBounds() 2666 ? rewriter.getArrayAttr( 2667 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)) 2668 : ArrayAttr(); 2669 Value rankedReducedView = rewriter.create<memref::SubViewOp>( 2670 loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(), 2671 strides); 2672 auto permMap = getTransferMinorIdentityMap( 2673 rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType); 2674 Value result = rewriter.create<vector::TransferReadOp>( 2675 loc, resultTargetVecType, rankedReducedView, 2676 readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), 2677 readOp.getPadding(), 2678 // TODO: support mask. 2679 /*mask=*/Value(), inBoundsAttr); 2680 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType, 2681 result); 2682 return success(); 2683 } 2684 }; 2685 2686 namespace { 2687 2688 /// This function checks to see if the vector combining kind 2689 /// is consistent with the integer or float element type. 2690 static bool isValidKind(bool isInt, vector::CombiningKind kind) { 2691 using vector::CombiningKind; 2692 enum class KindType { FLOAT, INT, INVALID }; 2693 KindType type{KindType::INVALID}; 2694 switch (kind) { 2695 case CombiningKind::MINF: 2696 case CombiningKind::MAXF: 2697 type = KindType::FLOAT; 2698 break; 2699 case CombiningKind::MINUI: 2700 case CombiningKind::MINSI: 2701 case CombiningKind::MAXUI: 2702 case CombiningKind::MAXSI: 2703 case CombiningKind::AND: 2704 case CombiningKind::OR: 2705 case CombiningKind::XOR: 2706 type = KindType::INT; 2707 break; 2708 case CombiningKind::ADD: 2709 case CombiningKind::MUL: 2710 type = isInt ? KindType::INT : KindType::FLOAT; 2711 break; 2712 } 2713 bool isValidIntKind = (type == KindType::INT) && isInt; 2714 bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt); 2715 return (isValidIntKind || isValidFloatKind); 2716 } 2717 2718 /// This function constructs the appropriate integer or float 2719 /// operation given the vector combining kind and operands. The 2720 /// supported int operations are : add, mul, min (signed/unsigned), 2721 /// max(signed/unsigned), and, or, xor. The supported float 2722 /// operations are : add, mul, min and max. 2723 static Value genOperator(Location loc, Value x, Value y, 2724 vector::CombiningKind kind, 2725 PatternRewriter &rewriter) { 2726 using vector::CombiningKind; 2727 2728 auto elType = x.getType().cast<VectorType>().getElementType(); 2729 bool isInt = elType.isIntOrIndex(); 2730 2731 Value combinedResult{nullptr}; 2732 switch (kind) { 2733 case CombiningKind::ADD: 2734 if (isInt) 2735 combinedResult = rewriter.create<arith::AddIOp>(loc, x, y); 2736 else 2737 combinedResult = rewriter.create<arith::AddFOp>(loc, x, y); 2738 break; 2739 case CombiningKind::MUL: 2740 if (isInt) 2741 combinedResult = rewriter.create<arith::MulIOp>(loc, x, y); 2742 else 2743 combinedResult = rewriter.create<arith::MulFOp>(loc, x, y); 2744 break; 2745 case CombiningKind::MINUI: 2746 combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y); 2747 break; 2748 case CombiningKind::MINSI: 2749 combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y); 2750 break; 2751 case CombiningKind::MAXUI: 2752 combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y); 2753 break; 2754 case CombiningKind::MAXSI: 2755 combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y); 2756 break; 2757 case CombiningKind::AND: 2758 combinedResult = rewriter.create<arith::AndIOp>(loc, x, y); 2759 break; 2760 case CombiningKind::OR: 2761 combinedResult = rewriter.create<arith::OrIOp>(loc, x, y); 2762 break; 2763 case CombiningKind::XOR: 2764 combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y); 2765 break; 2766 case CombiningKind::MINF: 2767 combinedResult = rewriter.create<arith::MinFOp>(loc, x, y); 2768 break; 2769 case CombiningKind::MAXF: 2770 combinedResult = rewriter.create<arith::MaxFOp>(loc, x, y); 2771 break; 2772 } 2773 return combinedResult; 2774 } 2775 2776 /// Convert vector.scan op into arith ops and 2777 /// vector.insert_strided_slice/extract_strided_slice 2778 /// 2779 /// Ex: 2780 /// ``` 2781 /// %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 2782 /// 1} : 2783 /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>) 2784 /// ``` 2785 /// Gets converted to: 2786 /// ``` 2787 /// %cst = arith.constant dense<0> : vector<2x3xi32> 2788 /// %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1], 2789 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 = 2790 /// vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]} 2791 /// : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice 2792 /// %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : 2793 /// vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 : 2794 /// vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1], 2795 /// strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 = 2796 /// vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1], 2797 /// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3, 2798 /// %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets = 2799 /// [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 = 2800 /// vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 : 2801 /// vector<2x3xi32>, vector<2xi32> 2802 /// ``` 2803 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { 2804 using OpRewritePattern<vector::ScanOp>::OpRewritePattern; 2805 2806 LogicalResult matchAndRewrite(vector::ScanOp scanOp, 2807 PatternRewriter &rewriter) const override { 2808 auto loc = scanOp.getLoc(); 2809 VectorType destType = scanOp.getDestType(); 2810 ArrayRef<int64_t> destShape = destType.getShape(); 2811 auto elType = destType.getElementType(); 2812 bool isInt = elType.isIntOrIndex(); 2813 if (!isValidKind(isInt, scanOp.getKind())) 2814 return failure(); 2815 2816 VectorType resType = VectorType::get(destShape, elType); 2817 Value result = rewriter.create<arith::ConstantOp>( 2818 loc, resType, rewriter.getZeroAttr(resType)); 2819 int64_t reductionDim = scanOp.getReductionDim(); 2820 bool inclusive = scanOp.getInclusive(); 2821 int64_t destRank = destType.getRank(); 2822 VectorType initialValueType = scanOp.getInitialValueType(); 2823 int64_t initialValueRank = initialValueType.getRank(); 2824 2825 SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end()); 2826 reductionShape[reductionDim] = 1; 2827 VectorType reductionType = VectorType::get(reductionShape, elType); 2828 SmallVector<int64_t> offsets(destRank, 0); 2829 SmallVector<int64_t> strides(destRank, 1); 2830 SmallVector<int64_t> sizes(destShape.begin(), destShape.end()); 2831 sizes[reductionDim] = 1; 2832 ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes); 2833 ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides); 2834 2835 Value lastOutput, lastInput; 2836 for (int i = 0; i < destShape[reductionDim]; i++) { 2837 offsets[reductionDim] = i; 2838 ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets); 2839 Value input = rewriter.create<vector::ExtractStridedSliceOp>( 2840 loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes, 2841 scanStrides); 2842 Value output; 2843 if (i == 0) { 2844 if (inclusive) { 2845 output = input; 2846 } else { 2847 if (initialValueRank == 0) { 2848 // ShapeCastOp cannot handle 0-D vectors 2849 output = rewriter.create<vector::BroadcastOp>( 2850 loc, input.getType(), scanOp.getInitialValue()); 2851 } else { 2852 output = rewriter.create<vector::ShapeCastOp>( 2853 loc, input.getType(), scanOp.getInitialValue()); 2854 } 2855 } 2856 } else { 2857 Value y = inclusive ? input : lastInput; 2858 output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter); 2859 assert(output != nullptr); 2860 } 2861 result = rewriter.create<vector::InsertStridedSliceOp>( 2862 loc, output, result, offsets, strides); 2863 lastOutput = output; 2864 lastInput = input; 2865 } 2866 2867 Value reduction; 2868 if (initialValueRank == 0) { 2869 Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0); 2870 reduction = 2871 rewriter.create<vector::BroadcastOp>(loc, initialValueType, v); 2872 } else { 2873 reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType, 2874 lastOutput); 2875 } 2876 2877 rewriter.replaceOp(scanOp, {result, reduction}); 2878 return success(); 2879 } 2880 }; 2881 2882 } // namespace 2883 2884 void mlir::vector::populateVectorMaskMaterializationPatterns( 2885 RewritePatternSet &patterns, bool force32BitVectorIndices) { 2886 patterns.add<VectorCreateMaskOpConversion, 2887 MaterializeTransferMask<vector::TransferReadOp>, 2888 MaterializeTransferMask<vector::TransferWriteOp>>( 2889 patterns.getContext(), force32BitVectorIndices); 2890 } 2891 2892 void mlir::vector::populateShapeCastFoldingPatterns( 2893 RewritePatternSet &patterns) { 2894 patterns.add<ShapeCastOpFolder>(patterns.getContext()); 2895 } 2896 2897 void mlir::vector::populateBubbleVectorBitCastOpPatterns( 2898 RewritePatternSet &patterns) { 2899 patterns.add<BubbleDownVectorBitCastForExtract, 2900 BubbleDownBitCastForStridedSliceExtract, 2901 BubbleUpBitCastForStridedSliceInsert>(patterns.getContext()); 2902 } 2903 2904 void mlir::vector::populateVectorBroadcastLoweringPatterns( 2905 RewritePatternSet &patterns) { 2906 patterns.add<BroadcastOpLowering>(patterns.getContext()); 2907 } 2908 2909 void mlir::vector::populateVectorMaskOpLoweringPatterns( 2910 RewritePatternSet &patterns) { 2911 patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>( 2912 patterns.getContext()); 2913 } 2914 2915 void mlir::vector::populateVectorShapeCastLoweringPatterns( 2916 RewritePatternSet &patterns) { 2917 patterns.add<ShapeCastOp2DDownCastRewritePattern, 2918 ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>( 2919 patterns.getContext()); 2920 } 2921 2922 void mlir::vector::populateVectorContractLoweringPatterns( 2923 RewritePatternSet &patterns, VectorTransformsOptions options) { 2924 patterns.add<OuterProductOpLowering>(patterns.getContext()); 2925 patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering, 2926 ContractionOpToOuterProductOpLowering>(options, 2927 patterns.getContext()); 2928 } 2929 2930 void mlir::vector::populateVectorTransposeLoweringPatterns( 2931 RewritePatternSet &patterns, VectorTransformsOptions options) { 2932 patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>( 2933 options, patterns.getContext()); 2934 } 2935 2936 void mlir::vector::populateVectorReductionToContractPatterns( 2937 RewritePatternSet &patterns) { 2938 patterns.add<MultiReduceToContract, CombineContractBroadcast, 2939 CombineContractTranspose, ReorderCastOpsOnBroadcast, 2940 ReorderElementwiseOpsOnTranspose>(patterns.getContext()); 2941 } 2942 2943 void mlir::vector:: 2944 populateVectorTransferCollapseInnerMostContiguousDimsPatterns( 2945 RewritePatternSet &patterns) { 2946 patterns.add<DropInnerMostUnitDims>(patterns.getContext()); 2947 } 2948 2949 void mlir::vector::populateVectorTransferLoweringPatterns( 2950 RewritePatternSet &patterns, llvm::Optional<unsigned> maxTransferRank) { 2951 patterns.add<TransferReadToVectorLoadLowering, 2952 TransferWriteToVectorStoreLowering>(patterns.getContext(), 2953 maxTransferRank); 2954 patterns 2955 .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>( 2956 patterns.getContext()); 2957 } 2958 2959 void mlir::vector::populateVectorScanLoweringPatterns( 2960 RewritePatternSet &patterns) { 2961 patterns.add<ScanToArithOps>(patterns.getContext()); 2962 } 2963