1 //===- ConvertVectorToLLVM.h - Utils to convert from the vector dialect ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 #ifndef MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_ 9 #define MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_ 10 11 #include "mlir/Transforms/DialectConversion.h" 12 13 namespace mlir { 14 class LLVMTypeConverter; 15 class ModuleOp; 16 template <typename T> 17 class OperationPass; 18 19 /// Options to control Vector to LLVM lowering. 20 /// 21 /// This should kept in sync with VectorToLLVM options defined for the 22 /// ConvertVectorToLLVM pass in include/mlir/Conversion/Passes.td 23 struct LowerVectorToLLVMOptions { LowerVectorToLLVMOptionsLowerVectorToLLVMOptions24 LowerVectorToLLVMOptions() {} 25 26 LowerVectorToLLVMOptions &enableReassociateFPReductions(bool b = true) { 27 reassociateFPReductions = b; 28 return *this; 29 } 30 LowerVectorToLLVMOptions &enableIndexOptimizations(bool b = true) { 31 force32BitVectorIndices = b; 32 return *this; 33 } 34 LowerVectorToLLVMOptions &enableArmNeon(bool b = true) { 35 armNeon = b; 36 return *this; 37 } 38 LowerVectorToLLVMOptions &enableArmSVE(bool b = true) { 39 armSVE = b; 40 return *this; 41 } 42 LowerVectorToLLVMOptions &enableAMX(bool b = true) { 43 amx = b; 44 return *this; 45 } 46 LowerVectorToLLVMOptions &enableX86Vector(bool b = true) { 47 x86Vector = b; 48 return *this; 49 } 50 51 bool reassociateFPReductions{false}; 52 bool force32BitVectorIndices{true}; 53 bool armNeon{false}; 54 bool armSVE{false}; 55 bool amx{false}; 56 bool x86Vector{false}; 57 }; 58 59 /// Collect a set of patterns to convert from Vector contractions to LLVM Matrix 60 /// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics 61 /// will be needed when invoking LLVM. 62 void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter, 63 RewritePatternSet &patterns); 64 65 /// Collect a set of patterns to convert from the Vector dialect to LLVM. 66 void populateVectorToLLVMConversionPatterns( 67 LLVMTypeConverter &converter, RewritePatternSet &patterns, 68 bool reassociateFPReductions = false, bool force32BitVectorIndices = false); 69 70 /// Create a pass to convert vector operations to the LLVMIR dialect. 71 std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToLLVMPass( 72 const LowerVectorToLLVMOptions &options = LowerVectorToLLVMOptions()); 73 74 } // namespace mlir 75 76 #endif // MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_ 77