1 //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/X86Vector/Transforms.h" 10 11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 13 #include "mlir/Dialect/StandardOps/IR/Ops.h" 14 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/PatternMatch.h" 17 18 using namespace mlir; 19 using namespace mlir::x86vector; 20 21 /// Extracts the "main" vector element type from the given X86Vector operation. 22 template <typename OpTy> 23 static Type getSrcVectorElementType(OpTy op) { 24 return op.src().getType().template cast<VectorType>().getElementType(); 25 } 26 template <> 27 Type getSrcVectorElementType(Vp2IntersectOp op) { 28 return op.a().getType().template cast<VectorType>().getElementType(); 29 } 30 31 namespace { 32 33 /// Base conversion for AVX512 ops that can be lowered to one of the two 34 /// intrinsics based on the bitwidth of their "main" vector element type. This 35 /// relies on the to-LLVM-dialect conversion helpers to correctly pack the 36 /// results of multi-result intrinsic ops. 37 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy> 38 struct LowerToIntrinsic : public OpConversionPattern<OpTy> { 39 explicit LowerToIntrinsic(LLVMTypeConverter &converter) 40 : OpConversionPattern<OpTy>(converter, &converter.getContext()) {} 41 42 LLVMTypeConverter &getTypeConverter() const { 43 return *static_cast<LLVMTypeConverter *>( 44 OpConversionPattern<OpTy>::getTypeConverter()); 45 } 46 47 LogicalResult 48 matchAndRewrite(OpTy op, ArrayRef<Value> operands, 49 ConversionPatternRewriter &rewriter) const override { 50 Type elementType = getSrcVectorElementType<OpTy>(op); 51 unsigned bitwidth = elementType.getIntOrFloatBitWidth(); 52 if (bitwidth == 32) 53 return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(), 54 operands, getTypeConverter(), 55 rewriter); 56 if (bitwidth == 64) 57 return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(), 58 operands, getTypeConverter(), 59 rewriter); 60 return rewriter.notifyMatchFailure( 61 op, "expected 'src' to be either f32 or f64"); 62 } 63 }; 64 65 struct MaskCompressOpConversion 66 : public ConvertOpToLLVMPattern<MaskCompressOp> { 67 using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern; 68 69 LogicalResult 70 matchAndRewrite(MaskCompressOp op, ArrayRef<Value> operands, 71 ConversionPatternRewriter &rewriter) const override { 72 MaskCompressOp::Adaptor adaptor(operands); 73 auto opType = adaptor.a().getType(); 74 75 Value src; 76 if (op.src()) { 77 src = adaptor.src(); 78 } else if (op.constant_src()) { 79 src = rewriter.create<ConstantOp>(op.getLoc(), opType, 80 op.constant_srcAttr()); 81 } else { 82 Attribute zeroAttr = rewriter.getZeroAttr(opType); 83 src = rewriter.create<ConstantOp>(op->getLoc(), opType, zeroAttr); 84 } 85 86 rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.a(), 87 src, adaptor.k()); 88 89 return success(); 90 } 91 }; 92 93 struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> { 94 using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern; 95 96 LogicalResult 97 matchAndRewrite(RsqrtOp op, ArrayRef<Value> operands, 98 ConversionPatternRewriter &rewriter) const override { 99 RsqrtOp::Adaptor adaptor(operands); 100 101 auto opType = adaptor.a().getType(); 102 rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.a()); 103 return success(); 104 } 105 }; 106 107 /// An entry associating the "main" AVX512 op with its instantiations for 108 /// vectors of 32-bit and 64-bit elements. 109 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy> 110 struct RegEntry { 111 using MainOp = OpTy; 112 using Intr32Op = Intr32OpTy; 113 using Intr64Op = Intr64OpTy; 114 }; 115 116 /// A container for op association entries facilitating the configuration of 117 /// dialect conversion. 118 template <typename... Args> 119 struct RegistryImpl { 120 /// Registers the patterns specializing the "main" op to one of the 121 /// "intrinsic" ops depending on elemental type. 122 static void registerPatterns(LLVMTypeConverter &converter, 123 RewritePatternSet &patterns) { 124 patterns 125 .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op, 126 typename Args::Intr64Op>...>(converter); 127 } 128 129 /// Configures the conversion target to lower out "main" ops. 130 static void configureTarget(LLVMConversionTarget &target) { 131 target.addIllegalOp<typename Args::MainOp...>(); 132 target.addLegalOp<typename Args::Intr32Op...>(); 133 target.addLegalOp<typename Args::Intr64Op...>(); 134 } 135 }; 136 137 using Registry = RegistryImpl< 138 RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>, 139 RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>, 140 RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>; 141 142 } // namespace 143 144 /// Populate the given list with patterns that convert from X86Vector to LLVM. 145 void mlir::populateX86VectorLegalizeForLLVMExportPatterns( 146 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 147 Registry::registerPatterns(converter, patterns); 148 patterns.add<MaskCompressOpConversion, RsqrtOpConversion>(converter); 149 } 150 151 void mlir::configureX86VectorLegalizeForExportTarget( 152 LLVMConversionTarget &target) { 153 Registry::configureTarget(target); 154 target.addLegalOp<MaskCompressIntrOp>(); 155 target.addIllegalOp<MaskCompressOp>(); 156 target.addLegalOp<RsqrtIntrOp>(); 157 target.addIllegalOp<RsqrtOp>(); 158 } 159