1*c95acf05SAart Bik //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// 2*c95acf05SAart Bik // 3*c95acf05SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*c95acf05SAart Bik // See https://llvm.org/LICENSE.txt for license information. 5*c95acf05SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*c95acf05SAart Bik // 7*c95acf05SAart Bik //===----------------------------------------------------------------------===// 8*c95acf05SAart Bik 9*c95acf05SAart Bik #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 10*c95acf05SAart Bik 11*c95acf05SAart Bik #include "../PassDetail.h" 12*c95acf05SAart Bik 13*c95acf05SAart Bik #include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" 14*c95acf05SAart Bik #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 15*c95acf05SAart Bik #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 16*c95acf05SAart Bik #include "mlir/Dialect/AVX512/AVX512Dialect.h" 17*c95acf05SAart Bik #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" 18*c95acf05SAart Bik #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19*c95acf05SAart Bik #include "mlir/Dialect/Vector/VectorOps.h" 20*c95acf05SAart Bik #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21*c95acf05SAart Bik 22*c95acf05SAart Bik using namespace mlir; 23*c95acf05SAart Bik using namespace mlir::vector; 24*c95acf05SAart Bik 25*c95acf05SAart Bik namespace { 26*c95acf05SAart Bik struct LowerVectorToLLVMPass 27*c95acf05SAart Bik : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> { 28*c95acf05SAart Bik LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { 29*c95acf05SAart Bik this->reassociateFPReductions = options.reassociateFPReductions; 30*c95acf05SAart Bik this->enableIndexOptimizations = options.enableIndexOptimizations; 31*c95acf05SAart Bik this->enableAVX512 = options.enableAVX512; 32*c95acf05SAart Bik } 33*c95acf05SAart Bik void runOnOperation() override; 34*c95acf05SAart Bik }; 35*c95acf05SAart Bik } // namespace 36*c95acf05SAart Bik 37*c95acf05SAart Bik void LowerVectorToLLVMPass::runOnOperation() { 38*c95acf05SAart Bik // Perform progressive lowering of operations on slices and 39*c95acf05SAart Bik // all contraction operations. Also applies folding and DCE. 40*c95acf05SAart Bik { 41*c95acf05SAart Bik OwningRewritePatternList patterns; 42*c95acf05SAart Bik populateVectorToVectorCanonicalizationPatterns(patterns, &getContext()); 43*c95acf05SAart Bik populateVectorSlicesLoweringPatterns(patterns, &getContext()); 44*c95acf05SAart Bik populateVectorContractLoweringPatterns(patterns, &getContext()); 45*c95acf05SAart Bik applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 46*c95acf05SAart Bik } 47*c95acf05SAart Bik 48*c95acf05SAart Bik // Convert to the LLVM IR dialect. 49*c95acf05SAart Bik LLVMTypeConverter converter(&getContext()); 50*c95acf05SAart Bik OwningRewritePatternList patterns; 51*c95acf05SAart Bik populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 52*c95acf05SAart Bik populateVectorToLLVMConversionPatterns( 53*c95acf05SAart Bik converter, patterns, reassociateFPReductions, enableIndexOptimizations); 54*c95acf05SAart Bik populateVectorToLLVMMatrixConversionPatterns(converter, patterns); 55*c95acf05SAart Bik populateStdToLLVMConversionPatterns(converter, patterns); 56*c95acf05SAart Bik 57*c95acf05SAart Bik // Architecture specific augmentations. 58*c95acf05SAart Bik LLVMConversionTarget target(getContext()); 59*c95acf05SAart Bik if (enableAVX512) { 60*c95acf05SAart Bik target.addLegalDialect<LLVM::LLVMAVX512Dialect>(); 61*c95acf05SAart Bik target.addIllegalDialect<avx512::AVX512Dialect>(); 62*c95acf05SAart Bik populateAVX512ToLLVMConversionPatterns(converter, patterns); 63*c95acf05SAart Bik } 64*c95acf05SAart Bik 65*c95acf05SAart Bik if (failed( 66*c95acf05SAart Bik applyPartialConversion(getOperation(), target, std::move(patterns)))) 67*c95acf05SAart Bik signalPassFailure(); 68*c95acf05SAart Bik } 69*c95acf05SAart Bik 70*c95acf05SAart Bik std::unique_ptr<OperationPass<ModuleOp>> 71*c95acf05SAart Bik mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { 72*c95acf05SAart Bik return std::make_unique<LowerVectorToLLVMPass>(options); 73*c95acf05SAart Bik } 74