1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
10 #include "mlir/Conversion/LLVMCommon/Pattern.h"
11 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
12 #include "mlir/Dialect/ArmSVE/Transforms.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::arm_sve;
20 
21 // Extract an LLVM IR type from the LLVM IR dialect type.
22 static Type unwrap(Type type) {
23   if (!type)
24     return nullptr;
25   auto *mlirContext = type.getContext();
26   if (!LLVM::isCompatibleType(type))
27     emitError(UnknownLoc::get(mlirContext),
28               "conversion resulted in a non-LLVM type");
29   return type;
30 }
31 
32 static Optional<Type>
33 convertScalableVectorTypeToLLVM(ScalableVectorType svType,
34                                 LLVMTypeConverter &converter) {
35   auto elementType = unwrap(converter.convertType(svType.getElementType()));
36   if (!elementType)
37     return {};
38 
39   auto sVectorType =
40       LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back());
41   return sVectorType;
42 }
43 
44 template <typename OpTy>
45 class ForwardOperands : public OpConversionPattern<OpTy> {
46   using OpConversionPattern<OpTy>::OpConversionPattern;
47 
48   LogicalResult
49   matchAndRewrite(OpTy op, ArrayRef<Value> operands,
50                   ConversionPatternRewriter &rewriter) const final {
51     if (ValueRange(operands).getTypes() == op->getOperands().getTypes())
52       return rewriter.notifyMatchFailure(op, "operand types already match");
53 
54     rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
55     return success();
56   }
57 };
58 
59 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
60 public:
61   using OpConversionPattern<ReturnOp>::OpConversionPattern;
62 
63   LogicalResult
64   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
65                   ConversionPatternRewriter &rewriter) const final {
66     rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
67     return success();
68   }
69 };
70 
71 static Optional<Value> addUnrealizedCast(OpBuilder &builder,
72                                          ScalableVectorType svType,
73                                          ValueRange inputs, Location loc) {
74   if (inputs.size() != 1 ||
75       !inputs[0].getType().isa<LLVM::LLVMScalableVectorType>())
76     return Value();
77   return builder.create<UnrealizedConversionCastOp>(loc, svType, inputs)
78       .getResult(0);
79 }
80 
81 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
82 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
83 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
84 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
85 using VectorScaleOpLowering =
86     OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>;
87 using ScalableMaskedAddIOpLowering =
88     OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
89                                  ScalableMaskedAddIIntrOp>;
90 using ScalableMaskedAddFOpLowering =
91     OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
92                                  ScalableMaskedAddFIntrOp>;
93 using ScalableMaskedSubIOpLowering =
94     OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
95                                  ScalableMaskedSubIIntrOp>;
96 using ScalableMaskedSubFOpLowering =
97     OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
98                                  ScalableMaskedSubFIntrOp>;
99 using ScalableMaskedMulIOpLowering =
100     OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
101                                  ScalableMaskedMulIIntrOp>;
102 using ScalableMaskedMulFOpLowering =
103     OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
104                                  ScalableMaskedMulFIntrOp>;
105 using ScalableMaskedSDivIOpLowering =
106     OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
107                                  ScalableMaskedSDivIIntrOp>;
108 using ScalableMaskedUDivIOpLowering =
109     OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
110                                  ScalableMaskedUDivIIntrOp>;
111 using ScalableMaskedDivFOpLowering =
112     OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
113                                  ScalableMaskedDivFIntrOp>;
114 
115 // Load operation is lowered to code that obtains a pointer to the indexed
116 // element and loads from it.
117 struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern<ScalableLoadOp> {
118   using ConvertOpToLLVMPattern<ScalableLoadOp>::ConvertOpToLLVMPattern;
119 
120   LogicalResult
121   matchAndRewrite(ScalableLoadOp loadOp, ArrayRef<Value> operands,
122                   ConversionPatternRewriter &rewriter) const override {
123     auto type = loadOp.getMemRefType();
124     if (!isConvertibleAndHasIdentityMaps(type))
125       return failure();
126 
127     ScalableLoadOp::Adaptor transformed(operands);
128     LLVMTypeConverter converter(loadOp.getContext());
129 
130     auto resultType = loadOp.result().getType();
131     LLVM::LLVMPointerType llvmDataTypePtr;
132     if (resultType.isa<VectorType>()) {
133       llvmDataTypePtr =
134           LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
135     } else if (resultType.isa<ScalableVectorType>()) {
136       llvmDataTypePtr = LLVM::LLVMPointerType::get(
137           convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
138                                           converter)
139               .getValue());
140     }
141     Value dataPtr =
142         getStridedElementPtr(loadOp.getLoc(), type, transformed.base(),
143                              transformed.index(), rewriter);
144     Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
145         loadOp.getLoc(), llvmDataTypePtr, dataPtr);
146     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, bitCastedPtr);
147     return success();
148   }
149 };
150 
151 // Store operation is lowered to code that obtains a pointer to the indexed
152 // element, and stores the given value to it.
153 struct ScalableStoreOpLowering
154     : public ConvertOpToLLVMPattern<ScalableStoreOp> {
155   using ConvertOpToLLVMPattern<ScalableStoreOp>::ConvertOpToLLVMPattern;
156 
157   LogicalResult
158   matchAndRewrite(ScalableStoreOp storeOp, ArrayRef<Value> operands,
159                   ConversionPatternRewriter &rewriter) const override {
160     auto type = storeOp.getMemRefType();
161     if (!isConvertibleAndHasIdentityMaps(type))
162       return failure();
163 
164     ScalableStoreOp::Adaptor transformed(operands);
165     LLVMTypeConverter converter(storeOp.getContext());
166 
167     auto resultType = storeOp.value().getType();
168     LLVM::LLVMPointerType llvmDataTypePtr;
169     if (resultType.isa<VectorType>()) {
170       llvmDataTypePtr =
171           LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
172     } else if (resultType.isa<ScalableVectorType>()) {
173       llvmDataTypePtr = LLVM::LLVMPointerType::get(
174           convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
175                                           converter)
176               .getValue());
177     }
178     Value dataPtr =
179         getStridedElementPtr(storeOp.getLoc(), type, transformed.base(),
180                              transformed.index(), rewriter);
181     Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
182         storeOp.getLoc(), llvmDataTypePtr, dataPtr);
183     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, transformed.value(),
184                                                bitCastedPtr);
185     return success();
186   }
187 };
188 
189 static void
190 populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter,
191                                          OwningRewritePatternList &patterns) {
192   // clang-format off
193   patterns.add<OneToOneConvertToLLVMPattern<ScalableAddIOp, LLVM::AddOp>,
194                OneToOneConvertToLLVMPattern<ScalableAddFOp, LLVM::FAddOp>,
195                OneToOneConvertToLLVMPattern<ScalableSubIOp, LLVM::SubOp>,
196                OneToOneConvertToLLVMPattern<ScalableSubFOp, LLVM::FSubOp>,
197                OneToOneConvertToLLVMPattern<ScalableMulIOp, LLVM::MulOp>,
198                OneToOneConvertToLLVMPattern<ScalableMulFOp, LLVM::FMulOp>,
199                OneToOneConvertToLLVMPattern<ScalableSDivIOp, LLVM::SDivOp>,
200                OneToOneConvertToLLVMPattern<ScalableUDivIOp, LLVM::UDivOp>,
201                OneToOneConvertToLLVMPattern<ScalableDivFOp, LLVM::FDivOp>
202               >(converter);
203   // clang-format on
204 }
205 
206 static void
207 configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) {
208   // clang-format off
209   target.addIllegalOp<ScalableAddIOp,
210                       ScalableAddFOp,
211                       ScalableSubIOp,
212                       ScalableSubFOp,
213                       ScalableMulIOp,
214                       ScalableMulFOp,
215                       ScalableSDivIOp,
216                       ScalableUDivIOp,
217                       ScalableDivFOp>();
218   // clang-format on
219 }
220 
221 static void
222 populateSVEMaskGenerationExportPatterns(LLVMTypeConverter &converter,
223                                         OwningRewritePatternList &patterns) {
224   // clang-format off
225   patterns.add<OneToOneConvertToLLVMPattern<ScalableCmpFOp, LLVM::FCmpOp>,
226                OneToOneConvertToLLVMPattern<ScalableCmpIOp, LLVM::ICmpOp>
227               >(converter);
228   // clang-format on
229 }
230 
231 static void
232 configureSVEMaskGenerationLegalizations(LLVMConversionTarget &target) {
233   // clang-format off
234   target.addIllegalOp<ScalableCmpFOp,
235                       ScalableCmpIOp>();
236   // clang-format on
237 }
238 
239 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
240 void mlir::populateArmSVELegalizeForLLVMExportPatterns(
241     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
242   // Populate conversion patterns
243   // Remove any ArmSVE-specific types from function signatures and results.
244   populateFuncOpTypeConversionPattern(patterns, converter);
245   converter.addConversion([&converter](ScalableVectorType svType) {
246     return convertScalableVectorTypeToLLVM(svType, converter);
247   });
248   converter.addSourceMaterialization(addUnrealizedCast);
249 
250   // clang-format off
251   patterns.add<ForwardOperands<CallOp>,
252                ForwardOperands<CallIndirectOp>,
253                ForwardOperands<ReturnOp>>(converter,
254                                           &converter.getContext());
255   patterns.add<SdotOpLowering,
256                SmmlaOpLowering,
257                UdotOpLowering,
258                UmmlaOpLowering,
259                VectorScaleOpLowering,
260                ScalableMaskedAddIOpLowering,
261                ScalableMaskedAddFOpLowering,
262                ScalableMaskedSubIOpLowering,
263                ScalableMaskedSubFOpLowering,
264                ScalableMaskedMulIOpLowering,
265                ScalableMaskedMulFOpLowering,
266                ScalableMaskedSDivIOpLowering,
267                ScalableMaskedUDivIOpLowering,
268                ScalableMaskedDivFOpLowering>(converter);
269   patterns.add<ScalableLoadOpLowering,
270                ScalableStoreOpLowering>(converter);
271   // clang-format on
272   populateBasicSVEArithmeticExportPatterns(converter, patterns);
273   populateSVEMaskGenerationExportPatterns(converter, patterns);
274 }
275 
276 void mlir::configureArmSVELegalizeForExportTarget(
277     LLVMConversionTarget &target) {
278   // clang-format off
279   target.addLegalOp<SdotIntrOp,
280                     SmmlaIntrOp,
281                     UdotIntrOp,
282                     UmmlaIntrOp,
283                     VectorScaleIntrOp,
284                     ScalableMaskedAddIIntrOp,
285                     ScalableMaskedAddFIntrOp,
286                     ScalableMaskedSubIIntrOp,
287                     ScalableMaskedSubFIntrOp,
288                     ScalableMaskedMulIIntrOp,
289                     ScalableMaskedMulFIntrOp,
290                     ScalableMaskedSDivIIntrOp,
291                     ScalableMaskedUDivIIntrOp,
292                     ScalableMaskedDivFIntrOp>();
293   target.addIllegalOp<SdotOp,
294                       SmmlaOp,
295                       UdotOp,
296                       UmmlaOp,
297                       VectorScaleOp,
298                       ScalableMaskedAddIOp,
299                       ScalableMaskedAddFOp,
300                       ScalableMaskedSubIOp,
301                       ScalableMaskedSubFOp,
302                       ScalableMaskedMulIOp,
303                       ScalableMaskedMulFOp,
304                       ScalableMaskedSDivIOp,
305                       ScalableMaskedUDivIOp,
306                       ScalableMaskedDivFOp,
307                       ScalableLoadOp,
308                       ScalableStoreOp>();
309   // clang-format on
310   auto hasScalableVectorType = [](TypeRange types) {
311     for (Type type : types)
312       if (type.isa<arm_sve::ScalableVectorType>())
313         return true;
314     return false;
315   };
316   target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) {
317     return !hasScalableVectorType(op.getType().getInputs()) &&
318            !hasScalableVectorType(op.getType().getResults());
319   });
320   target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>(
321       [hasScalableVectorType](Operation *op) {
322         return !hasScalableVectorType(op->getOperandTypes()) &&
323                !hasScalableVectorType(op->getResultTypes());
324       });
325   configureBasicSVEArithmeticLegalizations(target);
326   configureSVEMaskGenerationLegalizations(target);
327 }
328