1 //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/X86Vector/Transforms.h"
10 
11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::x86vector;
20 
21 /// Extracts the "main" vector element type from the given X86Vector operation.
22 template <typename OpTy>
23 static Type getSrcVectorElementType(OpTy op) {
24   return op.src().getType().template cast<VectorType>().getElementType();
25 }
26 template <>
27 Type getSrcVectorElementType(Vp2IntersectOp op) {
28   return op.a().getType().template cast<VectorType>().getElementType();
29 }
30 
31 namespace {
32 
33 /// Base conversion for AVX512 ops that can be lowered to one of the two
34 /// intrinsics based on the bitwidth of their "main" vector element type. This
35 /// relies on the to-LLVM-dialect conversion helpers to correctly pack the
36 /// results of multi-result intrinsic ops.
37 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
38 struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
39   explicit LowerToIntrinsic(LLVMTypeConverter &converter)
40       : OpConversionPattern<OpTy>(converter, &converter.getContext()) {}
41 
42   LLVMTypeConverter &getTypeConverter() const {
43     return *static_cast<LLVMTypeConverter *>(
44         OpConversionPattern<OpTy>::getTypeConverter());
45   }
46 
47   LogicalResult
48   matchAndRewrite(OpTy op, ArrayRef<Value> operands,
49                   ConversionPatternRewriter &rewriter) const override {
50     Type elementType = getSrcVectorElementType<OpTy>(op);
51     unsigned bitwidth = elementType.getIntOrFloatBitWidth();
52     if (bitwidth == 32)
53       return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(),
54                                            operands, getTypeConverter(),
55                                            rewriter);
56     if (bitwidth == 64)
57       return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(),
58                                            operands, getTypeConverter(),
59                                            rewriter);
60     return rewriter.notifyMatchFailure(
61         op, "expected 'src' to be either f32 or f64");
62   }
63 };
64 
65 struct MaskCompressOpConversion
66     : public ConvertOpToLLVMPattern<MaskCompressOp> {
67   using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern;
68 
69   LogicalResult
70   matchAndRewrite(MaskCompressOp op, ArrayRef<Value> operands,
71                   ConversionPatternRewriter &rewriter) const override {
72     MaskCompressOp::Adaptor adaptor(operands);
73     auto opType = adaptor.a().getType();
74 
75     Value src;
76     if (op.src()) {
77       src = adaptor.src();
78     } else if (op.constant_src()) {
79       src = rewriter.create<ConstantOp>(op.getLoc(), opType,
80                                         op.constant_srcAttr());
81     } else {
82       Attribute zeroAttr = rewriter.getZeroAttr(opType);
83       src = rewriter.create<ConstantOp>(op->getLoc(), opType, zeroAttr);
84     }
85 
86     rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.a(),
87                                                     src, adaptor.k());
88 
89     return success();
90   }
91 };
92 
93 struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
94   using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
95 
96   LogicalResult
97   matchAndRewrite(RsqrtOp op, ArrayRef<Value> operands,
98                   ConversionPatternRewriter &rewriter) const override {
99     RsqrtOp::Adaptor adaptor(operands);
100 
101     auto opType = adaptor.a().getType();
102     rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.a());
103     return success();
104   }
105 };
106 
107 struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> {
108   using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern;
109 
110   LogicalResult
111   matchAndRewrite(DotOp op, ArrayRef<Value> operands,
112                   ConversionPatternRewriter &rewriter) const override {
113     DotOp::Adaptor adaptor(operands);
114     auto opType = adaptor.a().getType();
115     Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8);
116     // Dot product of all elements, broadcasted to all elements.
117     auto attr = rewriter.getI8IntegerAttr(0xff);
118     Value scale =
119         rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr);
120     rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.a(), adaptor.b(),
121                                            scale);
122     return success();
123   }
124 };
125 
126 /// An entry associating the "main" AVX512 op with its instantiations for
127 /// vectors of 32-bit and 64-bit elements.
128 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
129 struct RegEntry {
130   using MainOp = OpTy;
131   using Intr32Op = Intr32OpTy;
132   using Intr64Op = Intr64OpTy;
133 };
134 
135 /// A container for op association entries facilitating the configuration of
136 /// dialect conversion.
137 template <typename... Args>
138 struct RegistryImpl {
139   /// Registers the patterns specializing the "main" op to one of the
140   /// "intrinsic" ops depending on elemental type.
141   static void registerPatterns(LLVMTypeConverter &converter,
142                                RewritePatternSet &patterns) {
143     patterns
144         .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
145                               typename Args::Intr64Op>...>(converter);
146   }
147 
148   /// Configures the conversion target to lower out "main" ops.
149   static void configureTarget(LLVMConversionTarget &target) {
150     target.addIllegalOp<typename Args::MainOp...>();
151     target.addLegalOp<typename Args::Intr32Op...>();
152     target.addLegalOp<typename Args::Intr64Op...>();
153   }
154 };
155 
156 using Registry = RegistryImpl<
157     RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>,
158     RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>,
159     RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>;
160 
161 } // namespace
162 
163 /// Populate the given list with patterns that convert from X86Vector to LLVM.
164 void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
165     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
166   Registry::registerPatterns(converter, patterns);
167   patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>(
168       converter);
169 }
170 
171 void mlir::configureX86VectorLegalizeForExportTarget(
172     LLVMConversionTarget &target) {
173   Registry::configureTarget(target);
174   target.addLegalOp<MaskCompressIntrOp>();
175   target.addIllegalOp<MaskCompressOp>();
176   target.addLegalOp<RsqrtIntrOp>();
177   target.addIllegalOp<RsqrtOp>();
178   target.addLegalOp<DotIntrOp>();
179   target.addIllegalOp<DotOp>();
180 }
181