1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
10 #include "mlir/Conversion/LLVMCommon/Pattern.h"
11 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
12 #include "mlir/Dialect/ArmSVE/Transforms.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::arm_sve;
20 
21 template <typename OpTy>
22 class ForwardOperands : public OpConversionPattern<OpTy> {
23   using OpConversionPattern<OpTy>::OpConversionPattern;
24 
25   LogicalResult
matchAndRewrite(OpTy op,typename OpTy::Adaptor adaptor,ConversionPatternRewriter & rewriter) const26   matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
27                   ConversionPatternRewriter &rewriter) const final {
28     if (adaptor.getOperands().getTypes() == op->getOperands().getTypes())
29       return rewriter.notifyMatchFailure(op, "operand types already match");
30 
31     rewriter.updateRootInPlace(
32         op, [&]() { op->setOperands(adaptor.getOperands()); });
33     return success();
34   }
35 };
36 
37 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
38 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
39 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
40 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
41 using ScalableMaskedAddIOpLowering =
42     OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
43                                  ScalableMaskedAddIIntrOp>;
44 using ScalableMaskedAddFOpLowering =
45     OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
46                                  ScalableMaskedAddFIntrOp>;
47 using ScalableMaskedSubIOpLowering =
48     OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
49                                  ScalableMaskedSubIIntrOp>;
50 using ScalableMaskedSubFOpLowering =
51     OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
52                                  ScalableMaskedSubFIntrOp>;
53 using ScalableMaskedMulIOpLowering =
54     OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
55                                  ScalableMaskedMulIIntrOp>;
56 using ScalableMaskedMulFOpLowering =
57     OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
58                                  ScalableMaskedMulFIntrOp>;
59 using ScalableMaskedSDivIOpLowering =
60     OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
61                                  ScalableMaskedSDivIIntrOp>;
62 using ScalableMaskedUDivIOpLowering =
63     OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
64                                  ScalableMaskedUDivIIntrOp>;
65 using ScalableMaskedDivFOpLowering =
66     OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
67                                  ScalableMaskedDivFIntrOp>;
68 
69 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
populateArmSVELegalizeForLLVMExportPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)70 void mlir::populateArmSVELegalizeForLLVMExportPatterns(
71     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
72   // Populate conversion patterns
73 
74   // clang-format off
75   patterns.add<ForwardOperands<func::CallOp>,
76                ForwardOperands<func::CallIndirectOp>,
77                ForwardOperands<func::ReturnOp>>(converter,
78                                           &converter.getContext());
79   patterns.add<SdotOpLowering,
80                SmmlaOpLowering,
81                UdotOpLowering,
82                UmmlaOpLowering,
83                ScalableMaskedAddIOpLowering,
84                ScalableMaskedAddFOpLowering,
85                ScalableMaskedSubIOpLowering,
86                ScalableMaskedSubFOpLowering,
87                ScalableMaskedMulIOpLowering,
88                ScalableMaskedMulFOpLowering,
89                ScalableMaskedSDivIOpLowering,
90                ScalableMaskedUDivIOpLowering,
91                ScalableMaskedDivFOpLowering>(converter);
92   // clang-format on
93 }
94 
configureArmSVELegalizeForExportTarget(LLVMConversionTarget & target)95 void mlir::configureArmSVELegalizeForExportTarget(
96     LLVMConversionTarget &target) {
97   // clang-format off
98   target.addLegalOp<SdotIntrOp,
99                     SmmlaIntrOp,
100                     UdotIntrOp,
101                     UmmlaIntrOp,
102                     ScalableMaskedAddIIntrOp,
103                     ScalableMaskedAddFIntrOp,
104                     ScalableMaskedSubIIntrOp,
105                     ScalableMaskedSubFIntrOp,
106                     ScalableMaskedMulIIntrOp,
107                     ScalableMaskedMulFIntrOp,
108                     ScalableMaskedSDivIIntrOp,
109                     ScalableMaskedUDivIIntrOp,
110                     ScalableMaskedDivFIntrOp>();
111   target.addIllegalOp<SdotOp,
112                       SmmlaOp,
113                       UdotOp,
114                       UmmlaOp,
115                       ScalableMaskedAddIOp,
116                       ScalableMaskedAddFOp,
117                       ScalableMaskedSubIOp,
118                       ScalableMaskedSubFOp,
119                       ScalableMaskedMulIOp,
120                       ScalableMaskedMulFOp,
121                       ScalableMaskedSDivIOp,
122                       ScalableMaskedUDivIOp,
123                       ScalableMaskedDivFOp>();
124   // clang-format on
125 }
126