1 //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/X86Vector/Transforms.h"
10 
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/PatternMatch.h"
18 
19 using namespace mlir;
20 using namespace mlir::x86vector;
21 
22 /// Extracts the "main" vector element type from the given X86Vector operation.
23 template <typename OpTy>
24 static Type getSrcVectorElementType(OpTy op) {
25   return op.src().getType().template cast<VectorType>().getElementType();
26 }
27 template <>
28 Type getSrcVectorElementType(Vp2IntersectOp op) {
29   return op.a().getType().template cast<VectorType>().getElementType();
30 }
31 
32 namespace {
33 
34 /// Base conversion for AVX512 ops that can be lowered to one of the two
35 /// intrinsics based on the bitwidth of their "main" vector element type. This
36 /// relies on the to-LLVM-dialect conversion helpers to correctly pack the
37 /// results of multi-result intrinsic ops.
38 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
39 struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
40   explicit LowerToIntrinsic(LLVMTypeConverter &converter)
41       : OpConversionPattern<OpTy>(converter, &converter.getContext()) {}
42 
43   LLVMTypeConverter &getTypeConverter() const {
44     return *static_cast<LLVMTypeConverter *>(
45         OpConversionPattern<OpTy>::getTypeConverter());
46   }
47 
48   LogicalResult
49   matchAndRewrite(OpTy op, ArrayRef<Value> operands,
50                   ConversionPatternRewriter &rewriter) const override {
51     Type elementType = getSrcVectorElementType<OpTy>(op);
52     unsigned bitwidth = elementType.getIntOrFloatBitWidth();
53     if (bitwidth == 32)
54       return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(),
55                                            operands, getTypeConverter(),
56                                            rewriter);
57     if (bitwidth == 64)
58       return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(),
59                                            operands, getTypeConverter(),
60                                            rewriter);
61     return rewriter.notifyMatchFailure(
62         op, "expected 'src' to be either f32 or f64");
63   }
64 };
65 
66 struct MaskCompressOpConversion
67     : public ConvertOpToLLVMPattern<MaskCompressOp> {
68   using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern;
69 
70   LogicalResult
71   matchAndRewrite(MaskCompressOp op, ArrayRef<Value> operands,
72                   ConversionPatternRewriter &rewriter) const override {
73     MaskCompressOp::Adaptor adaptor(operands);
74     auto opType = adaptor.a().getType();
75 
76     Value src;
77     if (op.src()) {
78       src = adaptor.src();
79     } else if (op.constant_src()) {
80       src = rewriter.create<ConstantOp>(op.getLoc(), opType,
81                                         op.constant_srcAttr());
82     } else {
83       Attribute zeroAttr = rewriter.getZeroAttr(opType);
84       src = rewriter.create<ConstantOp>(op->getLoc(), opType, zeroAttr);
85     }
86 
87     rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.a(),
88                                                     src, adaptor.k());
89 
90     return success();
91   }
92 };
93 
94 struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
95   using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
96 
97   LogicalResult
98   matchAndRewrite(RsqrtOp op, ArrayRef<Value> operands,
99                   ConversionPatternRewriter &rewriter) const override {
100     RsqrtOp::Adaptor adaptor(operands);
101 
102     auto opType = adaptor.a().getType();
103     rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.a());
104     return success();
105   }
106 };
107 
108 struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> {
109   using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern;
110 
111   LogicalResult
112   matchAndRewrite(DotOp op, ArrayRef<Value> operands,
113                   ConversionPatternRewriter &rewriter) const override {
114     DotOp::Adaptor adaptor(operands);
115     auto opType = adaptor.a().getType();
116     Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8);
117     // Dot product of all elements, broadcasted to all elements.
118     auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff));
119     Value scale =
120         rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr);
121     rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.a(), adaptor.b(),
122                                            scale);
123     return success();
124   }
125 };
126 
127 /// An entry associating the "main" AVX512 op with its instantiations for
128 /// vectors of 32-bit and 64-bit elements.
129 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
130 struct RegEntry {
131   using MainOp = OpTy;
132   using Intr32Op = Intr32OpTy;
133   using Intr64Op = Intr64OpTy;
134 };
135 
136 /// A container for op association entries facilitating the configuration of
137 /// dialect conversion.
138 template <typename... Args>
139 struct RegistryImpl {
140   /// Registers the patterns specializing the "main" op to one of the
141   /// "intrinsic" ops depending on elemental type.
142   static void registerPatterns(LLVMTypeConverter &converter,
143                                RewritePatternSet &patterns) {
144     patterns
145         .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
146                               typename Args::Intr64Op>...>(converter);
147   }
148 
149   /// Configures the conversion target to lower out "main" ops.
150   static void configureTarget(LLVMConversionTarget &target) {
151     target.addIllegalOp<typename Args::MainOp...>();
152     target.addLegalOp<typename Args::Intr32Op...>();
153     target.addLegalOp<typename Args::Intr64Op...>();
154   }
155 };
156 
157 using Registry = RegistryImpl<
158     RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>,
159     RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>,
160     RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>;
161 
162 } // namespace
163 
164 /// Populate the given list with patterns that convert from X86Vector to LLVM.
165 void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
166     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
167   Registry::registerPatterns(converter, patterns);
168   patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>(
169       converter);
170 }
171 
172 void mlir::configureX86VectorLegalizeForExportTarget(
173     LLVMConversionTarget &target) {
174   Registry::configureTarget(target);
175   target.addLegalOp<MaskCompressIntrOp>();
176   target.addIllegalOp<MaskCompressOp>();
177   target.addLegalOp<RsqrtIntrOp>();
178   target.addIllegalOp<RsqrtOp>();
179   target.addLegalOp<DotIntrOp>();
180   target.addIllegalOp<DotOp>();
181 }
182