1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 10 #include "mlir/Conversion/LLVMCommon/Pattern.h" 11 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" 12 #include "mlir/Dialect/ArmSVE/Transforms.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/PatternMatch.h" 17 18 using namespace mlir; 19 using namespace mlir::arm_sve; 20 21 template <typename OpTy> 22 class ForwardOperands : public OpConversionPattern<OpTy> { 23 using OpConversionPattern<OpTy>::OpConversionPattern; 24 25 LogicalResult 26 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, 27 ConversionPatternRewriter &rewriter) const final { 28 if (adaptor.getOperands().getTypes() == op->getOperands().getTypes()) 29 return rewriter.notifyMatchFailure(op, "operand types already match"); 30 31 rewriter.updateRootInPlace( 32 op, [&]() { op->setOperands(adaptor.getOperands()); }); 33 return success(); 34 } 35 }; 36 37 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> { 38 public: 39 using OpConversionPattern<ReturnOp>::OpConversionPattern; 40 41 LogicalResult 42 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 43 ConversionPatternRewriter &rewriter) const final { 44 rewriter.updateRootInPlace( 45 op, [&]() { op->setOperands(adaptor.getOperands()); }); 46 return success(); 47 } 48 }; 49 50 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>; 51 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>; 52 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>; 53 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>; 54 using ScalableMaskedAddIOpLowering = 55 OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp, 56 ScalableMaskedAddIIntrOp>; 57 using ScalableMaskedAddFOpLowering = 58 OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp, 59 ScalableMaskedAddFIntrOp>; 60 using ScalableMaskedSubIOpLowering = 61 OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp, 62 ScalableMaskedSubIIntrOp>; 63 using ScalableMaskedSubFOpLowering = 64 OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp, 65 ScalableMaskedSubFIntrOp>; 66 using ScalableMaskedMulIOpLowering = 67 OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp, 68 ScalableMaskedMulIIntrOp>; 69 using ScalableMaskedMulFOpLowering = 70 OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp, 71 ScalableMaskedMulFIntrOp>; 72 using ScalableMaskedSDivIOpLowering = 73 OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp, 74 ScalableMaskedSDivIIntrOp>; 75 using ScalableMaskedUDivIOpLowering = 76 OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp, 77 ScalableMaskedUDivIIntrOp>; 78 using ScalableMaskedDivFOpLowering = 79 OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp, 80 ScalableMaskedDivFIntrOp>; 81 82 /// Populate the given list with patterns that convert from ArmSVE to LLVM. 83 void mlir::populateArmSVELegalizeForLLVMExportPatterns( 84 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 85 // Populate conversion patterns 86 87 // clang-format off 88 patterns.add<ForwardOperands<CallOp>, 89 ForwardOperands<CallIndirectOp>, 90 ForwardOperands<ReturnOp>>(converter, 91 &converter.getContext()); 92 patterns.add<SdotOpLowering, 93 SmmlaOpLowering, 94 UdotOpLowering, 95 UmmlaOpLowering, 96 ScalableMaskedAddIOpLowering, 97 ScalableMaskedAddFOpLowering, 98 ScalableMaskedSubIOpLowering, 99 ScalableMaskedSubFOpLowering, 100 ScalableMaskedMulIOpLowering, 101 ScalableMaskedMulFOpLowering, 102 ScalableMaskedSDivIOpLowering, 103 ScalableMaskedUDivIOpLowering, 104 ScalableMaskedDivFOpLowering>(converter); 105 // clang-format on 106 } 107 108 void mlir::configureArmSVELegalizeForExportTarget( 109 LLVMConversionTarget &target) { 110 // clang-format off 111 target.addLegalOp<SdotIntrOp, 112 SmmlaIntrOp, 113 UdotIntrOp, 114 UmmlaIntrOp, 115 ScalableMaskedAddIIntrOp, 116 ScalableMaskedAddFIntrOp, 117 ScalableMaskedSubIIntrOp, 118 ScalableMaskedSubFIntrOp, 119 ScalableMaskedMulIIntrOp, 120 ScalableMaskedMulFIntrOp, 121 ScalableMaskedSDivIIntrOp, 122 ScalableMaskedUDivIIntrOp, 123 ScalableMaskedDivFIntrOp>(); 124 target.addIllegalOp<SdotOp, 125 SmmlaOp, 126 UdotOp, 127 UmmlaOp, 128 ScalableMaskedAddIOp, 129 ScalableMaskedAddFOp, 130 ScalableMaskedSubIOp, 131 ScalableMaskedSubFOp, 132 ScalableMaskedMulIOp, 133 ScalableMaskedMulFOp, 134 ScalableMaskedSDivIOp, 135 ScalableMaskedUDivIOp, 136 ScalableMaskedDivFOp>(); 137 // clang-format on 138 } 139