1 //===- VectorToSCF.h - Convert vector to SCF dialect ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_
10 #define MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_
11 
12 #include "mlir/IR/PatternMatch.h"
13 
14 namespace mlir {
15 class MLIRContext;
16 class Pass;
17 class RewritePatternSet;
18 
19 /// When lowering an N-d vector transfer op to an (N-1)-d vector transfer op,
20 /// a temporary buffer is created through which individual (N-1)-d vector are
21 /// staged. This pattern can be applied multiple time, until the transfer op
22 /// is 1-d.
23 /// This is consistent with the lack of an LLVM instruction to dynamically
24 /// index into an aggregate (see the Vector dialect lowering to LLVM deep dive).
25 ///
26 /// An instruction such as:
27 /// ```
28 ///    vector.transfer_write %vec, %A[%a, %b, %c] :
29 ///      vector<9x17x15xf32>, memref<?x?x?xf32>
30 /// ```
31 /// Lowers to pseudo-IR resembling (unpacking one dimension):
32 /// ```
33 ///    %0 = alloca() : memref<vector<9x17x15xf32>>
34 ///    store %vec, %0[] : memref<vector<9x17x15xf32>>
35 ///    %1 = vector.type_cast %0 :
36 ///      memref<vector<9x17x15xf32>> to memref<9xvector<17x15xf32>>
37 ///    affine.for %I = 0 to 9 {
38 ///      %dim = dim %A, 0 : memref<?x?x?xf32>
39 ///      %add = affine.apply %I + %a
40 ///      %cmp = arith.cmpi "slt", %add, %dim : index
41 ///      scf.if %cmp {
42 ///        %vec_2d = load %1[%I] : memref<9xvector<17x15xf32>>
43 ///        vector.transfer_write %vec_2d, %A[%add, %b, %c] :
44 ///          vector<17x15xf32>, memref<?x?x?xf32>
45 /// ```
46 ///
47 /// When applying the pattern a second time, the existing alloca() operation
48 /// is reused and only a second vector.type_cast is added.
49 struct VectorTransferToSCFOptions {
50   /// Minimal rank to which vector transfer are lowered.
51   unsigned targetRank = 1;
setTargetRankVectorTransferToSCFOptions52   VectorTransferToSCFOptions &setTargetRank(unsigned r) {
53     targetRank = r;
54     return *this;
55   }
56   ///
57   bool lowerPermutationMaps = false;
58   VectorTransferToSCFOptions &enableLowerPermutationMaps(bool l = true) {
59     lowerPermutationMaps = l;
60     return *this;
61   }
62   /// Allows vector transfers that operated on tensors to be lowered (this is an
63   /// uncommon alternative).
64   bool lowerTensors = false;
65   VectorTransferToSCFOptions &enableLowerTensors(bool l = true) {
66     lowerTensors = l;
67     return *this;
68   }
69   /// Triggers full unrolling (vs iterating with a loop) during transfer to scf.
70   bool unroll = false;
71   VectorTransferToSCFOptions &enableFullUnroll(bool u = true) {
72     unroll = u;
73     return *this;
74   }
75 };
76 
77 /// Collect a set of patterns to convert from the Vector dialect to SCF + func.
78 void populateVectorToSCFConversionPatterns(
79     RewritePatternSet &patterns,
80     const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions());
81 
82 /// Create a pass to convert a subset of vector ops to SCF.
83 std::unique_ptr<Pass> createConvertVectorToSCFPass(
84     const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions());
85 
86 } // namespace mlir
87 
88 #endif // MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_
89