1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
10 #include "mlir/Conversion/LLVMCommon/Pattern.h"
11 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
12 #include "mlir/Dialect/ArmSVE/Transforms.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::arm_sve;
20 
21 template <typename OpTy>
22 class ForwardOperands : public OpConversionPattern<OpTy> {
23   using OpConversionPattern<OpTy>::OpConversionPattern;
24 
25   LogicalResult
26   matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
27                   ConversionPatternRewriter &rewriter) const final {
28     if (adaptor.getOperands().getTypes() == op->getOperands().getTypes())
29       return rewriter.notifyMatchFailure(op, "operand types already match");
30 
31     rewriter.updateRootInPlace(
32         op, [&]() { op->setOperands(adaptor.getOperands()); });
33     return success();
34   }
35 };
36 
37 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
38 public:
39   using OpConversionPattern<ReturnOp>::OpConversionPattern;
40 
41   LogicalResult
42   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
43                   ConversionPatternRewriter &rewriter) const final {
44     rewriter.updateRootInPlace(
45         op, [&]() { op->setOperands(adaptor.getOperands()); });
46     return success();
47   }
48 };
49 
50 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
51 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
52 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
53 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
54 using ScalableMaskedAddIOpLowering =
55     OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
56                                  ScalableMaskedAddIIntrOp>;
57 using ScalableMaskedAddFOpLowering =
58     OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
59                                  ScalableMaskedAddFIntrOp>;
60 using ScalableMaskedSubIOpLowering =
61     OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
62                                  ScalableMaskedSubIIntrOp>;
63 using ScalableMaskedSubFOpLowering =
64     OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
65                                  ScalableMaskedSubFIntrOp>;
66 using ScalableMaskedMulIOpLowering =
67     OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
68                                  ScalableMaskedMulIIntrOp>;
69 using ScalableMaskedMulFOpLowering =
70     OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
71                                  ScalableMaskedMulFIntrOp>;
72 using ScalableMaskedSDivIOpLowering =
73     OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
74                                  ScalableMaskedSDivIIntrOp>;
75 using ScalableMaskedUDivIOpLowering =
76     OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
77                                  ScalableMaskedUDivIIntrOp>;
78 using ScalableMaskedDivFOpLowering =
79     OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
80                                  ScalableMaskedDivFIntrOp>;
81 
82 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
83 void mlir::populateArmSVELegalizeForLLVMExportPatterns(
84     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
85   // Populate conversion patterns
86 
87   // clang-format off
88   patterns.add<ForwardOperands<CallOp>,
89                ForwardOperands<CallIndirectOp>,
90                ForwardOperands<ReturnOp>>(converter,
91                                           &converter.getContext());
92   patterns.add<SdotOpLowering,
93                SmmlaOpLowering,
94                UdotOpLowering,
95                UmmlaOpLowering,
96                ScalableMaskedAddIOpLowering,
97                ScalableMaskedAddFOpLowering,
98                ScalableMaskedSubIOpLowering,
99                ScalableMaskedSubFOpLowering,
100                ScalableMaskedMulIOpLowering,
101                ScalableMaskedMulFOpLowering,
102                ScalableMaskedSDivIOpLowering,
103                ScalableMaskedUDivIOpLowering,
104                ScalableMaskedDivFOpLowering>(converter);
105   // clang-format on
106 }
107 
108 void mlir::configureArmSVELegalizeForExportTarget(
109     LLVMConversionTarget &target) {
110   // clang-format off
111   target.addLegalOp<SdotIntrOp,
112                     SmmlaIntrOp,
113                     UdotIntrOp,
114                     UmmlaIntrOp,
115                     ScalableMaskedAddIIntrOp,
116                     ScalableMaskedAddFIntrOp,
117                     ScalableMaskedSubIIntrOp,
118                     ScalableMaskedSubFIntrOp,
119                     ScalableMaskedMulIIntrOp,
120                     ScalableMaskedMulFIntrOp,
121                     ScalableMaskedSDivIIntrOp,
122                     ScalableMaskedUDivIIntrOp,
123                     ScalableMaskedDivFIntrOp>();
124   target.addIllegalOp<SdotOp,
125                       SmmlaOp,
126                       UdotOp,
127                       UmmlaOp,
128                       ScalableMaskedAddIOp,
129                       ScalableMaskedAddFOp,
130                       ScalableMaskedSubIOp,
131                       ScalableMaskedSubFOp,
132                       ScalableMaskedMulIOp,
133                       ScalableMaskedMulFOp,
134                       ScalableMaskedSDivIOp,
135                       ScalableMaskedUDivIOp,
136                       ScalableMaskedDivFOp>();
137   // clang-format on
138 }
139