1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
10 #include "mlir/Conversion/LLVMCommon/Pattern.h"
11 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
12 #include "mlir/Dialect/ArmSVE/Transforms.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/PatternMatch.h"
17
18 using namespace mlir;
19 using namespace mlir::arm_sve;
20
21 template <typename OpTy>
22 class ForwardOperands : public OpConversionPattern<OpTy> {
23 using OpConversionPattern<OpTy>::OpConversionPattern;
24
25 LogicalResult
matchAndRewrite(OpTy op,typename OpTy::Adaptor adaptor,ConversionPatternRewriter & rewriter) const26 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
27 ConversionPatternRewriter &rewriter) const final {
28 if (adaptor.getOperands().getTypes() == op->getOperands().getTypes())
29 return rewriter.notifyMatchFailure(op, "operand types already match");
30
31 rewriter.updateRootInPlace(
32 op, [&]() { op->setOperands(adaptor.getOperands()); });
33 return success();
34 }
35 };
36
37 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
38 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
39 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
40 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
41 using ScalableMaskedAddIOpLowering =
42 OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
43 ScalableMaskedAddIIntrOp>;
44 using ScalableMaskedAddFOpLowering =
45 OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
46 ScalableMaskedAddFIntrOp>;
47 using ScalableMaskedSubIOpLowering =
48 OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
49 ScalableMaskedSubIIntrOp>;
50 using ScalableMaskedSubFOpLowering =
51 OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
52 ScalableMaskedSubFIntrOp>;
53 using ScalableMaskedMulIOpLowering =
54 OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
55 ScalableMaskedMulIIntrOp>;
56 using ScalableMaskedMulFOpLowering =
57 OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
58 ScalableMaskedMulFIntrOp>;
59 using ScalableMaskedSDivIOpLowering =
60 OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
61 ScalableMaskedSDivIIntrOp>;
62 using ScalableMaskedUDivIOpLowering =
63 OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
64 ScalableMaskedUDivIIntrOp>;
65 using ScalableMaskedDivFOpLowering =
66 OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
67 ScalableMaskedDivFIntrOp>;
68
69 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
populateArmSVELegalizeForLLVMExportPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)70 void mlir::populateArmSVELegalizeForLLVMExportPatterns(
71 LLVMTypeConverter &converter, RewritePatternSet &patterns) {
72 // Populate conversion patterns
73
74 // clang-format off
75 patterns.add<ForwardOperands<func::CallOp>,
76 ForwardOperands<func::CallIndirectOp>,
77 ForwardOperands<func::ReturnOp>>(converter,
78 &converter.getContext());
79 patterns.add<SdotOpLowering,
80 SmmlaOpLowering,
81 UdotOpLowering,
82 UmmlaOpLowering,
83 ScalableMaskedAddIOpLowering,
84 ScalableMaskedAddFOpLowering,
85 ScalableMaskedSubIOpLowering,
86 ScalableMaskedSubFOpLowering,
87 ScalableMaskedMulIOpLowering,
88 ScalableMaskedMulFOpLowering,
89 ScalableMaskedSDivIOpLowering,
90 ScalableMaskedUDivIOpLowering,
91 ScalableMaskedDivFOpLowering>(converter);
92 // clang-format on
93 }
94
configureArmSVELegalizeForExportTarget(LLVMConversionTarget & target)95 void mlir::configureArmSVELegalizeForExportTarget(
96 LLVMConversionTarget &target) {
97 // clang-format off
98 target.addLegalOp<SdotIntrOp,
99 SmmlaIntrOp,
100 UdotIntrOp,
101 UmmlaIntrOp,
102 ScalableMaskedAddIIntrOp,
103 ScalableMaskedAddFIntrOp,
104 ScalableMaskedSubIIntrOp,
105 ScalableMaskedSubFIntrOp,
106 ScalableMaskedMulIIntrOp,
107 ScalableMaskedMulFIntrOp,
108 ScalableMaskedSDivIIntrOp,
109 ScalableMaskedUDivIIntrOp,
110 ScalableMaskedDivFIntrOp>();
111 target.addIllegalOp<SdotOp,
112 SmmlaOp,
113 UdotOp,
114 UmmlaOp,
115 ScalableMaskedAddIOp,
116 ScalableMaskedAddFOp,
117 ScalableMaskedSubIOp,
118 ScalableMaskedSubFOp,
119 ScalableMaskedMulIOp,
120 ScalableMaskedMulFOp,
121 ScalableMaskedSDivIOp,
122 ScalableMaskedUDivIOp,
123 ScalableMaskedDivFOp>();
124 // clang-format on
125 }
126