1 //===- VectorRewritePatterns.h - Vector rewrite patterns --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H 10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H 11 12 #include <utility> 13 14 #include "mlir/Dialect/Vector/IR/VectorOps.h" 15 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/PatternMatch.h" 18 19 namespace mlir { 20 class RewritePatternSet; 21 22 namespace vector { 23 24 //===----------------------------------------------------------------------===// 25 // Vector transformation options exposed as auxiliary structs. 26 //===----------------------------------------------------------------------===// 27 /// Enum to control the lowering of `vector.transpose` operations. 28 enum class VectorTransposeLowering { 29 /// Lower transpose into element-wise extract and inserts. 30 EltWise = 0, 31 /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix 32 /// intrinsics. 33 Flat = 1, 34 /// Lower 2-D transpose to `vector.shuffle`. 35 Shuffle = 2, 36 }; 37 /// Enum to control the lowering of `vector.multi_reduction` operations. 38 enum class VectorMultiReductionLowering { 39 /// Lower multi_reduction into outer-reduction and inner-parallel ops. 40 InnerParallel = 0, 41 /// Lower multi_reduction into outer-parallel and inner-reduction ops. 42 InnerReduction = 1, 43 }; 44 /// Enum to control the lowering of `vector.contract` operations. 45 enum class VectorContractLowering { 46 /// Progressively lower to finer grained `vector.contract` and dot-products. 47 Dot = 0, 48 /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics. 49 Matmul = 1, 50 /// Lower to `vector.outerproduct`. 51 OuterProduct = 2, 52 /// Lower contract with all reduction dimensions unrolled to 1 to a vector 53 /// elementwise operations. 54 ParallelArith = 3, 55 }; 56 /// Enum to control the splitting of `vector.transfer` operations into 57 /// in-bounds and out-of-bounds variants. 58 enum class VectorTransferSplit { 59 /// Do not split vector transfer operations. 60 None = 0, 61 /// Split using in-bounds + out-of-bounds vector.transfer operations. 62 VectorTransfer = 1, 63 /// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy 64 /// operations. 65 LinalgCopy = 2, 66 /// Do not split vector transfer operation but instead mark it as "in-bounds". 67 ForceInBounds = 3 68 }; 69 /// Structure to control the behavior of vector transform patterns. 70 struct VectorTransformsOptions { 71 /// Option to control the lowering of vector.contract. 72 VectorContractLowering vectorContractLowering = VectorContractLowering::Dot; 73 VectorTransformsOptions & setVectorTransformsOptionsVectorTransformsOptions74 setVectorTransformsOptions(VectorContractLowering opt) { 75 vectorContractLowering = opt; 76 return *this; 77 } 78 /// Option to control the lowering of vector.multi_reduction. 79 VectorMultiReductionLowering vectorMultiReductionLowering = 80 VectorMultiReductionLowering::InnerParallel; 81 VectorTransformsOptions & setVectorMultiReductionLoweringVectorTransformsOptions82 setVectorMultiReductionLowering(VectorMultiReductionLowering opt) { 83 vectorMultiReductionLowering = opt; 84 return *this; 85 } 86 /// Option to control the lowering of vector.transpose. 87 VectorTransposeLowering vectorTransposeLowering = 88 VectorTransposeLowering::EltWise; 89 VectorTransformsOptions & setVectorTransposeLoweringVectorTransformsOptions90 setVectorTransposeLowering(VectorTransposeLowering opt) { 91 vectorTransposeLowering = opt; 92 return *this; 93 } 94 /// Option to control the splitting of vector transfers. 95 VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None; setVectorTransferSplitVectorTransformsOptions96 VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) { 97 vectorTransferSplit = opt; 98 return *this; 99 } 100 }; 101 102 /// Options that control the vector unrolling. 103 struct UnrollVectorOptions { 104 using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>; 105 /// Callback function that indicates whether vector unrolling should be 106 /// attempted on the operation. 107 FilterConstraintFnType filterConstraint = nullptr; setFilterConstraintUnrollVectorOptions108 UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) { 109 filterConstraint = std::move(constraint); 110 return *this; 111 } 112 113 using NativeShapeFnType = 114 std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>; 115 /// Function that returns the shape of the vector to unroll to for a given 116 /// operation. The unrolling is aborted if the function returns `llvm::None`. 117 NativeShapeFnType nativeShape = nullptr; setNativeShapeFnUnrollVectorOptions118 UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) { 119 nativeShape = std::move(fn); 120 return *this; 121 } 122 123 /// Set the native shape to use for unrolling. setNativeShapeUnrollVectorOptions124 UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) { 125 SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end()); 126 nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> { 127 return tsShape; 128 }; 129 return *this; 130 } 131 132 /// Function that returns the traversal order (in terms of "for loop order", 133 /// i.e. slowest varying dimension to fastest varying dimension) that shoudl 134 /// be used when unrolling the given operation into units of the native vector 135 /// size. 136 using UnrollTraversalOrderFnType = 137 std::function<Optional<SmallVector<int64_t>>(Operation *op)>; 138 UnrollTraversalOrderFnType traversalOrderCallback = nullptr; 139 UnrollVectorOptions & setUnrollTraversalOrderFnUnrollVectorOptions140 setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn) { 141 traversalOrderCallback = std::move(traversalOrderFn); 142 return *this; 143 } 144 }; 145 146 //===----------------------------------------------------------------------===// 147 // Vector transformation exposed as populate functions over rewrite patterns. 148 //===----------------------------------------------------------------------===// 149 150 /// Insert TransposeLowering patterns into extraction/insertion. 151 void populateVectorTransposeLoweringPatterns( 152 RewritePatternSet &patterns, 153 VectorTransformsOptions options = VectorTransformsOptions()); 154 155 /// Collect a set of patterns to convert vector.multi_reduction op into 156 /// a sequence of vector.reduction ops. The patterns comprise: 157 /// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such 158 /// that all reduction dimensions are either innermost or outermost, by adding 159 /// the proper vector.transpose operations. 160 /// - ReduceMultiDimReductionRank: once in innermost or outermost reduction 161 /// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction, 162 /// by introducing vector.shape_cast ops to collapse + multi-reduce + expand 163 /// back. 164 /// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction 165 /// form, with an **outermost** reduction dimension, unroll the outer dimension 166 /// to obtain a sequence of 1-D vector ops. This also has an opportunity for 167 /// tree-reduction (in the future). 168 /// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form, 169 /// with an **innermost** reduction dimension, unroll the outer dimension to 170 /// obtain a sequence of extract + vector.reduction + insert. This can further 171 /// lower to horizontal reduction ops. 172 /// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector<k> 173 /// reduction (and are thus missing either a parallel or a reduction), we lift 174 /// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that 175 /// the other patterns can kick in, thus fully exiting out of the 176 /// vector.multi_reduction abstraction. 177 void populateVectorMultiReductionLoweringPatterns( 178 RewritePatternSet &patterns, VectorMultiReductionLowering options); 179 180 /// Collects patterns to progressively lower vector contraction ops on high-D 181 /// into low-D reduction and product ops. 182 void populateVectorContractLoweringPatterns( 183 RewritePatternSet &patterns, 184 VectorTransformsOptions options = VectorTransformsOptions()); 185 186 /// Collect patterns to convert reduction op to vector.contract and fold 187 /// transpose/broadcast ops into the contract. 188 void populateVectorReductionToContractPatterns(RewritePatternSet &patterns); 189 190 /// Collect patterns to convert scan op 191 void populateVectorScanLoweringPatterns(RewritePatternSet &patterns); 192 193 //===----------------------------------------------------------------------===// 194 // Vector.transfer patterns. 195 //===----------------------------------------------------------------------===// 196 /// Collect a set of transfer read/write lowering patterns that simplify the 197 /// permutation map (e.g., converting it to a minor identity map) by inserting 198 /// broadcasts and transposes. More specifically: 199 /// 200 /// [TransferReadPermutationLowering] 201 /// Lower transfer_read op with permutation into a transfer_read with a 202 /// permutation map composed of leading zeros followed by a minor identity + 203 /// vector.transpose op. 204 /// Ex: 205 /// vector.transfer_read ... 206 /// permutation_map: (d0, d1, d2) -> (0, d1) 207 /// into: 208 /// %v = vector.transfer_read ... 209 /// permutation_map: (d0, d1, d2) -> (d1, 0) 210 /// vector.transpose %v, [1, 0] 211 /// 212 /// vector.transfer_read ... 213 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3) 214 /// into: 215 /// %v = vector.transfer_read ... 216 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3) 217 /// vector.transpose %v, [0, 1, 3, 2, 4] 218 /// Note that an alternative is to transform it to linalg.transpose + 219 /// vector.transfer_read to do the transpose in memory instead. 220 /// 221 /// [TransferWritePermutationLowering] 222 /// Lower transfer_write op with permutation into a transfer_write with a 223 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.) 224 /// Ex: 225 /// vector.transfer_write %v ... 226 /// permutation_map: (d0, d1, d2) -> (d2, d0, d1) 227 /// into: 228 /// %tmp = vector.transpose %v, [2, 0, 1] 229 /// vector.transfer_write %tmp ... 230 /// permutation_map: (d0, d1, d2) -> (d0, d1, d2) 231 /// 232 /// vector.transfer_write %v ... 233 /// permutation_map: (d0, d1, d2, d3) -> (d3, d2) 234 /// into: 235 /// %tmp = vector.transpose %v, [1, 0] 236 /// %v = vector.transfer_write %tmp ... 237 /// permutation_map: (d0, d1, d2, d3) -> (d2, d3) 238 /// 239 /// [TransferOpReduceRank] 240 /// Lower transfer_read op with broadcast in the leading dimensions into 241 /// transfer_read of lower rank + vector.broadcast. 242 /// Ex: vector.transfer_read ... 243 /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3) 244 /// into: 245 /// %v = vector.transfer_read ... 246 /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) 247 /// vector.broadcast %v 248 void populateVectorTransferPermutationMapLoweringPatterns( 249 RewritePatternSet &patterns); 250 251 /// Collect a set of patterns to reduce the rank of the operands of vector 252 /// transfer ops to operate on the largest contigious vector. 253 /// These patterns are useful when lowering to dialects with 1d vector type 254 /// such as llvm and it will result fewer memory reads. 255 void populateVectorTransferCollapseInnerMostContiguousDimsPatterns( 256 RewritePatternSet &patterns); 257 258 /// Populate `patterns` with the following patterns. 259 /// 260 /// [DecomposeDifferentRankInsertStridedSlice] 261 /// ========================================== 262 /// RewritePattern for InsertStridedSliceOp where source and destination vectors 263 /// have different ranks. 264 /// 265 /// When ranks are different, InsertStridedSlice needs to extract a properly 266 /// ranked vector from the destination vector into which to insert. This pattern 267 /// only takes care of this extraction part and forwards the rest to 268 /// [VectorInsertStridedSliceOpSameRankRewritePattern]. 269 /// 270 /// For a k-D source and n-D destination vector (k < n), we emit: 271 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to 272 /// insert the k-D source. 273 /// 2. k-D -> (n-1)-D InsertStridedSlice op 274 /// 3. InsertOp that is the reverse of 1. 275 /// 276 /// [DecomposeNDExtractStridedSlice] 277 /// ================================ 278 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower 279 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case. 280 void populateVectorInsertExtractStridedSliceDecompositionPatterns( 281 RewritePatternSet &patterns); 282 283 /// Populate `patterns` with the following patterns. 284 /// 285 /// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns(); 286 /// 287 /// [ConvertSameRankInsertStridedSliceIntoShuffle] 288 /// ============================================== 289 /// RewritePattern for InsertStridedSliceOp where source and destination vectors 290 /// have the same rank. For each outermost index in the slice: 291 /// begin end stride 292 /// [offset : offset+size*stride : stride] 293 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector. 294 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D 295 /// 3. the destination subvector is inserted back in the proper place 296 /// 3. InsertOp that is the reverse of 1. 297 /// 298 /// [Convert1DExtractStridedSliceIntoShuffle] 299 /// ========================================= 300 /// For such cases, we can lower it to a ShuffleOp. 301 void populateVectorInsertExtractStridedSliceTransforms( 302 RewritePatternSet &patterns); 303 304 /// Collect a set of pattern to unroll vector operations to a smaller shapes. 305 /// `options` structure controls which operations are unrolled and the target 306 /// shape. 307 /// `op` is unrolled to the `targetShape` as follows, for each of its operands: 308 /// 1. the unrolled type `unrolledVectorType` and number of unrolled instances 309 /// `numUnrolledInstances` are computed from the `targetShape`. For now it is 310 /// assumed the unrolling factors divide the vector sizes. 311 /// 2. ExtractStridedSlice are created to break-up the vector operands. 312 /// 3. the original op is cloned `numUnrolledInstances` times, once for each 313 /// result. 314 /// 4. InsertStridedSlice are inserted to re-assemble the slices into the 315 /// original vectore shape. 316 /// 317 /// Example: 318 /// 319 /// opA(operand0, operand1) // numUnrolledInstances = 3 320 /// 321 /// operand0 operand1 322 /// | | 323 /// fork fork 324 /// <----------gather all fork ops ---------> 325 /// /|\ /|\ 326 /// f00 f01 f02 f10 f11 f12 327 /// <---------- clone op 3 times ---------> 328 /// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12) 329 /// \ | / 330 /// <-------------------- join -------------------------> 331 /// 332 /// Other local patterns then kick in iteratively (including DCE) and compose 333 /// to combine the ExtractStridedSlice/InsertStridedSlice. 334 void populateVectorUnrollPatterns(RewritePatternSet &patterns, 335 const UnrollVectorOptions &options); 336 337 //===----------------------------------------------------------------------===// 338 // Finer-grained patterns exposed for more control over individual lowerings. 339 //===----------------------------------------------------------------------===// 340 /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern 341 /// may take an extra filter to perform selection at a finer granularity. 342 struct VectorTransferFullPartialRewriter : public RewritePattern { 343 using FilterConstraintType = 344 std::function<LogicalResult(VectorTransferOpInterface op)>; 345 346 explicit VectorTransferFullPartialRewriter( 347 MLIRContext *context, 348 VectorTransformsOptions options = VectorTransformsOptions(), 349 FilterConstraintType filter = 350 [](VectorTransferOpInterface op) { return success(); }, 351 PatternBenefit benefit = 1) RewritePatternVectorTransferFullPartialRewriter352 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options), 353 filter(std::move(filter)) {} 354 355 /// Performs the rewrite. 356 LogicalResult matchAndRewrite(Operation *op, 357 PatternRewriter &rewriter) const override; 358 359 private: 360 VectorTransformsOptions options; 361 FilterConstraintType filter; 362 }; 363 364 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 365 /// semantics to: 366 /// ``` 367 /// %flattened_a = vector.shape_cast %a 368 /// %flattened_b = vector.shape_cast %b 369 /// %flattened_d = vector.matmul %flattened_a, %flattened_b 370 /// %d = vector.shape_cast %%flattened_d 371 /// %e = add %c, %d 372 /// ``` 373 /// `vector.matmul` later lowers to `llvm.matrix.multiply`. 374 // 375 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and 376 /// the vector.contract op is a row-major matrix multiply. 377 class ContractionOpToMatmulOpLowering 378 : public OpRewritePattern<vector::ContractionOp> { 379 public: 380 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 381 using FilterConstraintType = 382 std::function<LogicalResult(vector::ContractionOp op)>; 383 defaultFilter(vector::ContractionOp op)384 static LogicalResult defaultFilter(vector::ContractionOp op) { 385 return success(); 386 } 387 388 ContractionOpToMatmulOpLowering( 389 vector::VectorTransformsOptions vectorTransformOptions, 390 MLIRContext *context, FilterConstraintType constraint = defaultFilter) 391 : OpRewritePattern<vector::ContractionOp>(context), 392 vectorTransformOptions(vectorTransformOptions), 393 filter(std::move(constraint)) {} 394 395 LogicalResult matchAndRewrite(vector::ContractionOp op, 396 PatternRewriter &rewriter) const override; 397 398 private: 399 /// Options to control the vector patterns. 400 vector::VectorTransformsOptions vectorTransformOptions; 401 FilterConstraintType filter; 402 }; 403 404 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 405 /// semantics to a reduction_size-unrolled sequence: 406 /// ``` 407 /// %at = vector.transpose %a, [1, 0] 408 /// %bRow0 = vector.extract %b[0] 409 /// %atRow0 = vector.extract %at[0] 410 /// %c0 = vector.outerproduct %atRow0, %bRow0, %c 411 /// ... 412 /// %bRowK = vector.extract %b[K] 413 /// %atRowK = vector.extract %at[K] 414 /// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 415 /// ``` 416 /// 417 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and 418 /// the vector.contract op is a row-major matrix multiply. 419 class ContractionOpToOuterProductOpLowering 420 : public OpRewritePattern<vector::ContractionOp> { 421 public: 422 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 423 using FilterConstraintType = 424 std::function<LogicalResult(vector::ContractionOp op)>; 425 defaultFilter(vector::ContractionOp op)426 static LogicalResult defaultFilter(vector::ContractionOp op) { 427 return success(); 428 } 429 430 ContractionOpToOuterProductOpLowering( 431 vector::VectorTransformsOptions vectorTransformOptions, 432 MLIRContext *context, FilterConstraintType constraint = defaultFilter) 433 : OpRewritePattern<vector::ContractionOp>(context), 434 vectorTransformOptions(vectorTransformOptions), 435 filter(std::move(constraint)) {} 436 437 LogicalResult matchAndRewrite(vector::ContractionOp op, 438 PatternRewriter &rewriter) const override; 439 440 private: 441 /// Options to control the vector patterns. 442 vector::VectorTransformsOptions vectorTransformOptions; 443 FilterConstraintType filter; 444 }; 445 446 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul 447 /// semantics to an output-size-unrolled sequence: 448 /// ``` 449 /// %out = arith.constant ... : vector<MxNxelt_type> 450 /// %bt = vector.transpose %b, [1, 0] 451 /// %aRow0 = vector.extract %a[0] 452 /// %btRow0 = vector.extract %bt[0] 453 /// %c00 = vector.reduce %atRow0, %bRow0 454 /// %out00 = vector.insert %c00, %out[0, 0] 455 /// ... 456 /// %aRowLast = vector.extract %at[M-1] 457 /// %btRowLast = vector.extract %b[N-1] 458 /// %cLastLast = vector.reduce %atRowLast, %bRowLast 459 /// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1] 460 /// ``` 461 /// 462 /// This only kicks in when VectorTransformsOptions is set to Dot and 463 /// the vector.contract op is a row-major matmul or matvec. 464 class ContractionOpToDotLowering 465 : public OpRewritePattern<vector::ContractionOp> { 466 public: 467 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 468 using FilterConstraintType = 469 std::function<LogicalResult(vector::ContractionOp op)>; 470 defaultFilter(vector::ContractionOp op)471 static LogicalResult defaultFilter(vector::ContractionOp op) { 472 return success(); 473 } 474 475 ContractionOpToDotLowering( 476 vector::VectorTransformsOptions vectorTransformOptions, 477 MLIRContext *context, 478 const FilterConstraintType &constraint = defaultFilter) 479 : OpRewritePattern<vector::ContractionOp>(context), 480 vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} 481 482 LogicalResult matchAndRewrite(vector::ContractionOp op, 483 PatternRewriter &rewriter) const override; 484 485 private: 486 /// Options to control the vector patterns. 487 vector::VectorTransformsOptions vectorTransformOptions; 488 FilterConstraintType filter; 489 }; 490 491 /// Progressive lowering of ContractionOp. 492 /// 493 /// One: 494 /// %x = vector.contract with at least one free/batch dimension 495 /// is replaced by: 496 /// %a = vector.contract with one less free/batch dimension 497 /// %b = vector.contract with one less free/batch dimension 498 /// .. 499 /// %x = combine %a %b .. 500 /// until a pure contraction is reached (no free/batch dimensions), 501 /// which is replaced by a dot-product. 502 /// 503 /// This only kicks in when either VectorTransformsOptions is set 504 /// to Dot or when other contraction patterns fail. 505 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> { 506 public: 507 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; 508 using FilterConstraintType = 509 std::function<LogicalResult(vector::ContractionOp op)>; 510 defaultFilter(vector::ContractionOp op)511 static LogicalResult defaultFilter(vector::ContractionOp op) { 512 return success(); 513 } 514 515 ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions, 516 MLIRContext *context, 517 FilterConstraintType constraint = defaultFilter) 518 : OpRewritePattern<vector::ContractionOp>(context), 519 vectorTransformOptions(vectorTransformOptions), 520 filter(std::move(constraint)) {} 521 522 LogicalResult matchAndRewrite(vector::ContractionOp op, 523 PatternRewriter &rewriter) const override; 524 525 private: 526 /// Options to control the vector patterns. 527 vector::VectorTransformsOptions vectorTransformOptions; 528 FilterConstraintType filter; 529 // Lower one parallel dimension. 530 FailureOr<Value> lowerParallel(vector::ContractionOp op, int64_t lhsIndex, 531 int64_t rhsIndex, 532 PatternRewriter &rewriter) const; 533 // Lower one reduction dimension. 534 FailureOr<Value> lowerReduction(vector::ContractionOp op, 535 PatternRewriter &rewriter) const; 536 }; 537 538 } // namespace vector 539 } // namespace mlir 540 541 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H 542