//===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; using namespace mlir::x86vector; /// Extracts the "main" vector element type from the given X86Vector operation. template static Type getSrcVectorElementType(OpTy op) { return op.getSrc().getType().template cast().getElementType(); } template <> Type getSrcVectorElementType(Vp2IntersectOp op) { return op.getA().getType().template cast().getElementType(); } namespace { /// Base conversion for AVX512 ops that can be lowered to one of the two /// intrinsics based on the bitwidth of their "main" vector element type. This /// relies on the to-LLVM-dialect conversion helpers to correctly pack the /// results of multi-result intrinsic ops. template struct LowerToIntrinsic : public OpConversionPattern { explicit LowerToIntrinsic(LLVMTypeConverter &converter) : OpConversionPattern(converter, &converter.getContext()) {} LLVMTypeConverter &getTypeConverter() const { return *static_cast( OpConversionPattern::getTypeConverter()); } LogicalResult matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type elementType = getSrcVectorElementType(op); unsigned bitwidth = elementType.getIntOrFloatBitWidth(); if (bitwidth == 32) return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(), adaptor.getOperands(), getTypeConverter(), rewriter); if (bitwidth == 64) return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(), adaptor.getOperands(), getTypeConverter(), rewriter); return rewriter.notifyMatchFailure( op, "expected 'src' to be either f32 or f64"); } }; struct MaskCompressOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto opType = adaptor.getA().getType(); Value src; if (op.getSrc()) { src = adaptor.getSrc(); } else if (op.getConstantSrc()) { src = rewriter.create(op.getLoc(), opType, op.getConstantSrcAttr()); } else { Attribute zeroAttr = rewriter.getZeroAttr(opType); src = rewriter.create(op->getLoc(), opType, zeroAttr); } rewriter.replaceOpWithNewOp(op, opType, adaptor.getA(), src, adaptor.getK()); return success(); } }; struct RsqrtOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(RsqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto opType = adaptor.getA().getType(); rewriter.replaceOpWithNewOp(op, opType, adaptor.getA()); return success(); } }; struct DotOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto opType = adaptor.getA().getType(); Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8); // Dot product of all elements, broadcasted to all elements. auto attr = rewriter.getI8IntegerAttr(static_cast(0xff)); Value scale = rewriter.create(op.getLoc(), llvmIntType, attr); rewriter.replaceOpWithNewOp(op, opType, adaptor.getA(), adaptor.getB(), scale); return success(); } }; /// An entry associating the "main" AVX512 op with its instantiations for /// vectors of 32-bit and 64-bit elements. template struct RegEntry { using MainOp = OpTy; using Intr32Op = Intr32OpTy; using Intr64Op = Intr64OpTy; }; /// A container for op association entries facilitating the configuration of /// dialect conversion. template struct RegistryImpl { /// Registers the patterns specializing the "main" op to one of the /// "intrinsic" ops depending on elemental type. static void registerPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns .add...>(converter); } /// Configures the conversion target to lower out "main" ops. static void configureTarget(LLVMConversionTarget &target) { target.addIllegalOp(); target.addLegalOp(); target.addLegalOp(); } }; using Registry = RegistryImpl< RegEntry, RegEntry, RegEntry>; } // namespace /// Populate the given list with patterns that convert from X86Vector to LLVM. void mlir::populateX86VectorLegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { Registry::registerPatterns(converter, patterns); patterns.add( converter); } void mlir::configureX86VectorLegalizeForExportTarget( LLVMConversionTarget &target) { Registry::configureTarget(target); target.addLegalOp(); target.addIllegalOp(); target.addLegalOp(); target.addIllegalOp(); target.addLegalOp(); target.addIllegalOp(); }