1 //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/X86Vector/Transforms.h"
10 
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/PatternMatch.h"
18 
19 using namespace mlir;
20 using namespace mlir::x86vector;
21 
22 /// Extracts the "main" vector element type from the given X86Vector operation.
23 template <typename OpTy>
getSrcVectorElementType(OpTy op)24 static Type getSrcVectorElementType(OpTy op) {
25   return op.getSrc().getType().template cast<VectorType>().getElementType();
26 }
27 template <>
getSrcVectorElementType(Vp2IntersectOp op)28 Type getSrcVectorElementType(Vp2IntersectOp op) {
29   return op.getA().getType().template cast<VectorType>().getElementType();
30 }
31 
32 namespace {
33 
34 /// Base conversion for AVX512 ops that can be lowered to one of the two
35 /// intrinsics based on the bitwidth of their "main" vector element type. This
36 /// relies on the to-LLVM-dialect conversion helpers to correctly pack the
37 /// results of multi-result intrinsic ops.
38 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
39 struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
LowerToIntrinsic__anon338b04f50111::LowerToIntrinsic40   explicit LowerToIntrinsic(LLVMTypeConverter &converter)
41       : OpConversionPattern<OpTy>(converter, &converter.getContext()) {}
42 
getTypeConverter__anon338b04f50111::LowerToIntrinsic43   LLVMTypeConverter &getTypeConverter() const {
44     return *static_cast<LLVMTypeConverter *>(
45         OpConversionPattern<OpTy>::getTypeConverter());
46   }
47 
48   LogicalResult
matchAndRewrite__anon338b04f50111::LowerToIntrinsic49   matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
50                   ConversionPatternRewriter &rewriter) const override {
51     Type elementType = getSrcVectorElementType<OpTy>(op);
52     unsigned bitwidth = elementType.getIntOrFloatBitWidth();
53     if (bitwidth == 32)
54       return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(),
55                                            adaptor.getOperands(),
56                                            getTypeConverter(), rewriter);
57     if (bitwidth == 64)
58       return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(),
59                                            adaptor.getOperands(),
60                                            getTypeConverter(), rewriter);
61     return rewriter.notifyMatchFailure(
62         op, "expected 'src' to be either f32 or f64");
63   }
64 };
65 
66 struct MaskCompressOpConversion
67     : public ConvertOpToLLVMPattern<MaskCompressOp> {
68   using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern;
69 
70   LogicalResult
matchAndRewrite__anon338b04f50111::MaskCompressOpConversion71   matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor,
72                   ConversionPatternRewriter &rewriter) const override {
73     auto opType = adaptor.getA().getType();
74 
75     Value src;
76     if (op.getSrc()) {
77       src = adaptor.getSrc();
78     } else if (op.getConstantSrc()) {
79       src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType,
80                                                op.getConstantSrcAttr());
81     } else {
82       Attribute zeroAttr = rewriter.getZeroAttr(opType);
83       src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr);
84     }
85 
86     rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.getA(),
87                                                     src, adaptor.getK());
88 
89     return success();
90   }
91 };
92 
93 struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
94   using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
95 
96   LogicalResult
matchAndRewrite__anon338b04f50111::RsqrtOpConversion97   matchAndRewrite(RsqrtOp op, OpAdaptor adaptor,
98                   ConversionPatternRewriter &rewriter) const override {
99     auto opType = adaptor.getA().getType();
100     rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.getA());
101     return success();
102   }
103 };
104 
105 struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> {
106   using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern;
107 
108   LogicalResult
matchAndRewrite__anon338b04f50111::DotOpConversion109   matchAndRewrite(DotOp op, OpAdaptor adaptor,
110                   ConversionPatternRewriter &rewriter) const override {
111     auto opType = adaptor.getA().getType();
112     Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8);
113     // Dot product of all elements, broadcasted to all elements.
114     auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff));
115     Value scale =
116         rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr);
117     rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.getA(),
118                                            adaptor.getB(), scale);
119     return success();
120   }
121 };
122 
123 /// An entry associating the "main" AVX512 op with its instantiations for
124 /// vectors of 32-bit and 64-bit elements.
125 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
126 struct RegEntry {
127   using MainOp = OpTy;
128   using Intr32Op = Intr32OpTy;
129   using Intr64Op = Intr64OpTy;
130 };
131 
132 /// A container for op association entries facilitating the configuration of
133 /// dialect conversion.
134 template <typename... Args>
135 struct RegistryImpl {
136   /// Registers the patterns specializing the "main" op to one of the
137   /// "intrinsic" ops depending on elemental type.
registerPatterns__anon338b04f50111::RegistryImpl138   static void registerPatterns(LLVMTypeConverter &converter,
139                                RewritePatternSet &patterns) {
140     patterns
141         .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
142                               typename Args::Intr64Op>...>(converter);
143   }
144 
145   /// Configures the conversion target to lower out "main" ops.
configureTarget__anon338b04f50111::RegistryImpl146   static void configureTarget(LLVMConversionTarget &target) {
147     target.addIllegalOp<typename Args::MainOp...>();
148     target.addLegalOp<typename Args::Intr32Op...>();
149     target.addLegalOp<typename Args::Intr64Op...>();
150   }
151 };
152 
153 using Registry = RegistryImpl<
154     RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>,
155     RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>,
156     RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>;
157 
158 } // namespace
159 
160 /// Populate the given list with patterns that convert from X86Vector to LLVM.
populateX86VectorLegalizeForLLVMExportPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)161 void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
162     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
163   Registry::registerPatterns(converter, patterns);
164   patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>(
165       converter);
166 }
167 
configureX86VectorLegalizeForExportTarget(LLVMConversionTarget & target)168 void mlir::configureX86VectorLegalizeForExportTarget(
169     LLVMConversionTarget &target) {
170   Registry::configureTarget(target);
171   target.addLegalOp<MaskCompressIntrOp>();
172   target.addIllegalOp<MaskCompressOp>();
173   target.addLegalOp<RsqrtIntrOp>();
174   target.addIllegalOp<RsqrtOp>();
175   target.addLegalOp<DotIntrOp>();
176   target.addIllegalOp<DotOp>();
177 }
178