1 //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/X86Vector/Transforms.h" 10 11 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 13 #include "mlir/Dialect/StandardOps/IR/Ops.h" 14 #include "mlir/Dialect/X86Vector/X86VectorDialect.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/PatternMatch.h" 17 18 using namespace mlir; 19 using namespace mlir::x86vector; 20 21 /// Extracts the "main" vector element type from the given X86Vector operation. 22 template <typename OpTy> 23 static Type getSrcVectorElementType(OpTy op) { 24 return op.src().getType().template cast<VectorType>().getElementType(); 25 } 26 template <> 27 Type getSrcVectorElementType(Vp2IntersectOp op) { 28 return op.a().getType().template cast<VectorType>().getElementType(); 29 } 30 31 namespace { 32 33 /// Base conversion for AVX512 ops that can be lowered to one of the two 34 /// intrinsics based on the bitwidth of their "main" vector element type. This 35 /// relies on the to-LLVM-dialect conversion helpers to correctly pack the 36 /// results of multi-result intrinsic ops. 37 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy> 38 struct LowerToIntrinsic : public OpConversionPattern<OpTy> { 39 explicit LowerToIntrinsic(LLVMTypeConverter &converter) 40 : OpConversionPattern<OpTy>(converter, &converter.getContext()) {} 41 42 LLVMTypeConverter &getTypeConverter() const { 43 return *static_cast<LLVMTypeConverter *>( 44 OpConversionPattern<OpTy>::getTypeConverter()); 45 } 46 47 LogicalResult 48 matchAndRewrite(OpTy op, ArrayRef<Value> operands, 49 ConversionPatternRewriter &rewriter) const override { 50 Type elementType = getSrcVectorElementType<OpTy>(op); 51 unsigned bitwidth = elementType.getIntOrFloatBitWidth(); 52 if (bitwidth == 32) 53 return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(), 54 operands, getTypeConverter(), 55 rewriter); 56 if (bitwidth == 64) 57 return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(), 58 operands, getTypeConverter(), 59 rewriter); 60 return rewriter.notifyMatchFailure( 61 op, "expected 'src' to be either f32 or f64"); 62 } 63 }; 64 65 struct MaskCompressOpConversion 66 : public ConvertOpToLLVMPattern<MaskCompressOp> { 67 using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern; 68 69 LogicalResult 70 matchAndRewrite(MaskCompressOp op, ArrayRef<Value> operands, 71 ConversionPatternRewriter &rewriter) const override { 72 MaskCompressOp::Adaptor adaptor(operands); 73 auto opType = adaptor.a().getType(); 74 75 Value src; 76 if (op.src()) { 77 src = adaptor.src(); 78 } else if (op.constant_src()) { 79 src = rewriter.create<ConstantOp>(op.getLoc(), opType, 80 op.constant_srcAttr()); 81 } else { 82 Attribute zeroAttr = rewriter.getZeroAttr(opType); 83 src = rewriter.create<ConstantOp>(op->getLoc(), opType, zeroAttr); 84 } 85 86 rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.a(), 87 src, adaptor.k()); 88 89 return success(); 90 } 91 }; 92 93 /// An entry associating the "main" AVX512 op with its instantiations for 94 /// vectors of 32-bit and 64-bit elements. 95 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy> 96 struct RegEntry { 97 using MainOp = OpTy; 98 using Intr32Op = Intr32OpTy; 99 using Intr64Op = Intr64OpTy; 100 }; 101 102 /// A container for op association entries facilitating the configuration of 103 /// dialect conversion. 104 template <typename... Args> 105 struct RegistryImpl { 106 /// Registers the patterns specializing the "main" op to one of the 107 /// "intrinsic" ops depending on elemental type. 108 static void registerPatterns(LLVMTypeConverter &converter, 109 RewritePatternSet &patterns) { 110 patterns 111 .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op, 112 typename Args::Intr64Op>...>(converter); 113 } 114 115 /// Configures the conversion target to lower out "main" ops. 116 static void configureTarget(LLVMConversionTarget &target) { 117 target.addIllegalOp<typename Args::MainOp...>(); 118 target.addLegalOp<typename Args::Intr32Op...>(); 119 target.addLegalOp<typename Args::Intr64Op...>(); 120 } 121 }; 122 123 using Registry = RegistryImpl< 124 RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>, 125 RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>, 126 RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>; 127 128 } // namespace 129 130 /// Populate the given list with patterns that convert from X86Vector to LLVM. 131 void mlir::populateX86VectorLegalizeForLLVMExportPatterns( 132 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 133 Registry::registerPatterns(converter, patterns); 134 patterns.add<MaskCompressOpConversion>(converter); 135 } 136 137 void mlir::configureX86VectorLegalizeForExportTarget( 138 LLVMConversionTarget &target) { 139 Registry::configureTarget(target); 140 target.addLegalOp<MaskCompressIntrOp>(); 141 target.addIllegalOp<MaskCompressOp>(); 142 } 143