1 //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/X86Vector/Transforms.h"
10 
11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::x86vector;
20 
21 /// Extracts the "main" vector element type from the given X86Vector operation.
22 template <typename OpTy>
23 static Type getSrcVectorElementType(OpTy op) {
24   return op.src().getType().template cast<VectorType>().getElementType();
25 }
26 template <>
27 Type getSrcVectorElementType(Vp2IntersectOp op) {
28   return op.a().getType().template cast<VectorType>().getElementType();
29 }
30 
31 namespace {
32 
33 /// Base conversion for AVX512 ops that can be lowered to one of the two
34 /// intrinsics based on the bitwidth of their "main" vector element type. This
35 /// relies on the to-LLVM-dialect conversion helpers to correctly pack the
36 /// results of multi-result intrinsic ops.
37 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
38 struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
39   explicit LowerToIntrinsic(LLVMTypeConverter &converter)
40       : OpConversionPattern<OpTy>(converter, &converter.getContext()) {}
41 
42   LLVMTypeConverter &getTypeConverter() const {
43     return *static_cast<LLVMTypeConverter *>(
44         OpConversionPattern<OpTy>::getTypeConverter());
45   }
46 
47   LogicalResult
48   matchAndRewrite(OpTy op, ArrayRef<Value> operands,
49                   ConversionPatternRewriter &rewriter) const override {
50     Type elementType = getSrcVectorElementType<OpTy>(op);
51     unsigned bitwidth = elementType.getIntOrFloatBitWidth();
52     if (bitwidth == 32)
53       return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(),
54                                            operands, getTypeConverter(),
55                                            rewriter);
56     if (bitwidth == 64)
57       return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(),
58                                            operands, getTypeConverter(),
59                                            rewriter);
60     return rewriter.notifyMatchFailure(
61         op, "expected 'src' to be either f32 or f64");
62   }
63 };
64 
65 struct MaskCompressOpConversion
66     : public ConvertOpToLLVMPattern<MaskCompressOp> {
67   using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern;
68 
69   LogicalResult
70   matchAndRewrite(MaskCompressOp op, ArrayRef<Value> operands,
71                   ConversionPatternRewriter &rewriter) const override {
72     MaskCompressOp::Adaptor adaptor(operands);
73     auto opType = adaptor.a().getType();
74 
75     Value src;
76     if (op.src()) {
77       src = adaptor.src();
78     } else if (op.constant_src()) {
79       src = rewriter.create<ConstantOp>(op.getLoc(), opType,
80                                         op.constant_srcAttr());
81     } else {
82       Attribute zeroAttr = rewriter.getZeroAttr(opType);
83       src = rewriter.create<ConstantOp>(op->getLoc(), opType, zeroAttr);
84     }
85 
86     rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.a(),
87                                                     src, adaptor.k());
88 
89     return success();
90   }
91 };
92 
93 /// An entry associating the "main" AVX512 op with its instantiations for
94 /// vectors of 32-bit and 64-bit elements.
95 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
96 struct RegEntry {
97   using MainOp = OpTy;
98   using Intr32Op = Intr32OpTy;
99   using Intr64Op = Intr64OpTy;
100 };
101 
102 /// A container for op association entries facilitating the configuration of
103 /// dialect conversion.
104 template <typename... Args>
105 struct RegistryImpl {
106   /// Registers the patterns specializing the "main" op to one of the
107   /// "intrinsic" ops depending on elemental type.
108   static void registerPatterns(LLVMTypeConverter &converter,
109                                RewritePatternSet &patterns) {
110     patterns
111         .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
112                               typename Args::Intr64Op>...>(converter);
113   }
114 
115   /// Configures the conversion target to lower out "main" ops.
116   static void configureTarget(LLVMConversionTarget &target) {
117     target.addIllegalOp<typename Args::MainOp...>();
118     target.addLegalOp<typename Args::Intr32Op...>();
119     target.addLegalOp<typename Args::Intr64Op...>();
120   }
121 };
122 
123 using Registry = RegistryImpl<
124     RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>,
125     RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>,
126     RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>;
127 
128 } // namespace
129 
130 /// Populate the given list with patterns that convert from X86Vector to LLVM.
131 void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
132     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
133   Registry::registerPatterns(converter, patterns);
134   patterns.add<MaskCompressOpConversion>(converter);
135 }
136 
137 void mlir::configureX86VectorLegalizeForExportTarget(
138     LLVMConversionTarget &target) {
139   Registry::configureTarget(target);
140   target.addLegalOp<MaskCompressIntrOp>();
141   target.addIllegalOp<MaskCompressOp>();
142 }
143