1 //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/X86Vector/Transforms.h" 10 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 16 #include "mlir/IR/BuiltinOps.h" 17 #include "mlir/IR/PatternMatch.h" 18 19 using namespace mlir; 20 using namespace mlir::x86vector; 21 22 /// Extracts the "main" vector element type from the given X86Vector operation. 23 template <typename OpTy> 24 static Type getSrcVectorElementType(OpTy op) { 25 return op.src().getType().template cast<VectorType>().getElementType(); 26 } 27 template <> 28 Type getSrcVectorElementType(Vp2IntersectOp op) { 29 return op.a().getType().template cast<VectorType>().getElementType(); 30 } 31 32 namespace { 33 34 /// Base conversion for AVX512 ops that can be lowered to one of the two 35 /// intrinsics based on the bitwidth of their "main" vector element type. This 36 /// relies on the to-LLVM-dialect conversion helpers to correctly pack the 37 /// results of multi-result intrinsic ops. 38 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy> 39 struct LowerToIntrinsic : public OpConversionPattern<OpTy> { 40 explicit LowerToIntrinsic(LLVMTypeConverter &converter) 41 : OpConversionPattern<OpTy>(converter, &converter.getContext()) {} 42 43 LLVMTypeConverter &getTypeConverter() const { 44 return *static_cast<LLVMTypeConverter *>( 45 OpConversionPattern<OpTy>::getTypeConverter()); 46 } 47 48 LogicalResult 49 matchAndRewrite(OpTy op, ArrayRef<Value> operands, 50 ConversionPatternRewriter &rewriter) const override { 51 Type elementType = getSrcVectorElementType<OpTy>(op); 52 unsigned bitwidth = elementType.getIntOrFloatBitWidth(); 53 if (bitwidth == 32) 54 return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(), 55 operands, getTypeConverter(), 56 rewriter); 57 if (bitwidth == 64) 58 return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(), 59 operands, getTypeConverter(), 60 rewriter); 61 return rewriter.notifyMatchFailure( 62 op, "expected 'src' to be either f32 or f64"); 63 } 64 }; 65 66 struct MaskCompressOpConversion 67 : public ConvertOpToLLVMPattern<MaskCompressOp> { 68 using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern; 69 70 LogicalResult 71 matchAndRewrite(MaskCompressOp op, ArrayRef<Value> operands, 72 ConversionPatternRewriter &rewriter) const override { 73 MaskCompressOp::Adaptor adaptor(operands); 74 auto opType = adaptor.a().getType(); 75 76 Value src; 77 if (op.src()) { 78 src = adaptor.src(); 79 } else if (op.constant_src()) { 80 src = rewriter.create<ConstantOp>(op.getLoc(), opType, 81 op.constant_srcAttr()); 82 } else { 83 Attribute zeroAttr = rewriter.getZeroAttr(opType); 84 src = rewriter.create<ConstantOp>(op->getLoc(), opType, zeroAttr); 85 } 86 87 rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.a(), 88 src, adaptor.k()); 89 90 return success(); 91 } 92 }; 93 94 struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> { 95 using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern; 96 97 LogicalResult 98 matchAndRewrite(RsqrtOp op, ArrayRef<Value> operands, 99 ConversionPatternRewriter &rewriter) const override { 100 RsqrtOp::Adaptor adaptor(operands); 101 102 auto opType = adaptor.a().getType(); 103 rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.a()); 104 return success(); 105 } 106 }; 107 108 struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> { 109 using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern; 110 111 LogicalResult 112 matchAndRewrite(DotOp op, ArrayRef<Value> operands, 113 ConversionPatternRewriter &rewriter) const override { 114 DotOp::Adaptor adaptor(operands); 115 auto opType = adaptor.a().getType(); 116 Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8); 117 // Dot product of all elements, broadcasted to all elements. 118 auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff)); 119 Value scale = 120 rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr); 121 rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.a(), adaptor.b(), 122 scale); 123 return success(); 124 } 125 }; 126 127 /// An entry associating the "main" AVX512 op with its instantiations for 128 /// vectors of 32-bit and 64-bit elements. 129 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy> 130 struct RegEntry { 131 using MainOp = OpTy; 132 using Intr32Op = Intr32OpTy; 133 using Intr64Op = Intr64OpTy; 134 }; 135 136 /// A container for op association entries facilitating the configuration of 137 /// dialect conversion. 138 template <typename... Args> 139 struct RegistryImpl { 140 /// Registers the patterns specializing the "main" op to one of the 141 /// "intrinsic" ops depending on elemental type. 142 static void registerPatterns(LLVMTypeConverter &converter, 143 RewritePatternSet &patterns) { 144 patterns 145 .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op, 146 typename Args::Intr64Op>...>(converter); 147 } 148 149 /// Configures the conversion target to lower out "main" ops. 150 static void configureTarget(LLVMConversionTarget &target) { 151 target.addIllegalOp<typename Args::MainOp...>(); 152 target.addLegalOp<typename Args::Intr32Op...>(); 153 target.addLegalOp<typename Args::Intr64Op...>(); 154 } 155 }; 156 157 using Registry = RegistryImpl< 158 RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>, 159 RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>, 160 RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>; 161 162 } // namespace 163 164 /// Populate the given list with patterns that convert from X86Vector to LLVM. 165 void mlir::populateX86VectorLegalizeForLLVMExportPatterns( 166 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 167 Registry::registerPatterns(converter, patterns); 168 patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>( 169 converter); 170 } 171 172 void mlir::configureX86VectorLegalizeForExportTarget( 173 LLVMConversionTarget &target) { 174 Registry::configureTarget(target); 175 target.addLegalOp<MaskCompressIntrOp>(); 176 target.addIllegalOp<MaskCompressOp>(); 177 target.addLegalOp<RsqrtIntrOp>(); 178 target.addIllegalOp<RsqrtOp>(); 179 target.addLegalOp<DotIntrOp>(); 180 target.addIllegalOp<DotOp>(); 181 } 182