1 //===- VectorOps.h - MLIR Vector Dialect Operations -------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the Vector dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_VECTOR_IR_VECTOROPS_H 14 #define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H 15 16 #include "mlir/IR/AffineMap.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/Dialect.h" 20 #include "mlir/IR/OpDefinition.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Interfaces/ControlFlowInterfaces.h" 23 #include "mlir/Interfaces/InferTypeOpInterface.h" 24 #include "mlir/Interfaces/SideEffectInterfaces.h" 25 #include "mlir/Interfaces/VectorInterfaces.h" 26 #include "mlir/Interfaces/ViewLikeInterface.h" 27 #include "llvm/ADT/StringExtras.h" 28 29 // Pull in all enum type definitions and utility function declarations. 30 #include "mlir/Dialect/Vector/IR/VectorOpsEnums.h.inc" 31 32 namespace mlir { 33 class MLIRContext; 34 class RewritePatternSet; 35 36 namespace arith { 37 enum class AtomicRMWKind : uint64_t; 38 } // namespace arith 39 40 namespace vector { 41 class TransferReadOp; 42 class TransferWriteOp; 43 class VectorDialect; 44 45 namespace detail { 46 struct BitmaskEnumStorage; 47 } // namespace detail 48 49 /// Return whether `srcType` can be broadcast to `dstVectorType` under the 50 /// semantics of the `vector.broadcast` op. 51 enum class BroadcastableToResult { 52 Success = 0, 53 SourceRankHigher = 1, 54 DimensionMismatch = 2, 55 SourceTypeNotAVector = 3 56 }; 57 BroadcastableToResult 58 isBroadcastableTo(Type srcType, VectorType dstVectorType, 59 std::pair<int, int> *mismatchingDims = nullptr); 60 61 /// Collect a set of vector-to-vector canonicalization patterns. 62 void populateVectorToVectorCanonicalizationPatterns( 63 RewritePatternSet &patterns); 64 65 /// Collect a set of vector.shape_cast folding patterns. 66 void populateShapeCastFoldingPatterns(RewritePatternSet &patterns); 67 68 /// Collect a set of leading one dimension removal patterns. 69 /// 70 /// These patterns insert vector.shape_cast to remove leading one dimensions 71 /// to expose more canonical forms of read/write/insert/extract operations. 72 /// With them, there are more chances that we can cancel out extract-insert 73 /// pairs or forward write-read pairs. 74 void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns); 75 76 /// Collect a set of one dimension removal patterns. 77 /// 78 /// These patterns insert rank-reducing memref.subview ops to remove one 79 /// dimensions. With them, there are more chances that we can avoid 80 /// potentially exensive vector.shape_cast operations. 81 void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns); 82 83 /// Collect a set of patterns to flatten n-D vector transfers on contiguous 84 /// memref. 85 /// 86 /// These patterns insert memref.collapse_shape + vector.shape_cast patterns 87 /// to transform multiple small n-D transfers into a larger 1-D transfer where 88 /// the memref contiguity properties allow it. 89 void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns); 90 91 /// Collect a set of patterns that bubble up/down bitcast ops. 92 /// 93 /// These patterns move vector.bitcast ops to be before insert ops or after 94 /// extract ops where suitable. With them, bitcast will happen on smaller 95 /// vectors and there are more chances to share extract/insert ops. 96 void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns); 97 98 /// Collect a set of transfer read/write lowering patterns. 99 /// 100 /// These patterns lower transfer ops to simpler ops like `vector.load`, 101 /// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank 102 /// of a most `maxTransferRank` are lowered. This is useful when combined with 103 /// VectorToSCF, which reduces the rank of vector transfer ops. 104 void populateVectorTransferLoweringPatterns( 105 RewritePatternSet &patterns, 106 llvm::Optional<unsigned> maxTransferRank = llvm::None); 107 108 /// These patterns materialize masks for various vector ops such as transfers. 109 void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, 110 bool force32BitVectorIndices); 111 112 /// Collect a set of patterns to propagate insert_map/extract_map in the ssa 113 /// chain. 114 void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns); 115 116 /// An attribute that specifies the combining function for `vector.contract`, 117 /// and `vector.reduction`. 118 class CombiningKindAttr 119 : public Attribute::AttrBase<CombiningKindAttr, Attribute, 120 detail::BitmaskEnumStorage> { 121 public: 122 using Base::Base; 123 124 static CombiningKindAttr get(CombiningKind kind, MLIRContext *context); 125 126 CombiningKind getKind() const; 127 128 void print(AsmPrinter &p) const; 129 static Attribute parse(AsmParser &parser, Type type); 130 }; 131 132 /// Collects patterns to progressively lower vector.broadcast ops on high-D 133 /// vectors to low-D vector ops. 134 void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns); 135 136 /// Collects patterns to progressively lower vector mask ops into elementary 137 /// selection and insertion ops. 138 void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns); 139 140 /// Collects patterns to progressively lower vector.shape_cast ops on high-D 141 /// vectors into 1-D/2-D vector ops by generating data movement extract/insert 142 /// ops. 143 void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns); 144 145 /// Returns the integer type required for subscripts in the vector dialect. 146 IntegerType getVectorSubscriptType(Builder &builder); 147 148 /// Returns an integer array attribute containing the given values using 149 /// the integer type required for subscripts in the vector dialect. 150 ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values); 151 152 /// Returns the value obtained by reducing the vector into a scalar using the 153 /// operation kind associated with a binary AtomicRMWKind op. 154 Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, 155 Location loc, Value vector); 156 157 /// Return true if the last dimension of the MemRefType has unit stride. Also 158 /// return true for memrefs with no strides. 159 bool isLastMemrefDimUnitStride(MemRefType type); 160 161 /// Build the default minor identity map suitable for a vector transfer. This 162 /// also handles the case memref<... x vector<...>> -> vector<...> in which the 163 /// rank of the identity map must take the vector element type into account. 164 AffineMap getTransferMinorIdentityMap(ShapedType shapedType, 165 VectorType vectorType); 166 167 /// Return true if the transfer_write fully writes the data accessed by the 168 /// transfer_read. 169 bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read); 170 171 /// Return true if the write op fully over-write the priorWrite transfer_write 172 /// op. 173 bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite); 174 175 /// Same behavior as `isDisjointTransferSet` but doesn't require the operations 176 /// to have the same tensor/memref. This allows comparing operations accessing 177 /// different tensors. 178 bool isDisjointTransferIndices(VectorTransferOpInterface transferA, 179 VectorTransferOpInterface transferB); 180 181 /// Return true if we can prove that the transfer operations access disjoint 182 /// memory. 183 bool isDisjointTransferSet(VectorTransferOpInterface transferA, 184 VectorTransferOpInterface transferB); 185 186 /// Return the result value of reducing two scalar/vector values with the 187 /// corresponding arith operation. 188 Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, 189 Value v1, Value v2); 190 } // namespace vector 191 } // namespace mlir 192 193 #define GET_OP_CLASSES 194 #include "mlir/Dialect/Vector/IR/VectorOps.h.inc" 195 #include "mlir/Dialect/Vector/IR/VectorOpsDialect.h.inc" 196 197 #endif // MLIR_DIALECT_VECTOR_IR_VECTOROPS_H 198