1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
10 #include "mlir/Conversion/LLVMCommon/Pattern.h"
11 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
12 #include "mlir/Dialect/ArmSVE/Transforms.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/PatternMatch.h"
16 
17 using namespace mlir;
18 using namespace mlir::arm_sve;
19 
20 template <typename OpTy>
21 class ForwardOperands : public OpConversionPattern<OpTy> {
22   using OpConversionPattern<OpTy>::OpConversionPattern;
23 
24   LogicalResult
25   matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
26                   ConversionPatternRewriter &rewriter) const final {
27     if (adaptor.getOperands().getTypes() == op->getOperands().getTypes())
28       return rewriter.notifyMatchFailure(op, "operand types already match");
29 
30     rewriter.updateRootInPlace(
31         op, [&]() { op->setOperands(adaptor.getOperands()); });
32     return success();
33   }
34 };
35 
36 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
37 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
38 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
39 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
40 using ScalableMaskedAddIOpLowering =
41     OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
42                                  ScalableMaskedAddIIntrOp>;
43 using ScalableMaskedAddFOpLowering =
44     OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
45                                  ScalableMaskedAddFIntrOp>;
46 using ScalableMaskedSubIOpLowering =
47     OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
48                                  ScalableMaskedSubIIntrOp>;
49 using ScalableMaskedSubFOpLowering =
50     OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
51                                  ScalableMaskedSubFIntrOp>;
52 using ScalableMaskedMulIOpLowering =
53     OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
54                                  ScalableMaskedMulIIntrOp>;
55 using ScalableMaskedMulFOpLowering =
56     OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
57                                  ScalableMaskedMulFIntrOp>;
58 using ScalableMaskedSDivIOpLowering =
59     OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
60                                  ScalableMaskedSDivIIntrOp>;
61 using ScalableMaskedUDivIOpLowering =
62     OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
63                                  ScalableMaskedUDivIIntrOp>;
64 using ScalableMaskedDivFOpLowering =
65     OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
66                                  ScalableMaskedDivFIntrOp>;
67 
68 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
69 void mlir::populateArmSVELegalizeForLLVMExportPatterns(
70     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
71   // Populate conversion patterns
72 
73   // clang-format off
74   patterns.add<ForwardOperands<CallOp>,
75                ForwardOperands<CallIndirectOp>,
76                ForwardOperands<ReturnOp>>(converter,
77                                           &converter.getContext());
78   patterns.add<SdotOpLowering,
79                SmmlaOpLowering,
80                UdotOpLowering,
81                UmmlaOpLowering,
82                ScalableMaskedAddIOpLowering,
83                ScalableMaskedAddFOpLowering,
84                ScalableMaskedSubIOpLowering,
85                ScalableMaskedSubFOpLowering,
86                ScalableMaskedMulIOpLowering,
87                ScalableMaskedMulFOpLowering,
88                ScalableMaskedSDivIOpLowering,
89                ScalableMaskedUDivIOpLowering,
90                ScalableMaskedDivFOpLowering>(converter);
91   // clang-format on
92 }
93 
94 void mlir::configureArmSVELegalizeForExportTarget(
95     LLVMConversionTarget &target) {
96   // clang-format off
97   target.addLegalOp<SdotIntrOp,
98                     SmmlaIntrOp,
99                     UdotIntrOp,
100                     UmmlaIntrOp,
101                     ScalableMaskedAddIIntrOp,
102                     ScalableMaskedAddFIntrOp,
103                     ScalableMaskedSubIIntrOp,
104                     ScalableMaskedSubFIntrOp,
105                     ScalableMaskedMulIIntrOp,
106                     ScalableMaskedMulFIntrOp,
107                     ScalableMaskedSDivIIntrOp,
108                     ScalableMaskedUDivIIntrOp,
109                     ScalableMaskedDivFIntrOp>();
110   target.addIllegalOp<SdotOp,
111                       SmmlaOp,
112                       UdotOp,
113                       UmmlaOp,
114                       ScalableMaskedAddIOp,
115                       ScalableMaskedAddFOp,
116                       ScalableMaskedSubIOp,
117                       ScalableMaskedSubFOp,
118                       ScalableMaskedMulIOp,
119                       ScalableMaskedMulFOp,
120                       ScalableMaskedSDivIOp,
121                       ScalableMaskedUDivIOp,
122                       ScalableMaskedDivFOp>();
123   // clang-format on
124 }
125