1 //===- VectorRewritePatterns.h - Vector rewrite patterns --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
11 
12 #include <utility>
13 
14 #include "mlir/Dialect/Vector/IR/VectorOps.h"
15 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/PatternMatch.h"
18 
19 namespace mlir {
20 class RewritePatternSet;
21 
22 namespace vector {
23 
24 //===----------------------------------------------------------------------===//
25 // Vector transformation options exposed as auxiliary structs.
26 //===----------------------------------------------------------------------===//
27 /// Enum to control the lowering of `vector.transpose` operations.
28 enum class VectorTransposeLowering {
29   /// Lower transpose into element-wise extract and inserts.
30   EltWise = 0,
31   /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
32   /// intrinsics.
33   Flat = 1,
34   /// Lower 2-D transpose to `vector.shuffle`.
35   Shuffle = 2,
36 };
37 /// Enum to control the lowering of `vector.multi_reduction` operations.
38 enum class VectorMultiReductionLowering {
39   /// Lower multi_reduction into outer-reduction and inner-parallel ops.
40   InnerParallel = 0,
41   /// Lower multi_reduction into outer-parallel and inner-reduction ops.
42   InnerReduction = 1,
43 };
44 /// Enum to control the lowering of `vector.contract` operations.
45 enum class VectorContractLowering {
46   /// Progressively lower to finer grained `vector.contract` and dot-products.
47   Dot = 0,
48   /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
49   Matmul = 1,
50   /// Lower to `vector.outerproduct`.
51   OuterProduct = 2,
52   /// Lower contract with all reduction dimensions unrolled to 1 to a vector
53   /// elementwise operations.
54   ParallelArith = 3,
55 };
56 /// Enum to control the splitting of `vector.transfer` operations into
57 /// in-bounds and out-of-bounds variants.
58 enum class VectorTransferSplit {
59   /// Do not split vector transfer operations.
60   None = 0,
61   /// Split using in-bounds + out-of-bounds vector.transfer operations.
62   VectorTransfer = 1,
63   /// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy
64   /// operations.
65   LinalgCopy = 2,
66   /// Do not split vector transfer operation but instead mark it as "in-bounds".
67   ForceInBounds = 3
68 };
69 /// Structure to control the behavior of vector transform patterns.
70 struct VectorTransformsOptions {
71   /// Option to control the lowering of vector.contract.
72   VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
73   VectorTransformsOptions &
setVectorTransformsOptionsVectorTransformsOptions74   setVectorTransformsOptions(VectorContractLowering opt) {
75     vectorContractLowering = opt;
76     return *this;
77   }
78   /// Option to control the lowering of vector.multi_reduction.
79   VectorMultiReductionLowering vectorMultiReductionLowering =
80       VectorMultiReductionLowering::InnerParallel;
81   VectorTransformsOptions &
setVectorMultiReductionLoweringVectorTransformsOptions82   setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
83     vectorMultiReductionLowering = opt;
84     return *this;
85   }
86   /// Option to control the lowering of vector.transpose.
87   VectorTransposeLowering vectorTransposeLowering =
88       VectorTransposeLowering::EltWise;
89   VectorTransformsOptions &
setVectorTransposeLoweringVectorTransformsOptions90   setVectorTransposeLowering(VectorTransposeLowering opt) {
91     vectorTransposeLowering = opt;
92     return *this;
93   }
94   /// Option to control the splitting of vector transfers.
95   VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
setVectorTransferSplitVectorTransformsOptions96   VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
97     vectorTransferSplit = opt;
98     return *this;
99   }
100 };
101 
102 /// Options that control the vector unrolling.
103 struct UnrollVectorOptions {
104   using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
105   /// Callback function that indicates whether vector unrolling should be
106   /// attempted on the operation.
107   FilterConstraintFnType filterConstraint = nullptr;
setFilterConstraintUnrollVectorOptions108   UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) {
109     filterConstraint = std::move(constraint);
110     return *this;
111   }
112 
113   using NativeShapeFnType =
114       std::function<Optional<SmallVector<int64_t, 4>>(Operation *op)>;
115   /// Function that returns the shape of the vector to unroll to for a given
116   /// operation. The unrolling is aborted if the function returns `llvm::None`.
117   NativeShapeFnType nativeShape = nullptr;
setNativeShapeFnUnrollVectorOptions118   UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) {
119     nativeShape = std::move(fn);
120     return *this;
121   }
122 
123   /// Set the native shape to use for unrolling.
setNativeShapeUnrollVectorOptions124   UnrollVectorOptions &setNativeShape(ArrayRef<int64_t> shape) {
125     SmallVector<int64_t, 4> tsShape(shape.begin(), shape.end());
126     nativeShape = [=](Operation *) -> Optional<SmallVector<int64_t, 4>> {
127       return tsShape;
128     };
129     return *this;
130   }
131 
132   /// Function that returns the traversal order (in terms of "for loop order",
133   /// i.e. slowest varying dimension to fastest varying dimension) that shoudl
134   /// be used when unrolling the given operation into units of the native vector
135   /// size.
136   using UnrollTraversalOrderFnType =
137       std::function<Optional<SmallVector<int64_t>>(Operation *op)>;
138   UnrollTraversalOrderFnType traversalOrderCallback = nullptr;
139   UnrollVectorOptions &
setUnrollTraversalOrderFnUnrollVectorOptions140   setUnrollTraversalOrderFn(UnrollTraversalOrderFnType traversalOrderFn) {
141     traversalOrderCallback = std::move(traversalOrderFn);
142     return *this;
143   }
144 };
145 
146 //===----------------------------------------------------------------------===//
147 // Vector transformation exposed as populate functions over rewrite patterns.
148 //===----------------------------------------------------------------------===//
149 
150 /// Insert TransposeLowering patterns into extraction/insertion.
151 void populateVectorTransposeLoweringPatterns(
152     RewritePatternSet &patterns,
153     VectorTransformsOptions options = VectorTransformsOptions());
154 
155 /// Collect a set of patterns to convert vector.multi_reduction op into
156 /// a sequence of vector.reduction ops. The patterns comprise:
157 /// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such
158 /// that all reduction dimensions are either innermost or outermost, by adding
159 /// the proper vector.transpose operations.
160 /// - ReduceMultiDimReductionRank: once in innermost or outermost reduction
161 /// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
162 /// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
163 /// back.
164 /// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction
165 /// form, with an **outermost** reduction dimension, unroll the outer dimension
166 /// to obtain a sequence of 1-D vector ops. This also has an opportunity for
167 /// tree-reduction (in the future).
168 /// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form,
169 /// with an **innermost** reduction dimension, unroll the outer dimension to
170 /// obtain a sequence of extract + vector.reduction + insert. This can further
171 /// lower to horizontal reduction ops.
172 /// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector<k>
173 /// reduction (and are thus missing either a parallel or a reduction), we lift
174 /// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that
175 /// the other patterns can kick in, thus fully exiting out of the
176 /// vector.multi_reduction abstraction.
177 void populateVectorMultiReductionLoweringPatterns(
178     RewritePatternSet &patterns, VectorMultiReductionLowering options);
179 
180 /// Collects patterns to progressively lower vector contraction ops on high-D
181 /// into low-D reduction and product ops.
182 void populateVectorContractLoweringPatterns(
183     RewritePatternSet &patterns,
184     VectorTransformsOptions options = VectorTransformsOptions());
185 
186 /// Collect patterns to convert reduction op to vector.contract and fold
187 /// transpose/broadcast ops into the contract.
188 void populateVectorReductionToContractPatterns(RewritePatternSet &patterns);
189 
190 /// Collect patterns to convert scan op
191 void populateVectorScanLoweringPatterns(RewritePatternSet &patterns);
192 
193 //===----------------------------------------------------------------------===//
194 // Vector.transfer patterns.
195 //===----------------------------------------------------------------------===//
196 /// Collect a set of transfer read/write lowering patterns that simplify the
197 /// permutation map (e.g., converting it to a minor identity map) by inserting
198 /// broadcasts and transposes. More specifically:
199 ///
200 /// [TransferReadPermutationLowering]
201 /// Lower transfer_read op with permutation into a transfer_read with a
202 /// permutation map composed of leading zeros followed by a minor identity +
203 /// vector.transpose op.
204 /// Ex:
205 ///     vector.transfer_read ...
206 ///         permutation_map: (d0, d1, d2) -> (0, d1)
207 /// into:
208 ///     %v = vector.transfer_read ...
209 ///         permutation_map: (d0, d1, d2) -> (d1, 0)
210 ///     vector.transpose %v, [1, 0]
211 ///
212 ///     vector.transfer_read ...
213 ///         permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
214 /// into:
215 ///     %v = vector.transfer_read ...
216 ///         permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
217 ///     vector.transpose %v, [0, 1, 3, 2, 4]
218 /// Note that an alternative is to transform it to linalg.transpose +
219 /// vector.transfer_read to do the transpose in memory instead.
220 ///
221 /// [TransferWritePermutationLowering]
222 /// Lower transfer_write op with permutation into a transfer_write with a
223 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
224 /// Ex:
225 ///     vector.transfer_write %v ...
226 ///         permutation_map: (d0, d1, d2) -> (d2, d0, d1)
227 /// into:
228 ///     %tmp = vector.transpose %v, [2, 0, 1]
229 ///     vector.transfer_write %tmp ...
230 ///         permutation_map: (d0, d1, d2) -> (d0, d1, d2)
231 ///
232 ///     vector.transfer_write %v ...
233 ///         permutation_map: (d0, d1, d2, d3) -> (d3, d2)
234 /// into:
235 ///     %tmp = vector.transpose %v, [1, 0]
236 ///     %v = vector.transfer_write %tmp ...
237 ///         permutation_map: (d0, d1, d2, d3) -> (d2, d3)
238 ///
239 /// [TransferOpReduceRank]
240 /// Lower transfer_read op with broadcast in the leading dimensions into
241 /// transfer_read of lower rank + vector.broadcast.
242 /// Ex: vector.transfer_read ...
243 ///         permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
244 /// into:
245 ///     %v = vector.transfer_read ...
246 ///         permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
247 ///     vector.broadcast %v
248 void populateVectorTransferPermutationMapLoweringPatterns(
249     RewritePatternSet &patterns);
250 
251 /// Collect a set of patterns to reduce the rank of the operands of vector
252 /// transfer ops to operate on the largest contigious vector.
253 /// These patterns are useful when lowering to dialects with 1d vector type
254 /// such as llvm and it will result fewer memory reads.
255 void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
256     RewritePatternSet &patterns);
257 
258 /// Populate `patterns` with the following patterns.
259 ///
260 /// [DecomposeDifferentRankInsertStridedSlice]
261 /// ==========================================
262 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
263 /// have different ranks.
264 ///
265 /// When ranks are different, InsertStridedSlice needs to extract a properly
266 /// ranked vector from the destination vector into which to insert. This pattern
267 /// only takes care of this extraction part and forwards the rest to
268 /// [VectorInsertStridedSliceOpSameRankRewritePattern].
269 ///
270 /// For a k-D source and n-D destination vector (k < n), we emit:
271 ///   1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
272 ///      insert the k-D source.
273 ///   2. k-D -> (n-1)-D InsertStridedSlice op
274 ///   3. InsertOp that is the reverse of 1.
275 ///
276 /// [DecomposeNDExtractStridedSlice]
277 /// ================================
278 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
279 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
280 void populateVectorInsertExtractStridedSliceDecompositionPatterns(
281     RewritePatternSet &patterns);
282 
283 /// Populate `patterns` with the following patterns.
284 ///
285 /// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns();
286 ///
287 /// [ConvertSameRankInsertStridedSliceIntoShuffle]
288 /// ==============================================
289 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
290 /// have the same rank. For each outermost index in the slice:
291 ///   begin    end             stride
292 /// [offset : offset+size*stride : stride]
293 ///   1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
294 ///   2. InsertStridedSlice (k-1)-D into (n-1)-D
295 ///   3. the destination subvector is inserted back in the proper place
296 ///   3. InsertOp that is the reverse of 1.
297 ///
298 /// [Convert1DExtractStridedSliceIntoShuffle]
299 /// =========================================
300 /// For such cases, we can lower it to a ShuffleOp.
301 void populateVectorInsertExtractStridedSliceTransforms(
302     RewritePatternSet &patterns);
303 
304 /// Collect a set of pattern to unroll vector operations to a smaller shapes.
305 /// `options` structure controls which operations are unrolled and the target
306 /// shape.
307 /// `op` is unrolled to the `targetShape` as follows, for each of its operands:
308 ///   1. the unrolled type `unrolledVectorType` and number of unrolled instances
309 ///   `numUnrolledInstances` are computed from the `targetShape`. For now it is
310 ///   assumed the unrolling factors divide the vector sizes.
311 ///   2. ExtractStridedSlice are created to break-up the vector operands.
312 ///   3. the original op is cloned `numUnrolledInstances` times, once for each
313 ///   result.
314 ///   4. InsertStridedSlice are inserted to re-assemble the slices into the
315 ///   original vectore shape.
316 ///
317 /// Example:
318 ///
319 ///    opA(operand0, operand1)  // numUnrolledInstances = 3
320 ///
321 ///            operand0                   operand1
322 ///               |                          |
323 ///             fork                       fork
324 ///        <----------gather all fork ops --------->
325 ///              /|\                        /|\
326 ///          f00 f01 f02                f10 f11 f12
327 ///        <---------- clone op 3 times --------->
328 ///          opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
329 ///                 \            |            /
330 ///      <-------------------- join ------------------------->
331 ///
332 /// Other local patterns then kick in iteratively (including DCE) and compose
333 /// to combine the ExtractStridedSlice/InsertStridedSlice.
334 void populateVectorUnrollPatterns(RewritePatternSet &patterns,
335                                   const UnrollVectorOptions &options);
336 
337 //===----------------------------------------------------------------------===//
338 // Finer-grained patterns exposed for more control over individual lowerings.
339 //===----------------------------------------------------------------------===//
340 /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
341 /// may take an extra filter to perform selection at a finer granularity.
342 struct VectorTransferFullPartialRewriter : public RewritePattern {
343   using FilterConstraintType =
344       std::function<LogicalResult(VectorTransferOpInterface op)>;
345 
346   explicit VectorTransferFullPartialRewriter(
347       MLIRContext *context,
348       VectorTransformsOptions options = VectorTransformsOptions(),
349       FilterConstraintType filter =
350           [](VectorTransferOpInterface op) { return success(); },
351       PatternBenefit benefit = 1)
RewritePatternVectorTransferFullPartialRewriter352       : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
353         filter(std::move(filter)) {}
354 
355   /// Performs the rewrite.
356   LogicalResult matchAndRewrite(Operation *op,
357                                 PatternRewriter &rewriter) const override;
358 
359 private:
360   VectorTransformsOptions options;
361   FilterConstraintType filter;
362 };
363 
364 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
365 /// semantics to:
366 /// ```
367 ///    %flattened_a = vector.shape_cast %a
368 ///    %flattened_b = vector.shape_cast %b
369 ///    %flattened_d = vector.matmul %flattened_a, %flattened_b
370 ///    %d = vector.shape_cast %%flattened_d
371 ///    %e = add %c, %d
372 /// ```
373 /// `vector.matmul` later lowers to `llvm.matrix.multiply`.
374 //
375 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
376 /// the vector.contract op is a row-major matrix multiply.
377 class ContractionOpToMatmulOpLowering
378     : public OpRewritePattern<vector::ContractionOp> {
379 public:
380   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
381   using FilterConstraintType =
382       std::function<LogicalResult(vector::ContractionOp op)>;
383 
defaultFilter(vector::ContractionOp op)384   static LogicalResult defaultFilter(vector::ContractionOp op) {
385     return success();
386   }
387 
388   ContractionOpToMatmulOpLowering(
389       vector::VectorTransformsOptions vectorTransformOptions,
390       MLIRContext *context, FilterConstraintType constraint = defaultFilter)
391       : OpRewritePattern<vector::ContractionOp>(context),
392         vectorTransformOptions(vectorTransformOptions),
393         filter(std::move(constraint)) {}
394 
395   LogicalResult matchAndRewrite(vector::ContractionOp op,
396                                 PatternRewriter &rewriter) const override;
397 
398 private:
399   /// Options to control the vector patterns.
400   vector::VectorTransformsOptions vectorTransformOptions;
401   FilterConstraintType filter;
402 };
403 
404 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
405 /// semantics to a reduction_size-unrolled sequence:
406 /// ```
407 ///    %at = vector.transpose %a, [1, 0]
408 ///    %bRow0 = vector.extract %b[0]
409 ///    %atRow0 = vector.extract %at[0]
410 ///    %c0 = vector.outerproduct %atRow0, %bRow0, %c
411 ///    ...
412 ///    %bRowK = vector.extract %b[K]
413 ///    %atRowK = vector.extract %at[K]
414 ///    %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
415 /// ```
416 ///
417 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
418 /// the vector.contract op is a row-major matrix multiply.
419 class ContractionOpToOuterProductOpLowering
420     : public OpRewritePattern<vector::ContractionOp> {
421 public:
422   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
423   using FilterConstraintType =
424       std::function<LogicalResult(vector::ContractionOp op)>;
425 
defaultFilter(vector::ContractionOp op)426   static LogicalResult defaultFilter(vector::ContractionOp op) {
427     return success();
428   }
429 
430   ContractionOpToOuterProductOpLowering(
431       vector::VectorTransformsOptions vectorTransformOptions,
432       MLIRContext *context, FilterConstraintType constraint = defaultFilter)
433       : OpRewritePattern<vector::ContractionOp>(context),
434         vectorTransformOptions(vectorTransformOptions),
435         filter(std::move(constraint)) {}
436 
437   LogicalResult matchAndRewrite(vector::ContractionOp op,
438                                 PatternRewriter &rewriter) const override;
439 
440 private:
441   /// Options to control the vector patterns.
442   vector::VectorTransformsOptions vectorTransformOptions;
443   FilterConstraintType filter;
444 };
445 
446 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
447 /// semantics to an output-size-unrolled sequence:
448 /// ```
449 ///    %out = arith.constant ... : vector<MxNxelt_type>
450 ///    %bt = vector.transpose %b, [1, 0]
451 ///    %aRow0 = vector.extract %a[0]
452 ///    %btRow0 = vector.extract %bt[0]
453 ///    %c00 = vector.reduce %atRow0, %bRow0
454 ///    %out00 = vector.insert %c00, %out[0, 0]
455 ///    ...
456 ///    %aRowLast = vector.extract %at[M-1]
457 ///    %btRowLast = vector.extract %b[N-1]
458 ///    %cLastLast = vector.reduce %atRowLast, %bRowLast
459 ///    %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
460 /// ```
461 ///
462 /// This only kicks in when VectorTransformsOptions is set to Dot and
463 /// the vector.contract op is a row-major matmul or matvec.
464 class ContractionOpToDotLowering
465     : public OpRewritePattern<vector::ContractionOp> {
466 public:
467   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
468   using FilterConstraintType =
469       std::function<LogicalResult(vector::ContractionOp op)>;
470 
defaultFilter(vector::ContractionOp op)471   static LogicalResult defaultFilter(vector::ContractionOp op) {
472     return success();
473   }
474 
475   ContractionOpToDotLowering(
476       vector::VectorTransformsOptions vectorTransformOptions,
477       MLIRContext *context,
478       const FilterConstraintType &constraint = defaultFilter)
479       : OpRewritePattern<vector::ContractionOp>(context),
480         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
481 
482   LogicalResult matchAndRewrite(vector::ContractionOp op,
483                                 PatternRewriter &rewriter) const override;
484 
485 private:
486   /// Options to control the vector patterns.
487   vector::VectorTransformsOptions vectorTransformOptions;
488   FilterConstraintType filter;
489 };
490 
491 /// Progressive lowering of ContractionOp.
492 ///
493 /// One:
494 ///   %x = vector.contract with at least one free/batch dimension
495 /// is replaced by:
496 ///   %a = vector.contract with one less free/batch dimension
497 ///   %b = vector.contract with one less free/batch dimension
498 ///   ..
499 ///   %x = combine %a %b ..
500 /// until a pure contraction is reached (no free/batch dimensions),
501 /// which is replaced by a dot-product.
502 ///
503 /// This only kicks in when either VectorTransformsOptions is set
504 /// to Dot or when other contraction patterns fail.
505 class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
506 public:
507   using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
508   using FilterConstraintType =
509       std::function<LogicalResult(vector::ContractionOp op)>;
510 
defaultFilter(vector::ContractionOp op)511   static LogicalResult defaultFilter(vector::ContractionOp op) {
512     return success();
513   }
514 
515   ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
516                         MLIRContext *context,
517                         FilterConstraintType constraint = defaultFilter)
518       : OpRewritePattern<vector::ContractionOp>(context),
519         vectorTransformOptions(vectorTransformOptions),
520         filter(std::move(constraint)) {}
521 
522   LogicalResult matchAndRewrite(vector::ContractionOp op,
523                                 PatternRewriter &rewriter) const override;
524 
525 private:
526   /// Options to control the vector patterns.
527   vector::VectorTransformsOptions vectorTransformOptions;
528   FilterConstraintType filter;
529   // Lower one parallel dimension.
530   FailureOr<Value> lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
531                                  int64_t rhsIndex,
532                                  PatternRewriter &rewriter) const;
533   // Lower one reduction dimension.
534   FailureOr<Value> lowerReduction(vector::ContractionOp op,
535                                   PatternRewriter &rewriter) const;
536 };
537 
538 } // namespace vector
539 } // namespace mlir
540 
541 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H
542