1c95acf05SAart Bik //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2c95acf05SAart Bik //
3c95acf05SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c95acf05SAart Bik // See https://llvm.org/LICENSE.txt for license information.
5c95acf05SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6c95acf05SAart Bik //
7c95acf05SAart Bik //===----------------------------------------------------------------------===//
8c95acf05SAart Bik 
9c95acf05SAart Bik #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10c95acf05SAart Bik 
11c95acf05SAart Bik #include "../PassDetail.h"
12c95acf05SAart Bik 
13c95acf05SAart Bik #include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
14*7310501fSNicolas Vasilache #include "mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h"
15c95acf05SAart Bik #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
16c95acf05SAart Bik #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
17c95acf05SAart Bik #include "mlir/Dialect/AVX512/AVX512Dialect.h"
18*7310501fSNicolas Vasilache #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
19c95acf05SAart Bik #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
20*7310501fSNicolas Vasilache #include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
21c95acf05SAart Bik #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22c95acf05SAart Bik #include "mlir/Dialect/Vector/VectorOps.h"
23c95acf05SAart Bik #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24c95acf05SAart Bik 
25c95acf05SAart Bik using namespace mlir;
26c95acf05SAart Bik using namespace mlir::vector;
27c95acf05SAart Bik 
28c95acf05SAart Bik namespace {
29c95acf05SAart Bik struct LowerVectorToLLVMPass
30c95acf05SAart Bik     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
31c95acf05SAart Bik   LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
32c95acf05SAart Bik     this->reassociateFPReductions = options.reassociateFPReductions;
33c95acf05SAart Bik     this->enableIndexOptimizations = options.enableIndexOptimizations;
34*7310501fSNicolas Vasilache     this->enableArmNeon = options.enableArmNeon;
35c95acf05SAart Bik     this->enableAVX512 = options.enableAVX512;
36c95acf05SAart Bik   }
37*7310501fSNicolas Vasilache   // Override explicitly to allow conditional dialect dependence.
38*7310501fSNicolas Vasilache   void getDependentDialects(DialectRegistry &registry) const override {
39*7310501fSNicolas Vasilache     registry.insert<LLVM::LLVMDialect>();
40*7310501fSNicolas Vasilache     if (enableArmNeon)
41*7310501fSNicolas Vasilache       registry.insert<LLVM::LLVMArmNeonDialect>();
42*7310501fSNicolas Vasilache     if (enableAVX512)
43*7310501fSNicolas Vasilache       registry.insert<LLVM::LLVMAVX512Dialect>();
44*7310501fSNicolas Vasilache   }
45c95acf05SAart Bik   void runOnOperation() override;
46c95acf05SAart Bik };
47c95acf05SAart Bik } // namespace
48c95acf05SAart Bik 
49c95acf05SAart Bik void LowerVectorToLLVMPass::runOnOperation() {
50c95acf05SAart Bik   // Perform progressive lowering of operations on slices and
51c95acf05SAart Bik   // all contraction operations. Also applies folding and DCE.
52c95acf05SAart Bik   {
53c95acf05SAart Bik     OwningRewritePatternList patterns;
54c95acf05SAart Bik     populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
55c95acf05SAart Bik     populateVectorSlicesLoweringPatterns(patterns, &getContext());
56c95acf05SAart Bik     populateVectorContractLoweringPatterns(patterns, &getContext());
57c95acf05SAart Bik     applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
58c95acf05SAart Bik   }
59c95acf05SAart Bik 
60c95acf05SAart Bik   // Convert to the LLVM IR dialect.
61c95acf05SAart Bik   LLVMTypeConverter converter(&getContext());
62c95acf05SAart Bik   OwningRewritePatternList patterns;
63c95acf05SAart Bik   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
64c95acf05SAart Bik   populateVectorToLLVMConversionPatterns(
65c95acf05SAart Bik       converter, patterns, reassociateFPReductions, enableIndexOptimizations);
66c95acf05SAart Bik   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
67c95acf05SAart Bik   populateStdToLLVMConversionPatterns(converter, patterns);
68c95acf05SAart Bik 
69c95acf05SAart Bik   // Architecture specific augmentations.
70c95acf05SAart Bik   LLVMConversionTarget target(getContext());
71*7310501fSNicolas Vasilache   if (enableArmNeon) {
72*7310501fSNicolas Vasilache     target.addLegalDialect<LLVM::LLVMArmNeonDialect>();
73*7310501fSNicolas Vasilache     target.addIllegalDialect<arm_neon::ArmNeonDialect>();
74*7310501fSNicolas Vasilache     populateArmNeonToLLVMConversionPatterns(converter, patterns);
75*7310501fSNicolas Vasilache   }
76c95acf05SAart Bik   if (enableAVX512) {
77c95acf05SAart Bik     target.addLegalDialect<LLVM::LLVMAVX512Dialect>();
78c95acf05SAart Bik     target.addIllegalDialect<avx512::AVX512Dialect>();
79c95acf05SAart Bik     populateAVX512ToLLVMConversionPatterns(converter, patterns);
80c95acf05SAart Bik   }
81c95acf05SAart Bik 
82c95acf05SAart Bik   if (failed(
83c95acf05SAart Bik           applyPartialConversion(getOperation(), target, std::move(patterns))))
84c95acf05SAart Bik     signalPassFailure();
85c95acf05SAart Bik }
86c95acf05SAart Bik 
87c95acf05SAart Bik std::unique_ptr<OperationPass<ModuleOp>>
88c95acf05SAart Bik mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
89c95acf05SAart Bik   return std::make_unique<LowerVectorToLLVMPass>(options);
90c95acf05SAart Bik }
91