1 //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/X86Vector/Transforms.h" 10 11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 13 #include "mlir/Dialect/StandardOps/IR/Ops.h" 14 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/PatternMatch.h" 17 18 using namespace mlir; 19 using namespace mlir::x86vector; 20 21 /// Extracts the "main" vector element type from the given X86Vector operation. 22 template <typename OpTy> 23 static Type getSrcVectorElementType(OpTy op) { 24 return op.src().getType().template cast<VectorType>().getElementType(); 25 } 26 template <> 27 Type getSrcVectorElementType(Vp2IntersectOp op) { 28 return op.a().getType().template cast<VectorType>().getElementType(); 29 } 30 31 namespace { 32 33 /// Base conversion for AVX512 ops that can be lowered to one of the two 34 /// intrinsics based on the bitwidth of their "main" vector element type. This 35 /// relies on the to-LLVM-dialect conversion helpers to correctly pack the 36 /// results of multi-result intrinsic ops. 37 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy> 38 struct LowerToIntrinsic : public OpConversionPattern<OpTy> { 39 explicit LowerToIntrinsic(LLVMTypeConverter &converter) 40 : OpConversionPattern<OpTy>(converter, &converter.getContext()) {} 41 42 LLVMTypeConverter &getTypeConverter() const { 43 return *static_cast<LLVMTypeConverter *>( 44 OpConversionPattern<OpTy>::getTypeConverter()); 45 } 46 47 LogicalResult 48 matchAndRewrite(OpTy op, ArrayRef<Value> operands, 49 ConversionPatternRewriter &rewriter) const override { 50 Type elementType = getSrcVectorElementType<OpTy>(op); 51 unsigned bitwidth = elementType.getIntOrFloatBitWidth(); 52 if (bitwidth == 32) 53 return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(), 54 operands, getTypeConverter(), 55 rewriter); 56 if (bitwidth == 64) 57 return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(), 58 operands, getTypeConverter(), 59 rewriter); 60 return rewriter.notifyMatchFailure( 61 op, "expected 'src' to be either f32 or f64"); 62 } 63 }; 64 65 struct MaskCompressOpConversion 66 : public ConvertOpToLLVMPattern<MaskCompressOp> { 67 using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern; 68 69 LogicalResult 70 matchAndRewrite(MaskCompressOp op, ArrayRef<Value> operands, 71 ConversionPatternRewriter &rewriter) const override { 72 MaskCompressOp::Adaptor adaptor(operands); 73 auto opType = adaptor.a().getType(); 74 75 Value src; 76 if (op.src()) { 77 src = adaptor.src(); 78 } else if (op.constant_src()) { 79 src = rewriter.create<ConstantOp>(op.getLoc(), opType, 80 op.constant_srcAttr()); 81 } else { 82 Attribute zeroAttr = rewriter.getZeroAttr(opType); 83 src = rewriter.create<ConstantOp>(op->getLoc(), opType, zeroAttr); 84 } 85 86 rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.a(), 87 src, adaptor.k()); 88 89 return success(); 90 } 91 }; 92 93 struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> { 94 using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern; 95 96 LogicalResult 97 matchAndRewrite(RsqrtOp op, ArrayRef<Value> operands, 98 ConversionPatternRewriter &rewriter) const override { 99 RsqrtOp::Adaptor adaptor(operands); 100 101 auto opType = adaptor.a().getType(); 102 rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.a()); 103 return success(); 104 } 105 }; 106 107 struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> { 108 using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern; 109 110 LogicalResult 111 matchAndRewrite(DotOp op, ArrayRef<Value> operands, 112 ConversionPatternRewriter &rewriter) const override { 113 DotOp::Adaptor adaptor(operands); 114 auto opType = adaptor.a().getType(); 115 Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8); 116 // Dot product of all elements, broadcasted to all elements. 117 auto attr = rewriter.getI8IntegerAttr(0xff); 118 Value scale = 119 rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr); 120 rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.a(), adaptor.b(), 121 scale); 122 return success(); 123 } 124 }; 125 126 /// An entry associating the "main" AVX512 op with its instantiations for 127 /// vectors of 32-bit and 64-bit elements. 128 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy> 129 struct RegEntry { 130 using MainOp = OpTy; 131 using Intr32Op = Intr32OpTy; 132 using Intr64Op = Intr64OpTy; 133 }; 134 135 /// A container for op association entries facilitating the configuration of 136 /// dialect conversion. 137 template <typename... Args> 138 struct RegistryImpl { 139 /// Registers the patterns specializing the "main" op to one of the 140 /// "intrinsic" ops depending on elemental type. 141 static void registerPatterns(LLVMTypeConverter &converter, 142 RewritePatternSet &patterns) { 143 patterns 144 .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op, 145 typename Args::Intr64Op>...>(converter); 146 } 147 148 /// Configures the conversion target to lower out "main" ops. 149 static void configureTarget(LLVMConversionTarget &target) { 150 target.addIllegalOp<typename Args::MainOp...>(); 151 target.addLegalOp<typename Args::Intr32Op...>(); 152 target.addLegalOp<typename Args::Intr64Op...>(); 153 } 154 }; 155 156 using Registry = RegistryImpl< 157 RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>, 158 RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>, 159 RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>; 160 161 } // namespace 162 163 /// Populate the given list with patterns that convert from X86Vector to LLVM. 164 void mlir::populateX86VectorLegalizeForLLVMExportPatterns( 165 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 166 Registry::registerPatterns(converter, patterns); 167 patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>( 168 converter); 169 } 170 171 void mlir::configureX86VectorLegalizeForExportTarget( 172 LLVMConversionTarget &target) { 173 Registry::configureTarget(target); 174 target.addLegalOp<MaskCompressIntrOp>(); 175 target.addIllegalOp<MaskCompressOp>(); 176 target.addLegalOp<RsqrtIntrOp>(); 177 target.addIllegalOp<RsqrtOp>(); 178 target.addLegalOp<DotIntrOp>(); 179 target.addIllegalOp<DotOp>(); 180 } 181