1 //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/X86Vector/Transforms.h" 10 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/PatternMatch.h" 18 19 using namespace mlir; 20 using namespace mlir::x86vector; 21 22 /// Extracts the "main" vector element type from the given X86Vector operation. 23 template <typename OpTy> 24 static Type getSrcVectorElementType(OpTy op) { 25 return op.getSrc().getType().template cast<VectorType>().getElementType(); 26 } 27 template <> 28 Type getSrcVectorElementType(Vp2IntersectOp op) { 29 return op.getA().getType().template cast<VectorType>().getElementType(); 30 } 31 32 namespace { 33 34 /// Base conversion for AVX512 ops that can be lowered to one of the two 35 /// intrinsics based on the bitwidth of their "main" vector element type. This 36 /// relies on the to-LLVM-dialect conversion helpers to correctly pack the 37 /// results of multi-result intrinsic ops. 38 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy> 39 struct LowerToIntrinsic : public OpConversionPattern<OpTy> { 40 explicit LowerToIntrinsic(LLVMTypeConverter &converter) 41 : OpConversionPattern<OpTy>(converter, &converter.getContext()) {} 42 43 LLVMTypeConverter &getTypeConverter() const { 44 return *static_cast<LLVMTypeConverter *>( 45 OpConversionPattern<OpTy>::getTypeConverter()); 46 } 47 48 LogicalResult 49 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, 50 ConversionPatternRewriter &rewriter) const override { 51 Type elementType = getSrcVectorElementType<OpTy>(op); 52 unsigned bitwidth = elementType.getIntOrFloatBitWidth(); 53 if (bitwidth == 32) 54 return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(), 55 adaptor.getOperands(), 56 getTypeConverter(), rewriter); 57 if (bitwidth == 64) 58 return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(), 59 adaptor.getOperands(), 60 getTypeConverter(), rewriter); 61 return rewriter.notifyMatchFailure( 62 op, "expected 'src' to be either f32 or f64"); 63 } 64 }; 65 66 struct MaskCompressOpConversion 67 : public ConvertOpToLLVMPattern<MaskCompressOp> { 68 using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern; 69 70 LogicalResult 71 matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor, 72 ConversionPatternRewriter &rewriter) const override { 73 auto opType = adaptor.getA().getType(); 74 75 Value src; 76 if (op.getSrc()) { 77 src = adaptor.getSrc(); 78 } else if (op.getConstantSrc()) { 79 src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType, 80 op.getConstantSrcAttr()); 81 } else { 82 Attribute zeroAttr = rewriter.getZeroAttr(opType); 83 src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr); 84 } 85 86 rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.getA(), 87 src, adaptor.getK()); 88 89 return success(); 90 } 91 }; 92 93 struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> { 94 using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern; 95 96 LogicalResult 97 matchAndRewrite(RsqrtOp op, OpAdaptor adaptor, 98 ConversionPatternRewriter &rewriter) const override { 99 auto opType = adaptor.getA().getType(); 100 rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.getA()); 101 return success(); 102 } 103 }; 104 105 struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> { 106 using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern; 107 108 LogicalResult 109 matchAndRewrite(DotOp op, OpAdaptor adaptor, 110 ConversionPatternRewriter &rewriter) const override { 111 auto opType = adaptor.getA().getType(); 112 Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8); 113 // Dot product of all elements, broadcasted to all elements. 114 auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff)); 115 Value scale = 116 rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr); 117 rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.getA(), 118 adaptor.getB(), scale); 119 return success(); 120 } 121 }; 122 123 /// An entry associating the "main" AVX512 op with its instantiations for 124 /// vectors of 32-bit and 64-bit elements. 125 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy> 126 struct RegEntry { 127 using MainOp = OpTy; 128 using Intr32Op = Intr32OpTy; 129 using Intr64Op = Intr64OpTy; 130 }; 131 132 /// A container for op association entries facilitating the configuration of 133 /// dialect conversion. 134 template <typename... Args> 135 struct RegistryImpl { 136 /// Registers the patterns specializing the "main" op to one of the 137 /// "intrinsic" ops depending on elemental type. 138 static void registerPatterns(LLVMTypeConverter &converter, 139 RewritePatternSet &patterns) { 140 patterns 141 .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op, 142 typename Args::Intr64Op>...>(converter); 143 } 144 145 /// Configures the conversion target to lower out "main" ops. 146 static void configureTarget(LLVMConversionTarget &target) { 147 target.addIllegalOp<typename Args::MainOp...>(); 148 target.addLegalOp<typename Args::Intr32Op...>(); 149 target.addLegalOp<typename Args::Intr64Op...>(); 150 } 151 }; 152 153 using Registry = RegistryImpl< 154 RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>, 155 RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>, 156 RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>; 157 158 } // namespace 159 160 /// Populate the given list with patterns that convert from X86Vector to LLVM. 161 void mlir::populateX86VectorLegalizeForLLVMExportPatterns( 162 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 163 Registry::registerPatterns(converter, patterns); 164 patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>( 165 converter); 166 } 167 168 void mlir::configureX86VectorLegalizeForExportTarget( 169 LLVMConversionTarget &target) { 170 Registry::configureTarget(target); 171 target.addLegalOp<MaskCompressIntrOp>(); 172 target.addIllegalOp<MaskCompressOp>(); 173 target.addLegalOp<RsqrtIntrOp>(); 174 target.addIllegalOp<RsqrtOp>(); 175 target.addLegalOp<DotIntrOp>(); 176 target.addIllegalOp<DotOp>(); 177 } 178