1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 10 #include "mlir/Conversion/LLVMCommon/Pattern.h" 11 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" 12 #include "mlir/Dialect/ArmSVE/Transforms.h" 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/PatternMatch.h" 17 18 using namespace mlir; 19 using namespace mlir::arm_sve; 20 21 template <typename OpTy> 22 class ForwardOperands : public OpConversionPattern<OpTy> { 23 using OpConversionPattern<OpTy>::OpConversionPattern; 24 25 LogicalResult 26 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, 27 ConversionPatternRewriter &rewriter) const final { 28 if (adaptor.getOperands().getTypes() == op->getOperands().getTypes()) 29 return rewriter.notifyMatchFailure(op, "operand types already match"); 30 31 rewriter.updateRootInPlace( 32 op, [&]() { op->setOperands(adaptor.getOperands()); }); 33 return success(); 34 } 35 }; 36 37 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>; 38 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>; 39 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>; 40 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>; 41 using ScalableMaskedAddIOpLowering = 42 OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp, 43 ScalableMaskedAddIIntrOp>; 44 using ScalableMaskedAddFOpLowering = 45 OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp, 46 ScalableMaskedAddFIntrOp>; 47 using ScalableMaskedSubIOpLowering = 48 OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp, 49 ScalableMaskedSubIIntrOp>; 50 using ScalableMaskedSubFOpLowering = 51 OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp, 52 ScalableMaskedSubFIntrOp>; 53 using ScalableMaskedMulIOpLowering = 54 OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp, 55 ScalableMaskedMulIIntrOp>; 56 using ScalableMaskedMulFOpLowering = 57 OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp, 58 ScalableMaskedMulFIntrOp>; 59 using ScalableMaskedSDivIOpLowering = 60 OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp, 61 ScalableMaskedSDivIIntrOp>; 62 using ScalableMaskedUDivIOpLowering = 63 OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp, 64 ScalableMaskedUDivIIntrOp>; 65 using ScalableMaskedDivFOpLowering = 66 OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp, 67 ScalableMaskedDivFIntrOp>; 68 69 /// Populate the given list with patterns that convert from ArmSVE to LLVM. 70 void mlir::populateArmSVELegalizeForLLVMExportPatterns( 71 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 72 // Populate conversion patterns 73 74 // clang-format off 75 patterns.add<ForwardOperands<func::CallOp>, 76 ForwardOperands<func::CallIndirectOp>, 77 ForwardOperands<func::ReturnOp>>(converter, 78 &converter.getContext()); 79 patterns.add<SdotOpLowering, 80 SmmlaOpLowering, 81 UdotOpLowering, 82 UmmlaOpLowering, 83 ScalableMaskedAddIOpLowering, 84 ScalableMaskedAddFOpLowering, 85 ScalableMaskedSubIOpLowering, 86 ScalableMaskedSubFOpLowering, 87 ScalableMaskedMulIOpLowering, 88 ScalableMaskedMulFOpLowering, 89 ScalableMaskedSDivIOpLowering, 90 ScalableMaskedUDivIOpLowering, 91 ScalableMaskedDivFOpLowering>(converter); 92 // clang-format on 93 } 94 95 void mlir::configureArmSVELegalizeForExportTarget( 96 LLVMConversionTarget &target) { 97 // clang-format off 98 target.addLegalOp<SdotIntrOp, 99 SmmlaIntrOp, 100 UdotIntrOp, 101 UmmlaIntrOp, 102 ScalableMaskedAddIIntrOp, 103 ScalableMaskedAddFIntrOp, 104 ScalableMaskedSubIIntrOp, 105 ScalableMaskedSubFIntrOp, 106 ScalableMaskedMulIIntrOp, 107 ScalableMaskedMulFIntrOp, 108 ScalableMaskedSDivIIntrOp, 109 ScalableMaskedUDivIIntrOp, 110 ScalableMaskedDivFIntrOp>(); 111 target.addIllegalOp<SdotOp, 112 SmmlaOp, 113 UdotOp, 114 UmmlaOp, 115 ScalableMaskedAddIOp, 116 ScalableMaskedAddFOp, 117 ScalableMaskedSubIOp, 118 ScalableMaskedSubFOp, 119 ScalableMaskedMulIOp, 120 ScalableMaskedMulFOp, 121 ScalableMaskedSDivIOp, 122 ScalableMaskedUDivIOp, 123 ScalableMaskedDivFOp>(); 124 // clang-format on 125 } 126