1*c95acf05SAart Bik //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2*c95acf05SAart Bik //
3*c95acf05SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*c95acf05SAart Bik // See https://llvm.org/LICENSE.txt for license information.
5*c95acf05SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*c95acf05SAart Bik //
7*c95acf05SAart Bik //===----------------------------------------------------------------------===//
8*c95acf05SAart Bik 
9*c95acf05SAart Bik #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
10*c95acf05SAart Bik 
11*c95acf05SAart Bik #include "../PassDetail.h"
12*c95acf05SAart Bik 
13*c95acf05SAart Bik #include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
14*c95acf05SAart Bik #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15*c95acf05SAart Bik #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
16*c95acf05SAart Bik #include "mlir/Dialect/AVX512/AVX512Dialect.h"
17*c95acf05SAart Bik #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
18*c95acf05SAart Bik #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19*c95acf05SAart Bik #include "mlir/Dialect/Vector/VectorOps.h"
20*c95acf05SAart Bik #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21*c95acf05SAart Bik 
22*c95acf05SAart Bik using namespace mlir;
23*c95acf05SAart Bik using namespace mlir::vector;
24*c95acf05SAart Bik 
25*c95acf05SAart Bik namespace {
26*c95acf05SAart Bik struct LowerVectorToLLVMPass
27*c95acf05SAart Bik     : public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
28*c95acf05SAart Bik   LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
29*c95acf05SAart Bik     this->reassociateFPReductions = options.reassociateFPReductions;
30*c95acf05SAart Bik     this->enableIndexOptimizations = options.enableIndexOptimizations;
31*c95acf05SAart Bik     this->enableAVX512 = options.enableAVX512;
32*c95acf05SAart Bik   }
33*c95acf05SAart Bik   void runOnOperation() override;
34*c95acf05SAart Bik };
35*c95acf05SAart Bik } // namespace
36*c95acf05SAart Bik 
37*c95acf05SAart Bik void LowerVectorToLLVMPass::runOnOperation() {
38*c95acf05SAart Bik   // Perform progressive lowering of operations on slices and
39*c95acf05SAart Bik   // all contraction operations. Also applies folding and DCE.
40*c95acf05SAart Bik   {
41*c95acf05SAart Bik     OwningRewritePatternList patterns;
42*c95acf05SAart Bik     populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
43*c95acf05SAart Bik     populateVectorSlicesLoweringPatterns(patterns, &getContext());
44*c95acf05SAart Bik     populateVectorContractLoweringPatterns(patterns, &getContext());
45*c95acf05SAart Bik     applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
46*c95acf05SAart Bik   }
47*c95acf05SAart Bik 
48*c95acf05SAart Bik   // Convert to the LLVM IR dialect.
49*c95acf05SAart Bik   LLVMTypeConverter converter(&getContext());
50*c95acf05SAart Bik   OwningRewritePatternList patterns;
51*c95acf05SAart Bik   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
52*c95acf05SAart Bik   populateVectorToLLVMConversionPatterns(
53*c95acf05SAart Bik       converter, patterns, reassociateFPReductions, enableIndexOptimizations);
54*c95acf05SAart Bik   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
55*c95acf05SAart Bik   populateStdToLLVMConversionPatterns(converter, patterns);
56*c95acf05SAart Bik 
57*c95acf05SAart Bik   // Architecture specific augmentations.
58*c95acf05SAart Bik   LLVMConversionTarget target(getContext());
59*c95acf05SAart Bik   if (enableAVX512) {
60*c95acf05SAart Bik     target.addLegalDialect<LLVM::LLVMAVX512Dialect>();
61*c95acf05SAart Bik     target.addIllegalDialect<avx512::AVX512Dialect>();
62*c95acf05SAart Bik     populateAVX512ToLLVMConversionPatterns(converter, patterns);
63*c95acf05SAart Bik   }
64*c95acf05SAart Bik 
65*c95acf05SAart Bik   if (failed(
66*c95acf05SAart Bik           applyPartialConversion(getOperation(), target, std::move(patterns))))
67*c95acf05SAart Bik     signalPassFailure();
68*c95acf05SAart Bik }
69*c95acf05SAart Bik 
70*c95acf05SAart Bik std::unique_ptr<OperationPass<ModuleOp>>
71*c95acf05SAart Bik mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
72*c95acf05SAart Bik   return std::make_unique<LowerVectorToLLVMPass>(options);
73*c95acf05SAart Bik }
74