1 //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/X86Vector/Transforms.h" 10 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/PatternMatch.h" 19 20 using namespace mlir; 21 using namespace mlir::x86vector; 22 23 /// Extracts the "main" vector element type from the given X86Vector operation. 24 template <typename OpTy> 25 static Type getSrcVectorElementType(OpTy op) { 26 return op.src().getType().template cast<VectorType>().getElementType(); 27 } 28 template <> 29 Type getSrcVectorElementType(Vp2IntersectOp op) { 30 return op.a().getType().template cast<VectorType>().getElementType(); 31 } 32 33 namespace { 34 35 /// Base conversion for AVX512 ops that can be lowered to one of the two 36 /// intrinsics based on the bitwidth of their "main" vector element type. This 37 /// relies on the to-LLVM-dialect conversion helpers to correctly pack the 38 /// results of multi-result intrinsic ops. 39 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy> 40 struct LowerToIntrinsic : public OpConversionPattern<OpTy> { 41 explicit LowerToIntrinsic(LLVMTypeConverter &converter) 42 : OpConversionPattern<OpTy>(converter, &converter.getContext()) {} 43 44 LLVMTypeConverter &getTypeConverter() const { 45 return *static_cast<LLVMTypeConverter *>( 46 OpConversionPattern<OpTy>::getTypeConverter()); 47 } 48 49 LogicalResult 50 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, 51 ConversionPatternRewriter &rewriter) const override { 52 Type elementType = getSrcVectorElementType<OpTy>(op); 53 unsigned bitwidth = elementType.getIntOrFloatBitWidth(); 54 if (bitwidth == 32) 55 return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(), 56 adaptor.getOperands(), 57 getTypeConverter(), rewriter); 58 if (bitwidth == 64) 59 return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(), 60 adaptor.getOperands(), 61 getTypeConverter(), rewriter); 62 return rewriter.notifyMatchFailure( 63 op, "expected 'src' to be either f32 or f64"); 64 } 65 }; 66 67 struct MaskCompressOpConversion 68 : public ConvertOpToLLVMPattern<MaskCompressOp> { 69 using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern; 70 71 LogicalResult 72 matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor, 73 ConversionPatternRewriter &rewriter) const override { 74 auto opType = adaptor.a().getType(); 75 76 Value src; 77 if (op.src()) { 78 src = adaptor.src(); 79 } else if (op.constant_src()) { 80 src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType, 81 op.constant_srcAttr()); 82 } else { 83 Attribute zeroAttr = rewriter.getZeroAttr(opType); 84 src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr); 85 } 86 87 rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.a(), 88 src, adaptor.k()); 89 90 return success(); 91 } 92 }; 93 94 struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> { 95 using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern; 96 97 LogicalResult 98 matchAndRewrite(RsqrtOp op, OpAdaptor adaptor, 99 ConversionPatternRewriter &rewriter) const override { 100 auto opType = adaptor.a().getType(); 101 rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.a()); 102 return success(); 103 } 104 }; 105 106 struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> { 107 using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern; 108 109 LogicalResult 110 matchAndRewrite(DotOp op, OpAdaptor adaptor, 111 ConversionPatternRewriter &rewriter) const override { 112 auto opType = adaptor.a().getType(); 113 Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8); 114 // Dot product of all elements, broadcasted to all elements. 115 auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff)); 116 Value scale = 117 rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr); 118 rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.a(), adaptor.b(), 119 scale); 120 return success(); 121 } 122 }; 123 124 /// An entry associating the "main" AVX512 op with its instantiations for 125 /// vectors of 32-bit and 64-bit elements. 126 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy> 127 struct RegEntry { 128 using MainOp = OpTy; 129 using Intr32Op = Intr32OpTy; 130 using Intr64Op = Intr64OpTy; 131 }; 132 133 /// A container for op association entries facilitating the configuration of 134 /// dialect conversion. 135 template <typename... Args> 136 struct RegistryImpl { 137 /// Registers the patterns specializing the "main" op to one of the 138 /// "intrinsic" ops depending on elemental type. 139 static void registerPatterns(LLVMTypeConverter &converter, 140 RewritePatternSet &patterns) { 141 patterns 142 .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op, 143 typename Args::Intr64Op>...>(converter); 144 } 145 146 /// Configures the conversion target to lower out "main" ops. 147 static void configureTarget(LLVMConversionTarget &target) { 148 target.addIllegalOp<typename Args::MainOp...>(); 149 target.addLegalOp<typename Args::Intr32Op...>(); 150 target.addLegalOp<typename Args::Intr64Op...>(); 151 } 152 }; 153 154 using Registry = RegistryImpl< 155 RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>, 156 RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>, 157 RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>; 158 159 } // namespace 160 161 /// Populate the given list with patterns that convert from X86Vector to LLVM. 162 void mlir::populateX86VectorLegalizeForLLVMExportPatterns( 163 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 164 Registry::registerPatterns(converter, patterns); 165 patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>( 166 converter); 167 } 168 169 void mlir::configureX86VectorLegalizeForExportTarget( 170 LLVMConversionTarget &target) { 171 Registry::configureTarget(target); 172 target.addLegalOp<MaskCompressIntrOp>(); 173 target.addIllegalOp<MaskCompressOp>(); 174 target.addLegalOp<RsqrtIntrOp>(); 175 target.addIllegalOp<RsqrtOp>(); 176 target.addLegalOp<DotIntrOp>(); 177 target.addIllegalOp<DotOp>(); 178 } 179