1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 10 #include "mlir/Conversion/LLVMCommon/Pattern.h" 11 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" 12 #include "mlir/Dialect/ArmSVE/Transforms.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/PatternMatch.h" 16 17 using namespace mlir; 18 using namespace mlir::arm_sve; 19 20 template <typename OpTy> 21 class ForwardOperands : public OpConversionPattern<OpTy> { 22 using OpConversionPattern<OpTy>::OpConversionPattern; 23 24 LogicalResult 25 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, 26 ConversionPatternRewriter &rewriter) const final { 27 if (adaptor.getOperands().getTypes() == op->getOperands().getTypes()) 28 return rewriter.notifyMatchFailure(op, "operand types already match"); 29 30 rewriter.updateRootInPlace( 31 op, [&]() { op->setOperands(adaptor.getOperands()); }); 32 return success(); 33 } 34 }; 35 36 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>; 37 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>; 38 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>; 39 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>; 40 using ScalableMaskedAddIOpLowering = 41 OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp, 42 ScalableMaskedAddIIntrOp>; 43 using ScalableMaskedAddFOpLowering = 44 OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp, 45 ScalableMaskedAddFIntrOp>; 46 using ScalableMaskedSubIOpLowering = 47 OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp, 48 ScalableMaskedSubIIntrOp>; 49 using ScalableMaskedSubFOpLowering = 50 OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp, 51 ScalableMaskedSubFIntrOp>; 52 using ScalableMaskedMulIOpLowering = 53 OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp, 54 ScalableMaskedMulIIntrOp>; 55 using ScalableMaskedMulFOpLowering = 56 OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp, 57 ScalableMaskedMulFIntrOp>; 58 using ScalableMaskedSDivIOpLowering = 59 OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp, 60 ScalableMaskedSDivIIntrOp>; 61 using ScalableMaskedUDivIOpLowering = 62 OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp, 63 ScalableMaskedUDivIIntrOp>; 64 using ScalableMaskedDivFOpLowering = 65 OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp, 66 ScalableMaskedDivFIntrOp>; 67 68 /// Populate the given list with patterns that convert from ArmSVE to LLVM. 69 void mlir::populateArmSVELegalizeForLLVMExportPatterns( 70 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 71 // Populate conversion patterns 72 73 // clang-format off 74 patterns.add<ForwardOperands<CallOp>, 75 ForwardOperands<CallIndirectOp>, 76 ForwardOperands<ReturnOp>>(converter, 77 &converter.getContext()); 78 patterns.add<SdotOpLowering, 79 SmmlaOpLowering, 80 UdotOpLowering, 81 UmmlaOpLowering, 82 ScalableMaskedAddIOpLowering, 83 ScalableMaskedAddFOpLowering, 84 ScalableMaskedSubIOpLowering, 85 ScalableMaskedSubFOpLowering, 86 ScalableMaskedMulIOpLowering, 87 ScalableMaskedMulFOpLowering, 88 ScalableMaskedSDivIOpLowering, 89 ScalableMaskedUDivIOpLowering, 90 ScalableMaskedDivFOpLowering>(converter); 91 // clang-format on 92 } 93 94 void mlir::configureArmSVELegalizeForExportTarget( 95 LLVMConversionTarget &target) { 96 // clang-format off 97 target.addLegalOp<SdotIntrOp, 98 SmmlaIntrOp, 99 UdotIntrOp, 100 UmmlaIntrOp, 101 ScalableMaskedAddIIntrOp, 102 ScalableMaskedAddFIntrOp, 103 ScalableMaskedSubIIntrOp, 104 ScalableMaskedSubFIntrOp, 105 ScalableMaskedMulIIntrOp, 106 ScalableMaskedMulFIntrOp, 107 ScalableMaskedSDivIIntrOp, 108 ScalableMaskedUDivIIntrOp, 109 ScalableMaskedDivFIntrOp>(); 110 target.addIllegalOp<SdotOp, 111 SmmlaOp, 112 UdotOp, 113 UmmlaOp, 114 ScalableMaskedAddIOp, 115 ScalableMaskedAddFOp, 116 ScalableMaskedSubIOp, 117 ScalableMaskedSubFOp, 118 ScalableMaskedMulIOp, 119 ScalableMaskedMulFOp, 120 ScalableMaskedSDivIOp, 121 ScalableMaskedUDivIOp, 122 ScalableMaskedDivFOp>(); 123 // clang-format on 124 } 125