1 //===- VectorPattern.h - Conversion pattern to the LLVM dialect -*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H 10 #define MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H 11 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Transforms/DialectConversion.h" 14 15 namespace mlir { 16 17 namespace LLVM { 18 namespace detail { 19 // Helper struct to "unroll" operations on n-D vectors in terms of operations on 20 // 1-D LLVM vectors. 21 struct NDVectorTypeInfo { 22 // LLVM array struct which encodes n-D vectors. 23 Type llvmNDVectorTy; 24 // LLVM vector type which encodes the inner 1-D vector type. 25 Type llvm1DVectorTy; 26 // Multiplicity of llvmNDVectorTy to llvm1DVectorTy. 27 SmallVector<int64_t, 4> arraySizes; 28 }; 29 30 // For >1-D vector types, extracts the necessary information to iterate over all 31 // 1-D subvectors in the underlying llrepresentation of the n-D vector 32 // Iterates on the llvm array type until we hit a non-array type (which is 33 // asserted to be an llvm vector type). 34 NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, 35 LLVMTypeConverter &converter); 36 37 // Express `linearIndex` in terms of coordinates of `basis`. 38 // Returns the empty vector when linearIndex is out of the range [0, P] where 39 // P is the product of all the basis coordinates. 40 // 41 // Prerequisites: 42 // Basis is an array of nonnegative integers (signed type inherited from 43 // vector shape type). 44 SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis, 45 unsigned linearIndex); 46 47 // Iterate of linear index, convert to coords space and insert splatted 1-D 48 // vector in each position. 49 void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, 50 function_ref<void(ArrayAttr)> fun); 51 52 LogicalResult handleMultidimensionalVectors( 53 Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, 54 std::function<Value(Type, ValueRange)> createOperand, 55 ConversionPatternRewriter &rewriter); 56 57 LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, 58 ValueRange operands, 59 LLVMTypeConverter &typeConverter, 60 ConversionPatternRewriter &rewriter); 61 } // namespace detail 62 } // namespace LLVM 63 64 /// Basic lowering implementation to rewrite Ops with just one result to the 65 /// LLVM Dialect. This supports higher-dimensional vector types. 66 template <typename SourceOp, typename TargetOp> 67 class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> { 68 public: 69 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; 70 using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>; 71 72 LogicalResult matchAndRewrite(SourceOp op,typename SourceOp::Adaptor adaptor,ConversionPatternRewriter & rewriter)73 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, 74 ConversionPatternRewriter &rewriter) const override { 75 static_assert( 76 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value, 77 "expected single result op"); 78 return LLVM::detail::vectorOneToOneRewrite( 79 op, TargetOp::getOperationName(), adaptor.getOperands(), 80 *this->getTypeConverter(), rewriter); 81 } 82 }; 83 } // namespace mlir 84 85 #endif // MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H 86