1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
10 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
11 #include "mlir/Dialect/ArmSVE/Transforms.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/PatternMatch.h"
16 
17 using namespace mlir;
18 using namespace mlir::arm_sve;
19 
20 // Extract an LLVM IR type from the LLVM IR dialect type.
21 static Type unwrap(Type type) {
22   if (!type)
23     return nullptr;
24   auto *mlirContext = type.getContext();
25   if (!LLVM::isCompatibleType(type))
26     emitError(UnknownLoc::get(mlirContext),
27               "conversion resulted in a non-LLVM type");
28   return type;
29 }
30 
31 static Optional<Type>
32 convertScalableVectorTypeToLLVM(ScalableVectorType svType,
33                                 LLVMTypeConverter &converter) {
34   auto elementType = unwrap(converter.convertType(svType.getElementType()));
35   if (!elementType)
36     return {};
37 
38   auto sVectorType =
39       LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back());
40   return sVectorType;
41 }
42 
43 template <typename OpTy>
44 class ForwardOperands : public OpConversionPattern<OpTy> {
45   using OpConversionPattern<OpTy>::OpConversionPattern;
46 
47   LogicalResult
48   matchAndRewrite(OpTy op, ArrayRef<Value> operands,
49                   ConversionPatternRewriter &rewriter) const final {
50     if (ValueRange(operands).getTypes() == op->getOperands().getTypes())
51       return rewriter.notifyMatchFailure(op, "operand types already match");
52 
53     rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
54     return success();
55   }
56 };
57 
58 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
59 public:
60   using OpConversionPattern<ReturnOp>::OpConversionPattern;
61 
62   LogicalResult
63   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
64                   ConversionPatternRewriter &rewriter) const final {
65     rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
66     return success();
67   }
68 };
69 
70 static Optional<Value> addUnrealizedCast(OpBuilder &builder,
71                                          ScalableVectorType svType,
72                                          ValueRange inputs, Location loc) {
73   if (inputs.size() != 1 ||
74       !inputs[0].getType().isa<LLVM::LLVMScalableVectorType>())
75     return Value();
76   return builder.create<UnrealizedConversionCastOp>(loc, svType, inputs)
77       .getResult(0);
78 }
79 
80 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
81 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
82 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
83 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
84 using VectorScaleOpLowering =
85     OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>;
86 using ScalableMaskedAddIOpLowering =
87     OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
88                                  ScalableMaskedAddIIntrOp>;
89 using ScalableMaskedAddFOpLowering =
90     OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
91                                  ScalableMaskedAddFIntrOp>;
92 using ScalableMaskedSubIOpLowering =
93     OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
94                                  ScalableMaskedSubIIntrOp>;
95 using ScalableMaskedSubFOpLowering =
96     OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
97                                  ScalableMaskedSubFIntrOp>;
98 using ScalableMaskedMulIOpLowering =
99     OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
100                                  ScalableMaskedMulIIntrOp>;
101 using ScalableMaskedMulFOpLowering =
102     OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
103                                  ScalableMaskedMulFIntrOp>;
104 using ScalableMaskedSDivIOpLowering =
105     OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
106                                  ScalableMaskedSDivIIntrOp>;
107 using ScalableMaskedUDivIOpLowering =
108     OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
109                                  ScalableMaskedUDivIIntrOp>;
110 using ScalableMaskedDivFOpLowering =
111     OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
112                                  ScalableMaskedDivFIntrOp>;
113 
114 static void
115 populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter,
116                                          OwningRewritePatternList &patterns) {
117   // clang-format off
118   patterns.add<OneToOneConvertToLLVMPattern<ScalableAddIOp, LLVM::AddOp>,
119                OneToOneConvertToLLVMPattern<ScalableAddFOp, LLVM::FAddOp>,
120                OneToOneConvertToLLVMPattern<ScalableSubIOp, LLVM::SubOp>,
121                OneToOneConvertToLLVMPattern<ScalableSubFOp, LLVM::FSubOp>,
122                OneToOneConvertToLLVMPattern<ScalableMulIOp, LLVM::MulOp>,
123                OneToOneConvertToLLVMPattern<ScalableMulFOp, LLVM::FMulOp>,
124                OneToOneConvertToLLVMPattern<ScalableSDivIOp, LLVM::SDivOp>,
125                OneToOneConvertToLLVMPattern<ScalableUDivIOp, LLVM::UDivOp>,
126                OneToOneConvertToLLVMPattern<ScalableDivFOp, LLVM::FDivOp>
127               >(converter);
128   // clang-format on
129 }
130 
131 static void
132 configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) {
133   // clang-format off
134   target.addIllegalOp<ScalableAddIOp,
135                       ScalableAddFOp,
136                       ScalableSubIOp,
137                       ScalableSubFOp,
138                       ScalableMulIOp,
139                       ScalableMulFOp,
140                       ScalableSDivIOp,
141                       ScalableUDivIOp,
142                       ScalableDivFOp>();
143   // clang-format on
144 }
145 
146 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
147 void mlir::populateArmSVELegalizeForLLVMExportPatterns(
148     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
149   // Populate conversion patterns
150   // Remove any ArmSVE-specific types from function signatures and results.
151   populateFuncOpTypeConversionPattern(patterns, converter);
152   converter.addConversion([&converter](ScalableVectorType svType) {
153     return convertScalableVectorTypeToLLVM(svType, converter);
154   });
155   converter.addSourceMaterialization(addUnrealizedCast);
156 
157   // clang-format off
158   patterns.add<ForwardOperands<CallOp>,
159                ForwardOperands<CallIndirectOp>,
160                ForwardOperands<ReturnOp>>(converter,
161                                           &converter.getContext());
162   patterns.add<SdotOpLowering,
163                SmmlaOpLowering,
164                UdotOpLowering,
165                UmmlaOpLowering,
166                VectorScaleOpLowering,
167                ScalableMaskedAddIOpLowering,
168                ScalableMaskedAddFOpLowering,
169                ScalableMaskedSubIOpLowering,
170                ScalableMaskedSubFOpLowering,
171                ScalableMaskedMulIOpLowering,
172                ScalableMaskedMulFOpLowering,
173                ScalableMaskedSDivIOpLowering,
174                ScalableMaskedUDivIOpLowering,
175                ScalableMaskedDivFOpLowering>(converter);
176   // clang-format on
177   populateBasicSVEArithmeticExportPatterns(converter, patterns);
178 }
179 
180 void mlir::configureArmSVELegalizeForExportTarget(
181     LLVMConversionTarget &target) {
182   // clang-format off
183   target.addLegalOp<SdotIntrOp,
184                     SmmlaIntrOp,
185                     UdotIntrOp,
186                     UmmlaIntrOp,
187                     VectorScaleIntrOp,
188                     ScalableMaskedAddIIntrOp,
189                     ScalableMaskedAddFIntrOp,
190                     ScalableMaskedSubIIntrOp,
191                     ScalableMaskedSubFIntrOp,
192                     ScalableMaskedMulIIntrOp,
193                     ScalableMaskedMulFIntrOp,
194                     ScalableMaskedSDivIIntrOp,
195                     ScalableMaskedUDivIIntrOp,
196                     ScalableMaskedDivFIntrOp>();
197   target.addIllegalOp<SdotOp,
198                       SmmlaOp,
199                       UdotOp,
200                       UmmlaOp,
201                       VectorScaleOp,
202                       ScalableMaskedAddIOp,
203                       ScalableMaskedAddFOp,
204                       ScalableMaskedSubIOp,
205                       ScalableMaskedSubFOp,
206                       ScalableMaskedMulIOp,
207                       ScalableMaskedMulFOp,
208                       ScalableMaskedSDivIOp,
209                       ScalableMaskedUDivIOp,
210                       ScalableMaskedDivFOp>();
211   // clang-format on
212   auto hasScalableVectorType = [](TypeRange types) {
213     for (Type type : types)
214       if (type.isa<arm_sve::ScalableVectorType>())
215         return true;
216     return false;
217   };
218   target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) {
219     return !hasScalableVectorType(op.getType().getInputs()) &&
220            !hasScalableVectorType(op.getType().getResults());
221   });
222   target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>(
223       [hasScalableVectorType](Operation *op) {
224         return !hasScalableVectorType(op->getOperandTypes()) &&
225                !hasScalableVectorType(op->getResultTypes());
226       });
227   configureBasicSVEArithmeticLegalizations(target);
228 }
229