18508a63bSEmilio Cota //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===//
28508a63bSEmilio Cota //
38508a63bSEmilio Cota // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48508a63bSEmilio Cota // See https://llvm.org/LICENSE.txt for license information.
58508a63bSEmilio Cota // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68508a63bSEmilio Cota //
78508a63bSEmilio Cota //===----------------------------------------------------------------------===//
88508a63bSEmilio Cota 
98508a63bSEmilio Cota #include "mlir/Dialect/X86Vector/Transforms.h"
108508a63bSEmilio Cota 
1175e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1275e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
13a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
148508a63bSEmilio Cota #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
158508a63bSEmilio Cota #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
168508a63bSEmilio Cota #include "mlir/IR/BuiltinOps.h"
178508a63bSEmilio Cota #include "mlir/IR/PatternMatch.h"
188508a63bSEmilio Cota 
198508a63bSEmilio Cota using namespace mlir;
208508a63bSEmilio Cota using namespace mlir::x86vector;
218508a63bSEmilio Cota 
228508a63bSEmilio Cota /// Extracts the "main" vector element type from the given X86Vector operation.
238508a63bSEmilio Cota template <typename OpTy>
getSrcVectorElementType(OpTy op)248508a63bSEmilio Cota static Type getSrcVectorElementType(OpTy op) {
25*8df54a6aSJacques Pienaar   return op.getSrc().getType().template cast<VectorType>().getElementType();
268508a63bSEmilio Cota }
278508a63bSEmilio Cota template <>
getSrcVectorElementType(Vp2IntersectOp op)288508a63bSEmilio Cota Type getSrcVectorElementType(Vp2IntersectOp op) {
29*8df54a6aSJacques Pienaar   return op.getA().getType().template cast<VectorType>().getElementType();
308508a63bSEmilio Cota }
318508a63bSEmilio Cota 
328508a63bSEmilio Cota namespace {
338508a63bSEmilio Cota 
348508a63bSEmilio Cota /// Base conversion for AVX512 ops that can be lowered to one of the two
358508a63bSEmilio Cota /// intrinsics based on the bitwidth of their "main" vector element type. This
368508a63bSEmilio Cota /// relies on the to-LLVM-dialect conversion helpers to correctly pack the
378508a63bSEmilio Cota /// results of multi-result intrinsic ops.
388508a63bSEmilio Cota template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
398508a63bSEmilio Cota struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
LowerToIntrinsic__anon338b04f50111::LowerToIntrinsic408508a63bSEmilio Cota   explicit LowerToIntrinsic(LLVMTypeConverter &converter)
418508a63bSEmilio Cota       : OpConversionPattern<OpTy>(converter, &converter.getContext()) {}
428508a63bSEmilio Cota 
getTypeConverter__anon338b04f50111::LowerToIntrinsic438508a63bSEmilio Cota   LLVMTypeConverter &getTypeConverter() const {
448508a63bSEmilio Cota     return *static_cast<LLVMTypeConverter *>(
458508a63bSEmilio Cota         OpConversionPattern<OpTy>::getTypeConverter());
468508a63bSEmilio Cota   }
478508a63bSEmilio Cota 
488508a63bSEmilio Cota   LogicalResult
matchAndRewrite__anon338b04f50111::LowerToIntrinsic49b54c724bSRiver Riddle   matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
508508a63bSEmilio Cota                   ConversionPatternRewriter &rewriter) const override {
518508a63bSEmilio Cota     Type elementType = getSrcVectorElementType<OpTy>(op);
528508a63bSEmilio Cota     unsigned bitwidth = elementType.getIntOrFloatBitWidth();
538508a63bSEmilio Cota     if (bitwidth == 32)
548508a63bSEmilio Cota       return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(),
55b54c724bSRiver Riddle                                            adaptor.getOperands(),
56b54c724bSRiver Riddle                                            getTypeConverter(), rewriter);
578508a63bSEmilio Cota     if (bitwidth == 64)
588508a63bSEmilio Cota       return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(),
59b54c724bSRiver Riddle                                            adaptor.getOperands(),
60b54c724bSRiver Riddle                                            getTypeConverter(), rewriter);
618508a63bSEmilio Cota     return rewriter.notifyMatchFailure(
628508a63bSEmilio Cota         op, "expected 'src' to be either f32 or f64");
638508a63bSEmilio Cota   }
648508a63bSEmilio Cota };
658508a63bSEmilio Cota 
668508a63bSEmilio Cota struct MaskCompressOpConversion
678508a63bSEmilio Cota     : public ConvertOpToLLVMPattern<MaskCompressOp> {
688508a63bSEmilio Cota   using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern;
698508a63bSEmilio Cota 
708508a63bSEmilio Cota   LogicalResult
matchAndRewrite__anon338b04f50111::MaskCompressOpConversion71b54c724bSRiver Riddle   matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor,
728508a63bSEmilio Cota                   ConversionPatternRewriter &rewriter) const override {
73*8df54a6aSJacques Pienaar     auto opType = adaptor.getA().getType();
748508a63bSEmilio Cota 
758508a63bSEmilio Cota     Value src;
76*8df54a6aSJacques Pienaar     if (op.getSrc()) {
77*8df54a6aSJacques Pienaar       src = adaptor.getSrc();
78*8df54a6aSJacques Pienaar     } else if (op.getConstantSrc()) {
79a54f4eaeSMogball       src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType,
80*8df54a6aSJacques Pienaar                                                op.getConstantSrcAttr());
818508a63bSEmilio Cota     } else {
828508a63bSEmilio Cota       Attribute zeroAttr = rewriter.getZeroAttr(opType);
83a54f4eaeSMogball       src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr);
848508a63bSEmilio Cota     }
858508a63bSEmilio Cota 
86*8df54a6aSJacques Pienaar     rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.getA(),
87*8df54a6aSJacques Pienaar                                                     src, adaptor.getK());
888508a63bSEmilio Cota 
898508a63bSEmilio Cota     return success();
908508a63bSEmilio Cota   }
918508a63bSEmilio Cota };
928508a63bSEmilio Cota 
930b63e322SEmilio Cota struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
940b63e322SEmilio Cota   using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
950b63e322SEmilio Cota 
960b63e322SEmilio Cota   LogicalResult
matchAndRewrite__anon338b04f50111::RsqrtOpConversion97b54c724bSRiver Riddle   matchAndRewrite(RsqrtOp op, OpAdaptor adaptor,
980b63e322SEmilio Cota                   ConversionPatternRewriter &rewriter) const override {
99*8df54a6aSJacques Pienaar     auto opType = adaptor.getA().getType();
100*8df54a6aSJacques Pienaar     rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.getA());
1010b63e322SEmilio Cota     return success();
1020b63e322SEmilio Cota   }
1030b63e322SEmilio Cota };
1040b63e322SEmilio Cota 
105916f3e16SAart Bik struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> {
106916f3e16SAart Bik   using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern;
107916f3e16SAart Bik 
108916f3e16SAart Bik   LogicalResult
matchAndRewrite__anon338b04f50111::DotOpConversion109b54c724bSRiver Riddle   matchAndRewrite(DotOp op, OpAdaptor adaptor,
110916f3e16SAart Bik                   ConversionPatternRewriter &rewriter) const override {
111*8df54a6aSJacques Pienaar     auto opType = adaptor.getA().getType();
112916f3e16SAart Bik     Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8);
113916f3e16SAart Bik     // Dot product of all elements, broadcasted to all elements.
11466b1e629SMatthias Springer     auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff));
115916f3e16SAart Bik     Value scale =
116916f3e16SAart Bik         rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr);
117*8df54a6aSJacques Pienaar     rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.getA(),
118*8df54a6aSJacques Pienaar                                            adaptor.getB(), scale);
119916f3e16SAart Bik     return success();
120916f3e16SAart Bik   }
121916f3e16SAart Bik };
122916f3e16SAart Bik 
1238508a63bSEmilio Cota /// An entry associating the "main" AVX512 op with its instantiations for
1248508a63bSEmilio Cota /// vectors of 32-bit and 64-bit elements.
1258508a63bSEmilio Cota template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
1268508a63bSEmilio Cota struct RegEntry {
1278508a63bSEmilio Cota   using MainOp = OpTy;
1288508a63bSEmilio Cota   using Intr32Op = Intr32OpTy;
1298508a63bSEmilio Cota   using Intr64Op = Intr64OpTy;
1308508a63bSEmilio Cota };
1318508a63bSEmilio Cota 
1328508a63bSEmilio Cota /// A container for op association entries facilitating the configuration of
1338508a63bSEmilio Cota /// dialect conversion.
1348508a63bSEmilio Cota template <typename... Args>
1358508a63bSEmilio Cota struct RegistryImpl {
1368508a63bSEmilio Cota   /// Registers the patterns specializing the "main" op to one of the
1378508a63bSEmilio Cota   /// "intrinsic" ops depending on elemental type.
registerPatterns__anon338b04f50111::RegistryImpl1388508a63bSEmilio Cota   static void registerPatterns(LLVMTypeConverter &converter,
1398508a63bSEmilio Cota                                RewritePatternSet &patterns) {
1408508a63bSEmilio Cota     patterns
1418508a63bSEmilio Cota         .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
1428508a63bSEmilio Cota                               typename Args::Intr64Op>...>(converter);
1438508a63bSEmilio Cota   }
1448508a63bSEmilio Cota 
1458508a63bSEmilio Cota   /// Configures the conversion target to lower out "main" ops.
configureTarget__anon338b04f50111::RegistryImpl1468508a63bSEmilio Cota   static void configureTarget(LLVMConversionTarget &target) {
1478508a63bSEmilio Cota     target.addIllegalOp<typename Args::MainOp...>();
1488508a63bSEmilio Cota     target.addLegalOp<typename Args::Intr32Op...>();
1498508a63bSEmilio Cota     target.addLegalOp<typename Args::Intr64Op...>();
1508508a63bSEmilio Cota   }
1518508a63bSEmilio Cota };
1528508a63bSEmilio Cota 
1538508a63bSEmilio Cota using Registry = RegistryImpl<
1548508a63bSEmilio Cota     RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>,
1558508a63bSEmilio Cota     RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>,
1568508a63bSEmilio Cota     RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>;
1578508a63bSEmilio Cota 
1588508a63bSEmilio Cota } // namespace
1598508a63bSEmilio Cota 
1608508a63bSEmilio Cota /// Populate the given list with patterns that convert from X86Vector to LLVM.
populateX86VectorLegalizeForLLVMExportPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)1618508a63bSEmilio Cota void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
1628508a63bSEmilio Cota     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1638508a63bSEmilio Cota   Registry::registerPatterns(converter, patterns);
164916f3e16SAart Bik   patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>(
165916f3e16SAart Bik       converter);
1668508a63bSEmilio Cota }
1678508a63bSEmilio Cota 
configureX86VectorLegalizeForExportTarget(LLVMConversionTarget & target)1688508a63bSEmilio Cota void mlir::configureX86VectorLegalizeForExportTarget(
1698508a63bSEmilio Cota     LLVMConversionTarget &target) {
1708508a63bSEmilio Cota   Registry::configureTarget(target);
1718508a63bSEmilio Cota   target.addLegalOp<MaskCompressIntrOp>();
1728508a63bSEmilio Cota   target.addIllegalOp<MaskCompressOp>();
1730b63e322SEmilio Cota   target.addLegalOp<RsqrtIntrOp>();
1740b63e322SEmilio Cota   target.addIllegalOp<RsqrtOp>();
175916f3e16SAart Bik   target.addLegalOp<DotIntrOp>();
176916f3e16SAart Bik   target.addIllegalOp<DotOp>();
1778508a63bSEmilio Cota }
178