1 //===- VectorOps.h - MLIR Vector Dialect Operations -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Vector dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
14 #define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
15 
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/OpDefinition.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Interfaces/ControlFlowInterfaces.h"
23 #include "mlir/Interfaces/InferTypeOpInterface.h"
24 #include "mlir/Interfaces/SideEffectInterfaces.h"
25 #include "mlir/Interfaces/VectorInterfaces.h"
26 #include "mlir/Interfaces/ViewLikeInterface.h"
27 #include "llvm/ADT/StringExtras.h"
28 
29 // Pull in all enum type definitions and utility function declarations.
30 #include "mlir/Dialect/Vector/IR/VectorOpsEnums.h.inc"
31 
32 namespace mlir {
33 class MLIRContext;
34 class RewritePatternSet;
35 
36 namespace arith {
37 enum class AtomicRMWKind : uint64_t;
38 } // namespace arith
39 
40 namespace vector {
41 class TransferReadOp;
42 class TransferWriteOp;
43 class VectorDialect;
44 
45 namespace detail {
46 struct BitmaskEnumStorage;
47 } // namespace detail
48 
49 /// Return whether `srcType` can be broadcast to `dstVectorType` under the
50 /// semantics of the `vector.broadcast` op.
51 enum class BroadcastableToResult {
52   Success = 0,
53   SourceRankHigher = 1,
54   DimensionMismatch = 2,
55   SourceTypeNotAVector = 3
56 };
57 BroadcastableToResult
58 isBroadcastableTo(Type srcType, VectorType dstVectorType,
59                   std::pair<int, int> *mismatchingDims = nullptr);
60 
61 /// Collect a set of vector-to-vector canonicalization patterns.
62 void populateVectorToVectorCanonicalizationPatterns(
63     RewritePatternSet &patterns);
64 
65 /// Collect a set of vector.shape_cast folding patterns.
66 void populateShapeCastFoldingPatterns(RewritePatternSet &patterns);
67 
68 /// Collect a set of leading one dimension removal patterns.
69 ///
70 /// These patterns insert vector.shape_cast to remove leading one dimensions
71 /// to expose more canonical forms of read/write/insert/extract operations.
72 /// With them, there are more chances that we can cancel out extract-insert
73 /// pairs or forward write-read pairs.
74 void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns);
75 
76 /// Collect a set of one dimension removal patterns.
77 ///
78 /// These patterns insert rank-reducing memref.subview ops to remove one
79 /// dimensions. With them, there are more chances that we can avoid
80 /// potentially exensive vector.shape_cast operations.
81 void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns);
82 
83 /// Collect a set of patterns to flatten n-D vector transfers on contiguous
84 /// memref.
85 ///
86 /// These patterns insert memref.collapse_shape + vector.shape_cast patterns
87 /// to transform multiple small n-D transfers into a larger 1-D transfer where
88 /// the memref contiguity properties allow it.
89 void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns);
90 
91 /// Collect a set of patterns that bubble up/down bitcast ops.
92 ///
93 /// These patterns move vector.bitcast ops to be before insert ops or after
94 /// extract ops where suitable. With them, bitcast will happen on smaller
95 /// vectors and there are more chances to share extract/insert ops.
96 void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns);
97 
98 /// Collect a set of transfer read/write lowering patterns.
99 ///
100 /// These patterns lower transfer ops to simpler ops like `vector.load`,
101 /// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank
102 /// of a most `maxTransferRank` are lowered. This is useful when combined with
103 /// VectorToSCF, which reduces the rank of vector transfer ops.
104 void populateVectorTransferLoweringPatterns(
105     RewritePatternSet &patterns,
106     llvm::Optional<unsigned> maxTransferRank = llvm::None);
107 
108 /// These patterns materialize masks for various vector ops such as transfers.
109 void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
110                                                bool force32BitVectorIndices);
111 
112 /// Collect a set of patterns to propagate insert_map/extract_map in the ssa
113 /// chain.
114 void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns);
115 
116 /// An attribute that specifies the combining function for `vector.contract`,
117 /// and `vector.reduction`.
118 class CombiningKindAttr
119     : public Attribute::AttrBase<CombiningKindAttr, Attribute,
120                                  detail::BitmaskEnumStorage> {
121 public:
122   using Base::Base;
123 
124   static CombiningKindAttr get(CombiningKind kind, MLIRContext *context);
125 
126   CombiningKind getKind() const;
127 
128   void print(AsmPrinter &p) const;
129   static Attribute parse(AsmParser &parser, Type type);
130 };
131 
132 /// Collects patterns to progressively lower vector.broadcast ops on high-D
133 /// vectors to low-D vector ops.
134 void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns);
135 
136 /// Collects patterns to progressively lower vector mask ops into elementary
137 /// selection and insertion ops.
138 void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns);
139 
140 /// Collects patterns to progressively lower vector.shape_cast ops on high-D
141 /// vectors into 1-D/2-D vector ops by generating data movement extract/insert
142 /// ops.
143 void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns);
144 
145 /// Returns the integer type required for subscripts in the vector dialect.
146 IntegerType getVectorSubscriptType(Builder &builder);
147 
148 /// Returns an integer array attribute containing the given values using
149 /// the integer type required for subscripts in the vector dialect.
150 ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values);
151 
152 /// Returns the value obtained by reducing the vector into a scalar using the
153 /// operation kind associated with a binary AtomicRMWKind op.
154 Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder,
155                            Location loc, Value vector);
156 
157 /// Return true if the last dimension of the MemRefType has unit stride. Also
158 /// return true for memrefs with no strides.
159 bool isLastMemrefDimUnitStride(MemRefType type);
160 
161 /// Build the default minor identity map suitable for a vector transfer. This
162 /// also handles the case memref<... x vector<...>> -> vector<...> in which the
163 /// rank of the identity map must take the vector element type into account.
164 AffineMap getTransferMinorIdentityMap(ShapedType shapedType,
165                                       VectorType vectorType);
166 
167 /// Return true if the transfer_write fully writes the data accessed by the
168 /// transfer_read.
169 bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read);
170 
171 /// Return true if the write op fully over-write the priorWrite transfer_write
172 /// op.
173 bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite);
174 
175 /// Same behavior as `isDisjointTransferSet` but doesn't require the operations
176 /// to have the same tensor/memref. This allows comparing operations accessing
177 /// different tensors.
178 bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
179                                VectorTransferOpInterface transferB);
180 
181 /// Return true if we can prove that the transfer operations access disjoint
182 /// memory.
183 bool isDisjointTransferSet(VectorTransferOpInterface transferA,
184                            VectorTransferOpInterface transferB);
185 
186 /// Return the result value of reducing two scalar/vector values with the
187 /// corresponding arith operation.
188 Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
189                          Value v1, Value v2);
190 } // namespace vector
191 } // namespace mlir
192 
193 #define GET_OP_CLASSES
194 #include "mlir/Dialect/Vector/IR/VectorOps.h.inc"
195 #include "mlir/Dialect/Vector/IR/VectorOpsDialect.h.inc"
196 
197 #endif // MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
198