1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
10 #include "mlir/Conversion/LLVMCommon/Pattern.h"
11 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
12 #include "mlir/Dialect/ArmSVE/Transforms.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::arm_sve;
20 
21 // Extract an LLVM IR type from the LLVM IR dialect type.
22 static Type unwrap(Type type) {
23   if (!type)
24     return nullptr;
25   auto *mlirContext = type.getContext();
26   if (!LLVM::isCompatibleType(type))
27     emitError(UnknownLoc::get(mlirContext),
28               "conversion resulted in a non-LLVM type");
29   return type;
30 }
31 
32 static Optional<Type>
33 convertScalableVectorTypeToLLVM(ScalableVectorType svType,
34                                 LLVMTypeConverter &converter) {
35   auto elementType = unwrap(converter.convertType(svType.getElementType()));
36   if (!elementType)
37     return {};
38 
39   auto sVectorType =
40       LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back());
41   return sVectorType;
42 }
43 
44 template <typename OpTy>
45 class ForwardOperands : public OpConversionPattern<OpTy> {
46   using OpConversionPattern<OpTy>::OpConversionPattern;
47 
48   LogicalResult
49   matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
50                   ConversionPatternRewriter &rewriter) const final {
51     if (adaptor.getOperands().getTypes() == op->getOperands().getTypes())
52       return rewriter.notifyMatchFailure(op, "operand types already match");
53 
54     rewriter.updateRootInPlace(
55         op, [&]() { op->setOperands(adaptor.getOperands()); });
56     return success();
57   }
58 };
59 
60 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
61 public:
62   using OpConversionPattern<ReturnOp>::OpConversionPattern;
63 
64   LogicalResult
65   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
66                   ConversionPatternRewriter &rewriter) const final {
67     rewriter.updateRootInPlace(
68         op, [&]() { op->setOperands(adaptor.getOperands()); });
69     return success();
70   }
71 };
72 
73 static Optional<Value> addUnrealizedCast(OpBuilder &builder,
74                                          ScalableVectorType svType,
75                                          ValueRange inputs, Location loc) {
76   if (inputs.size() != 1 ||
77       !inputs[0].getType().isa<LLVM::LLVMScalableVectorType>())
78     return Value();
79   return builder.create<UnrealizedConversionCastOp>(loc, svType, inputs)
80       .getResult(0);
81 }
82 
83 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
84 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
85 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
86 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
87 using VectorScaleOpLowering =
88     OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>;
89 using ScalableMaskedAddIOpLowering =
90     OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
91                                  ScalableMaskedAddIIntrOp>;
92 using ScalableMaskedAddFOpLowering =
93     OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
94                                  ScalableMaskedAddFIntrOp>;
95 using ScalableMaskedSubIOpLowering =
96     OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
97                                  ScalableMaskedSubIIntrOp>;
98 using ScalableMaskedSubFOpLowering =
99     OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
100                                  ScalableMaskedSubFIntrOp>;
101 using ScalableMaskedMulIOpLowering =
102     OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
103                                  ScalableMaskedMulIIntrOp>;
104 using ScalableMaskedMulFOpLowering =
105     OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
106                                  ScalableMaskedMulFIntrOp>;
107 using ScalableMaskedSDivIOpLowering =
108     OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
109                                  ScalableMaskedSDivIIntrOp>;
110 using ScalableMaskedUDivIOpLowering =
111     OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
112                                  ScalableMaskedUDivIIntrOp>;
113 using ScalableMaskedDivFOpLowering =
114     OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
115                                  ScalableMaskedDivFIntrOp>;
116 
117 // Load operation is lowered to code that obtains a pointer to the indexed
118 // element and loads from it.
119 struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern<ScalableLoadOp> {
120   using ConvertOpToLLVMPattern<ScalableLoadOp>::ConvertOpToLLVMPattern;
121 
122   LogicalResult
123   matchAndRewrite(ScalableLoadOp loadOp, OpAdaptor adaptor,
124                   ConversionPatternRewriter &rewriter) const override {
125     auto type = loadOp.getMemRefType();
126     if (!isConvertibleAndHasIdentityMaps(type))
127       return failure();
128 
129     LLVMTypeConverter converter(loadOp.getContext());
130 
131     auto resultType = loadOp.result().getType();
132     LLVM::LLVMPointerType llvmDataTypePtr;
133     if (resultType.isa<VectorType>()) {
134       llvmDataTypePtr =
135           LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
136     } else if (resultType.isa<ScalableVectorType>()) {
137       llvmDataTypePtr = LLVM::LLVMPointerType::get(
138           convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
139                                           converter)
140               .getValue());
141     }
142     Value dataPtr = getStridedElementPtr(loadOp.getLoc(), type, adaptor.base(),
143                                          adaptor.index(), rewriter);
144     Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
145         loadOp.getLoc(), llvmDataTypePtr, dataPtr);
146     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, bitCastedPtr);
147     return success();
148   }
149 };
150 
151 // Store operation is lowered to code that obtains a pointer to the indexed
152 // element, and stores the given value to it.
153 struct ScalableStoreOpLowering
154     : public ConvertOpToLLVMPattern<ScalableStoreOp> {
155   using ConvertOpToLLVMPattern<ScalableStoreOp>::ConvertOpToLLVMPattern;
156 
157   LogicalResult
158   matchAndRewrite(ScalableStoreOp storeOp, OpAdaptor adaptor,
159                   ConversionPatternRewriter &rewriter) const override {
160     auto type = storeOp.getMemRefType();
161     if (!isConvertibleAndHasIdentityMaps(type))
162       return failure();
163 
164     LLVMTypeConverter converter(storeOp.getContext());
165 
166     auto resultType = storeOp.value().getType();
167     LLVM::LLVMPointerType llvmDataTypePtr;
168     if (resultType.isa<VectorType>()) {
169       llvmDataTypePtr =
170           LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
171     } else if (resultType.isa<ScalableVectorType>()) {
172       llvmDataTypePtr = LLVM::LLVMPointerType::get(
173           convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
174                                           converter)
175               .getValue());
176     }
177     Value dataPtr = getStridedElementPtr(storeOp.getLoc(), type, adaptor.base(),
178                                          adaptor.index(), rewriter);
179     Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
180         storeOp.getLoc(), llvmDataTypePtr, dataPtr);
181     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.value(),
182                                                bitCastedPtr);
183     return success();
184   }
185 };
186 
187 static void
188 populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter,
189                                          OwningRewritePatternList &patterns) {
190   // clang-format off
191   patterns.add<OneToOneConvertToLLVMPattern<ScalableAddIOp, LLVM::AddOp>,
192                OneToOneConvertToLLVMPattern<ScalableAddFOp, LLVM::FAddOp>,
193                OneToOneConvertToLLVMPattern<ScalableSubIOp, LLVM::SubOp>,
194                OneToOneConvertToLLVMPattern<ScalableSubFOp, LLVM::FSubOp>,
195                OneToOneConvertToLLVMPattern<ScalableMulIOp, LLVM::MulOp>,
196                OneToOneConvertToLLVMPattern<ScalableMulFOp, LLVM::FMulOp>,
197                OneToOneConvertToLLVMPattern<ScalableSDivIOp, LLVM::SDivOp>,
198                OneToOneConvertToLLVMPattern<ScalableUDivIOp, LLVM::UDivOp>,
199                OneToOneConvertToLLVMPattern<ScalableDivFOp, LLVM::FDivOp>
200               >(converter);
201   // clang-format on
202 }
203 
204 static void
205 configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) {
206   // clang-format off
207   target.addIllegalOp<ScalableAddIOp,
208                       ScalableAddFOp,
209                       ScalableSubIOp,
210                       ScalableSubFOp,
211                       ScalableMulIOp,
212                       ScalableMulFOp,
213                       ScalableSDivIOp,
214                       ScalableUDivIOp,
215                       ScalableDivFOp>();
216   // clang-format on
217 }
218 
219 static void
220 populateSVEMaskGenerationExportPatterns(LLVMTypeConverter &converter,
221                                         OwningRewritePatternList &patterns) {
222   // clang-format off
223   patterns.add<OneToOneConvertToLLVMPattern<ScalableCmpFOp, LLVM::FCmpOp>,
224                OneToOneConvertToLLVMPattern<ScalableCmpIOp, LLVM::ICmpOp>
225               >(converter);
226   // clang-format on
227 }
228 
229 static void
230 configureSVEMaskGenerationLegalizations(LLVMConversionTarget &target) {
231   // clang-format off
232   target.addIllegalOp<ScalableCmpFOp,
233                       ScalableCmpIOp>();
234   // clang-format on
235 }
236 
237 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
238 void mlir::populateArmSVELegalizeForLLVMExportPatterns(
239     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
240   // Populate conversion patterns
241   // Remove any ArmSVE-specific types from function signatures and results.
242   populateFuncOpTypeConversionPattern(patterns, converter);
243   converter.addConversion([&converter](ScalableVectorType svType) {
244     return convertScalableVectorTypeToLLVM(svType, converter);
245   });
246   converter.addSourceMaterialization(addUnrealizedCast);
247 
248   // clang-format off
249   patterns.add<ForwardOperands<CallOp>,
250                ForwardOperands<CallIndirectOp>,
251                ForwardOperands<ReturnOp>>(converter,
252                                           &converter.getContext());
253   patterns.add<SdotOpLowering,
254                SmmlaOpLowering,
255                UdotOpLowering,
256                UmmlaOpLowering,
257                VectorScaleOpLowering,
258                ScalableMaskedAddIOpLowering,
259                ScalableMaskedAddFOpLowering,
260                ScalableMaskedSubIOpLowering,
261                ScalableMaskedSubFOpLowering,
262                ScalableMaskedMulIOpLowering,
263                ScalableMaskedMulFOpLowering,
264                ScalableMaskedSDivIOpLowering,
265                ScalableMaskedUDivIOpLowering,
266                ScalableMaskedDivFOpLowering>(converter);
267   patterns.add<ScalableLoadOpLowering,
268                ScalableStoreOpLowering>(converter);
269   // clang-format on
270   populateBasicSVEArithmeticExportPatterns(converter, patterns);
271   populateSVEMaskGenerationExportPatterns(converter, patterns);
272 }
273 
274 void mlir::configureArmSVELegalizeForExportTarget(
275     LLVMConversionTarget &target) {
276   // clang-format off
277   target.addLegalOp<SdotIntrOp,
278                     SmmlaIntrOp,
279                     UdotIntrOp,
280                     UmmlaIntrOp,
281                     VectorScaleIntrOp,
282                     ScalableMaskedAddIIntrOp,
283                     ScalableMaskedAddFIntrOp,
284                     ScalableMaskedSubIIntrOp,
285                     ScalableMaskedSubFIntrOp,
286                     ScalableMaskedMulIIntrOp,
287                     ScalableMaskedMulFIntrOp,
288                     ScalableMaskedSDivIIntrOp,
289                     ScalableMaskedUDivIIntrOp,
290                     ScalableMaskedDivFIntrOp>();
291   target.addIllegalOp<SdotOp,
292                       SmmlaOp,
293                       UdotOp,
294                       UmmlaOp,
295                       VectorScaleOp,
296                       ScalableMaskedAddIOp,
297                       ScalableMaskedAddFOp,
298                       ScalableMaskedSubIOp,
299                       ScalableMaskedSubFOp,
300                       ScalableMaskedMulIOp,
301                       ScalableMaskedMulFOp,
302                       ScalableMaskedSDivIOp,
303                       ScalableMaskedUDivIOp,
304                       ScalableMaskedDivFOp,
305                       ScalableLoadOp,
306                       ScalableStoreOp>();
307   // clang-format on
308   auto hasScalableVectorType = [](TypeRange types) {
309     for (Type type : types)
310       if (type.isa<arm_sve::ScalableVectorType>())
311         return true;
312     return false;
313   };
314   target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) {
315     return !hasScalableVectorType(op.getType().getInputs()) &&
316            !hasScalableVectorType(op.getType().getResults());
317   });
318   target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>(
319       [hasScalableVectorType](Operation *op) {
320         return !hasScalableVectorType(op->getOperandTypes()) &&
321                !hasScalableVectorType(op->getResultTypes());
322       });
323   configureBasicSVEArithmeticLegalizations(target);
324   configureSVEMaskGenerationLegalizations(target);
325 }
326