1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
10 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
11 #include "mlir/Dialect/ArmSVE/Transforms.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/PatternMatch.h"
16 
17 using namespace mlir;
18 using namespace mlir::arm_sve;
19 
20 // Extract an LLVM IR type from the LLVM IR dialect type.
21 static Type unwrap(Type type) {
22   if (!type)
23     return nullptr;
24   auto *mlirContext = type.getContext();
25   if (!LLVM::isCompatibleType(type))
26     emitError(UnknownLoc::get(mlirContext),
27               "conversion resulted in a non-LLVM type");
28   return type;
29 }
30 
31 static Optional<Type>
32 convertScalableVectorTypeToLLVM(ScalableVectorType svType,
33                                 LLVMTypeConverter &converter) {
34   auto elementType = unwrap(converter.convertType(svType.getElementType()));
35   if (!elementType)
36     return {};
37 
38   auto sVectorType =
39       LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back());
40   return sVectorType;
41 }
42 
43 template <typename OpTy>
44 class ForwardOperands : public OpConversionPattern<OpTy> {
45   using OpConversionPattern<OpTy>::OpConversionPattern;
46 
47   LogicalResult
48   matchAndRewrite(OpTy op, ArrayRef<Value> operands,
49                   ConversionPatternRewriter &rewriter) const final {
50     if (ValueRange(operands).getTypes() == op->getOperands().getTypes())
51       return rewriter.notifyMatchFailure(op, "operand types already match");
52 
53     rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
54     return success();
55   }
56 };
57 
58 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
59 public:
60   using OpConversionPattern<ReturnOp>::OpConversionPattern;
61 
62   LogicalResult
63   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
64                   ConversionPatternRewriter &rewriter) const final {
65     rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
66     return success();
67   }
68 };
69 
70 static Optional<Value> addUnrealizedCast(OpBuilder &builder,
71                                          ScalableVectorType svType,
72                                          ValueRange inputs, Location loc) {
73   if (inputs.size() != 1 ||
74       !inputs[0].getType().isa<LLVM::LLVMScalableVectorType>())
75     return Value();
76   return builder.create<UnrealizedConversionCastOp>(loc, svType, inputs)
77       .getResult(0);
78 }
79 
80 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
81 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
82 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
83 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
84 using VectorScaleOpLowering =
85     OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>;
86 using ScalableMaskedAddIOpLowering =
87     OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
88                                  ScalableMaskedAddIIntrOp>;
89 using ScalableMaskedAddFOpLowering =
90     OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
91                                  ScalableMaskedAddFIntrOp>;
92 using ScalableMaskedSubIOpLowering =
93     OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
94                                  ScalableMaskedSubIIntrOp>;
95 using ScalableMaskedSubFOpLowering =
96     OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
97                                  ScalableMaskedSubFIntrOp>;
98 using ScalableMaskedMulIOpLowering =
99     OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
100                                  ScalableMaskedMulIIntrOp>;
101 using ScalableMaskedMulFOpLowering =
102     OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
103                                  ScalableMaskedMulFIntrOp>;
104 using ScalableMaskedSDivIOpLowering =
105     OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
106                                  ScalableMaskedSDivIIntrOp>;
107 using ScalableMaskedUDivIOpLowering =
108     OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
109                                  ScalableMaskedUDivIIntrOp>;
110 using ScalableMaskedDivFOpLowering =
111     OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
112                                  ScalableMaskedDivFIntrOp>;
113 
114 // Load operation is lowered to code that obtains a pointer to the indexed
115 // element and loads from it.
116 struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern<ScalableLoadOp> {
117   using ConvertOpToLLVMPattern<ScalableLoadOp>::ConvertOpToLLVMPattern;
118 
119   LogicalResult
120   matchAndRewrite(ScalableLoadOp loadOp, ArrayRef<Value> operands,
121                   ConversionPatternRewriter &rewriter) const override {
122     auto type = loadOp.getMemRefType();
123     if (!isConvertibleAndHasIdentityMaps(type))
124       return failure();
125 
126     ScalableLoadOp::Adaptor transformed(operands);
127     LLVMTypeConverter converter(loadOp.getContext());
128 
129     auto resultType = loadOp.result().getType();
130     LLVM::LLVMPointerType llvmDataTypePtr;
131     if (resultType.isa<VectorType>()) {
132       llvmDataTypePtr =
133           LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
134     } else if (resultType.isa<ScalableVectorType>()) {
135       llvmDataTypePtr = LLVM::LLVMPointerType::get(
136           convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
137                                           converter)
138               .getValue());
139     }
140     Value dataPtr =
141         getStridedElementPtr(loadOp.getLoc(), type, transformed.base(),
142                              transformed.index(), rewriter);
143     Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
144         loadOp.getLoc(), llvmDataTypePtr, dataPtr);
145     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, bitCastedPtr);
146     return success();
147   }
148 };
149 
150 // Store operation is lowered to code that obtains a pointer to the indexed
151 // element, and stores the given value to it.
152 struct ScalableStoreOpLowering
153     : public ConvertOpToLLVMPattern<ScalableStoreOp> {
154   using ConvertOpToLLVMPattern<ScalableStoreOp>::ConvertOpToLLVMPattern;
155 
156   LogicalResult
157   matchAndRewrite(ScalableStoreOp storeOp, ArrayRef<Value> operands,
158                   ConversionPatternRewriter &rewriter) const override {
159     auto type = storeOp.getMemRefType();
160     if (!isConvertibleAndHasIdentityMaps(type))
161       return failure();
162 
163     ScalableStoreOp::Adaptor transformed(operands);
164     LLVMTypeConverter converter(storeOp.getContext());
165 
166     auto resultType = storeOp.value().getType();
167     LLVM::LLVMPointerType llvmDataTypePtr;
168     if (resultType.isa<VectorType>()) {
169       llvmDataTypePtr =
170           LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
171     } else if (resultType.isa<ScalableVectorType>()) {
172       llvmDataTypePtr = LLVM::LLVMPointerType::get(
173           convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
174                                           converter)
175               .getValue());
176     }
177     Value dataPtr =
178         getStridedElementPtr(storeOp.getLoc(), type, transformed.base(),
179                              transformed.index(), rewriter);
180     Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
181         storeOp.getLoc(), llvmDataTypePtr, dataPtr);
182     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, transformed.value(),
183                                                bitCastedPtr);
184     return success();
185   }
186 };
187 
188 static void
189 populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter,
190                                          OwningRewritePatternList &patterns) {
191   // clang-format off
192   patterns.add<OneToOneConvertToLLVMPattern<ScalableAddIOp, LLVM::AddOp>,
193                OneToOneConvertToLLVMPattern<ScalableAddFOp, LLVM::FAddOp>,
194                OneToOneConvertToLLVMPattern<ScalableSubIOp, LLVM::SubOp>,
195                OneToOneConvertToLLVMPattern<ScalableSubFOp, LLVM::FSubOp>,
196                OneToOneConvertToLLVMPattern<ScalableMulIOp, LLVM::MulOp>,
197                OneToOneConvertToLLVMPattern<ScalableMulFOp, LLVM::FMulOp>,
198                OneToOneConvertToLLVMPattern<ScalableSDivIOp, LLVM::SDivOp>,
199                OneToOneConvertToLLVMPattern<ScalableUDivIOp, LLVM::UDivOp>,
200                OneToOneConvertToLLVMPattern<ScalableDivFOp, LLVM::FDivOp>
201               >(converter);
202   // clang-format on
203 }
204 
205 static void
206 configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) {
207   // clang-format off
208   target.addIllegalOp<ScalableAddIOp,
209                       ScalableAddFOp,
210                       ScalableSubIOp,
211                       ScalableSubFOp,
212                       ScalableMulIOp,
213                       ScalableMulFOp,
214                       ScalableSDivIOp,
215                       ScalableUDivIOp,
216                       ScalableDivFOp>();
217   // clang-format on
218 }
219 
220 static void
221 populateSVEMaskGenerationExportPatterns(LLVMTypeConverter &converter,
222                                         OwningRewritePatternList &patterns) {
223   // clang-format off
224   patterns.add<OneToOneConvertToLLVMPattern<ScalableCmpFOp, LLVM::FCmpOp>,
225                OneToOneConvertToLLVMPattern<ScalableCmpIOp, LLVM::ICmpOp>
226               >(converter);
227   // clang-format on
228 }
229 
230 static void
231 configureSVEMaskGenerationLegalizations(LLVMConversionTarget &target) {
232   // clang-format off
233   target.addIllegalOp<ScalableCmpFOp,
234                       ScalableCmpIOp>();
235   // clang-format on
236 }
237 
238 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
239 void mlir::populateArmSVELegalizeForLLVMExportPatterns(
240     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
241   // Populate conversion patterns
242   // Remove any ArmSVE-specific types from function signatures and results.
243   populateFuncOpTypeConversionPattern(patterns, converter);
244   converter.addConversion([&converter](ScalableVectorType svType) {
245     return convertScalableVectorTypeToLLVM(svType, converter);
246   });
247   converter.addSourceMaterialization(addUnrealizedCast);
248 
249   // clang-format off
250   patterns.add<ForwardOperands<CallOp>,
251                ForwardOperands<CallIndirectOp>,
252                ForwardOperands<ReturnOp>>(converter,
253                                           &converter.getContext());
254   patterns.add<SdotOpLowering,
255                SmmlaOpLowering,
256                UdotOpLowering,
257                UmmlaOpLowering,
258                VectorScaleOpLowering,
259                ScalableMaskedAddIOpLowering,
260                ScalableMaskedAddFOpLowering,
261                ScalableMaskedSubIOpLowering,
262                ScalableMaskedSubFOpLowering,
263                ScalableMaskedMulIOpLowering,
264                ScalableMaskedMulFOpLowering,
265                ScalableMaskedSDivIOpLowering,
266                ScalableMaskedUDivIOpLowering,
267                ScalableMaskedDivFOpLowering>(converter);
268   patterns.add<ScalableLoadOpLowering,
269                ScalableStoreOpLowering>(converter);
270   // clang-format on
271   populateBasicSVEArithmeticExportPatterns(converter, patterns);
272   populateSVEMaskGenerationExportPatterns(converter, patterns);
273 }
274 
275 void mlir::configureArmSVELegalizeForExportTarget(
276     LLVMConversionTarget &target) {
277   // clang-format off
278   target.addLegalOp<SdotIntrOp,
279                     SmmlaIntrOp,
280                     UdotIntrOp,
281                     UmmlaIntrOp,
282                     VectorScaleIntrOp,
283                     ScalableMaskedAddIIntrOp,
284                     ScalableMaskedAddFIntrOp,
285                     ScalableMaskedSubIIntrOp,
286                     ScalableMaskedSubFIntrOp,
287                     ScalableMaskedMulIIntrOp,
288                     ScalableMaskedMulFIntrOp,
289                     ScalableMaskedSDivIIntrOp,
290                     ScalableMaskedUDivIIntrOp,
291                     ScalableMaskedDivFIntrOp>();
292   target.addIllegalOp<SdotOp,
293                       SmmlaOp,
294                       UdotOp,
295                       UmmlaOp,
296                       VectorScaleOp,
297                       ScalableMaskedAddIOp,
298                       ScalableMaskedAddFOp,
299                       ScalableMaskedSubIOp,
300                       ScalableMaskedSubFOp,
301                       ScalableMaskedMulIOp,
302                       ScalableMaskedMulFOp,
303                       ScalableMaskedSDivIOp,
304                       ScalableMaskedUDivIOp,
305                       ScalableMaskedDivFOp,
306                       ScalableLoadOp,
307                       ScalableStoreOp>();
308   // clang-format on
309   auto hasScalableVectorType = [](TypeRange types) {
310     for (Type type : types)
311       if (type.isa<arm_sve::ScalableVectorType>())
312         return true;
313     return false;
314   };
315   target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) {
316     return !hasScalableVectorType(op.getType().getInputs()) &&
317            !hasScalableVectorType(op.getType().getResults());
318   });
319   target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>(
320       [hasScalableVectorType](Operation *op) {
321         return !hasScalableVectorType(op->getOperandTypes()) &&
322                !hasScalableVectorType(op->getResultTypes());
323       });
324   configureBasicSVEArithmeticLegalizations(target);
325   configureSVEMaskGenerationLegalizations(target);
326 }
327