1 //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/X86Vector/Transforms.h"
10 
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/PatternMatch.h"
19 
20 using namespace mlir;
21 using namespace mlir::x86vector;
22 
23 /// Extracts the "main" vector element type from the given X86Vector operation.
24 template <typename OpTy>
25 static Type getSrcVectorElementType(OpTy op) {
26   return op.src().getType().template cast<VectorType>().getElementType();
27 }
28 template <>
29 Type getSrcVectorElementType(Vp2IntersectOp op) {
30   return op.a().getType().template cast<VectorType>().getElementType();
31 }
32 
33 namespace {
34 
35 /// Base conversion for AVX512 ops that can be lowered to one of the two
36 /// intrinsics based on the bitwidth of their "main" vector element type. This
37 /// relies on the to-LLVM-dialect conversion helpers to correctly pack the
38 /// results of multi-result intrinsic ops.
39 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
40 struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
41   explicit LowerToIntrinsic(LLVMTypeConverter &converter)
42       : OpConversionPattern<OpTy>(converter, &converter.getContext()) {}
43 
44   LLVMTypeConverter &getTypeConverter() const {
45     return *static_cast<LLVMTypeConverter *>(
46         OpConversionPattern<OpTy>::getTypeConverter());
47   }
48 
49   LogicalResult
50   matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
51                   ConversionPatternRewriter &rewriter) const override {
52     Type elementType = getSrcVectorElementType<OpTy>(op);
53     unsigned bitwidth = elementType.getIntOrFloatBitWidth();
54     if (bitwidth == 32)
55       return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(),
56                                            adaptor.getOperands(),
57                                            getTypeConverter(), rewriter);
58     if (bitwidth == 64)
59       return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(),
60                                            adaptor.getOperands(),
61                                            getTypeConverter(), rewriter);
62     return rewriter.notifyMatchFailure(
63         op, "expected 'src' to be either f32 or f64");
64   }
65 };
66 
67 struct MaskCompressOpConversion
68     : public ConvertOpToLLVMPattern<MaskCompressOp> {
69   using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern;
70 
71   LogicalResult
72   matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor,
73                   ConversionPatternRewriter &rewriter) const override {
74     auto opType = adaptor.a().getType();
75 
76     Value src;
77     if (op.src()) {
78       src = adaptor.src();
79     } else if (op.constant_src()) {
80       src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType,
81                                                op.constant_srcAttr());
82     } else {
83       Attribute zeroAttr = rewriter.getZeroAttr(opType);
84       src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr);
85     }
86 
87     rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.a(),
88                                                     src, adaptor.k());
89 
90     return success();
91   }
92 };
93 
94 struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
95   using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
96 
97   LogicalResult
98   matchAndRewrite(RsqrtOp op, OpAdaptor adaptor,
99                   ConversionPatternRewriter &rewriter) const override {
100     auto opType = adaptor.a().getType();
101     rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.a());
102     return success();
103   }
104 };
105 
106 struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> {
107   using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern;
108 
109   LogicalResult
110   matchAndRewrite(DotOp op, OpAdaptor adaptor,
111                   ConversionPatternRewriter &rewriter) const override {
112     auto opType = adaptor.a().getType();
113     Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8);
114     // Dot product of all elements, broadcasted to all elements.
115     auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff));
116     Value scale =
117         rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr);
118     rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.a(), adaptor.b(),
119                                            scale);
120     return success();
121   }
122 };
123 
124 /// An entry associating the "main" AVX512 op with its instantiations for
125 /// vectors of 32-bit and 64-bit elements.
126 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
127 struct RegEntry {
128   using MainOp = OpTy;
129   using Intr32Op = Intr32OpTy;
130   using Intr64Op = Intr64OpTy;
131 };
132 
133 /// A container for op association entries facilitating the configuration of
134 /// dialect conversion.
135 template <typename... Args>
136 struct RegistryImpl {
137   /// Registers the patterns specializing the "main" op to one of the
138   /// "intrinsic" ops depending on elemental type.
139   static void registerPatterns(LLVMTypeConverter &converter,
140                                RewritePatternSet &patterns) {
141     patterns
142         .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
143                               typename Args::Intr64Op>...>(converter);
144   }
145 
146   /// Configures the conversion target to lower out "main" ops.
147   static void configureTarget(LLVMConversionTarget &target) {
148     target.addIllegalOp<typename Args::MainOp...>();
149     target.addLegalOp<typename Args::Intr32Op...>();
150     target.addLegalOp<typename Args::Intr64Op...>();
151   }
152 };
153 
154 using Registry = RegistryImpl<
155     RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>,
156     RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>,
157     RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>;
158 
159 } // namespace
160 
161 /// Populate the given list with patterns that convert from X86Vector to LLVM.
162 void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
163     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
164   Registry::registerPatterns(converter, patterns);
165   patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>(
166       converter);
167 }
168 
169 void mlir::configureX86VectorLegalizeForExportTarget(
170     LLVMConversionTarget &target) {
171   Registry::configureTarget(target);
172   target.addLegalOp<MaskCompressIntrOp>();
173   target.addIllegalOp<MaskCompressOp>();
174   target.addLegalOp<RsqrtIntrOp>();
175   target.addIllegalOp<RsqrtOp>();
176   target.addLegalOp<DotIntrOp>();
177   target.addIllegalOp<DotOp>();
178 }
179