1 //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10 
11 #include "../PassDetail.h"
12 
13 #include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
14 #include "mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h"
15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
16 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
17 #include "mlir/Dialect/AVX512/AVX512Dialect.h"
18 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
19 #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
20 #include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22 #include "mlir/Dialect/Vector/VectorOps.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 using namespace mlir;
26 using namespace mlir::vector;
27 
28 namespace {
29 struct LowerVectorToLLVMPass
30     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
31   LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
32     this->reassociateFPReductions = options.reassociateFPReductions;
33     this->enableIndexOptimizations = options.enableIndexOptimizations;
34     this->enableArmNeon = options.enableArmNeon;
35     this->enableAVX512 = options.enableAVX512;
36   }
37   // Override explicitly to allow conditional dialect dependence.
38   void getDependentDialects(DialectRegistry &registry) const override {
39     registry.insert<LLVM::LLVMDialect>();
40     if (enableArmNeon)
41       registry.insert<LLVM::LLVMArmNeonDialect>();
42     if (enableAVX512)
43       registry.insert<LLVM::LLVMAVX512Dialect>();
44   }
45   void runOnOperation() override;
46 };
47 } // namespace
48 
49 void LowerVectorToLLVMPass::runOnOperation() {
50   // Perform progressive lowering of operations on slices and
51   // all contraction operations. Also applies folding and DCE.
52   {
53     OwningRewritePatternList patterns;
54     populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
55     populateVectorSlicesLoweringPatterns(patterns, &getContext());
56     populateVectorContractLoweringPatterns(patterns, &getContext());
57     applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
58   }
59 
60   // Convert to the LLVM IR dialect.
61   LLVMTypeConverter converter(&getContext());
62   OwningRewritePatternList patterns;
63   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
64   populateVectorToLLVMConversionPatterns(
65       converter, patterns, reassociateFPReductions, enableIndexOptimizations);
66   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
67   populateStdToLLVMConversionPatterns(converter, patterns);
68 
69   // Architecture specific augmentations.
70   LLVMConversionTarget target(getContext());
71   if (enableArmNeon) {
72     target.addLegalDialect<LLVM::LLVMArmNeonDialect>();
73     target.addIllegalDialect<arm_neon::ArmNeonDialect>();
74     populateArmNeonToLLVMConversionPatterns(converter, patterns);
75   }
76   if (enableAVX512) {
77     target.addLegalDialect<LLVM::LLVMAVX512Dialect>();
78     target.addIllegalDialect<avx512::AVX512Dialect>();
79     populateAVX512ToLLVMConversionPatterns(converter, patterns);
80   }
81 
82   if (failed(
83           applyPartialConversion(getOperation(), target, std::move(patterns))))
84     signalPassFailure();
85 }
86 
87 std::unique_ptr<OperationPass<ModuleOp>>
88 mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
89   return std::make_unique<LowerVectorToLLVMPass>(options);
90 }
91