1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
10 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
11 #include "mlir/Dialect/ArmSVE/Transforms.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/PatternMatch.h"
16 
17 using namespace mlir;
18 using namespace mlir::arm_sve;
19 
20 // Extract an LLVM IR type from the LLVM IR dialect type.
21 static Type unwrap(Type type) {
22   if (!type)
23     return nullptr;
24   auto *mlirContext = type.getContext();
25   if (!LLVM::isCompatibleType(type))
26     emitError(UnknownLoc::get(mlirContext),
27               "conversion resulted in a non-LLVM type");
28   return type;
29 }
30 
31 static Optional<Type>
32 convertScalableVectorTypeToLLVM(ScalableVectorType svType,
33                                 LLVMTypeConverter &converter) {
34   auto elementType = unwrap(converter.convertType(svType.getElementType()));
35   if (!elementType)
36     return {};
37 
38   auto sVectorType =
39       LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back());
40   return sVectorType;
41 }
42 
43 template <typename OpTy>
44 class ForwardOperands : public OpConversionPattern<OpTy> {
45   using OpConversionPattern<OpTy>::OpConversionPattern;
46 
47   LogicalResult
48   matchAndRewrite(OpTy op, ArrayRef<Value> operands,
49                   ConversionPatternRewriter &rewriter) const final {
50     if (ValueRange(operands).getTypes() == op->getOperands().getTypes())
51       return rewriter.notifyMatchFailure(op, "operand types already match");
52 
53     rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
54     return success();
55   }
56 };
57 
58 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
59 public:
60   using OpConversionPattern<ReturnOp>::OpConversionPattern;
61 
62   LogicalResult
63   matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
64                   ConversionPatternRewriter &rewriter) const final {
65     rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
66     return success();
67   }
68 };
69 
70 static Optional<Value> addUnrealizedCast(OpBuilder &builder,
71                                          ScalableVectorType svType,
72                                          ValueRange inputs, Location loc) {
73   if (inputs.size() != 1 ||
74       !inputs[0].getType().isa<LLVM::LLVMScalableVectorType>())
75     return Value();
76   return builder.create<UnrealizedConversionCastOp>(loc, svType, inputs)
77       .getResult(0);
78 }
79 
80 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
81 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
82 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
83 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
84 using VectorScaleOpLowering =
85     OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>;
86 
87 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
88 void mlir::populateArmSVELegalizeForLLVMExportPatterns(
89     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
90   // Populate conversion patterns
91   // Remove any ArmSVE-specific types from function signatures and results.
92   populateFuncOpTypeConversionPattern(patterns, converter);
93   converter.addConversion([&converter](ScalableVectorType svType) {
94     return convertScalableVectorTypeToLLVM(svType, converter);
95   });
96   converter.addSourceMaterialization(addUnrealizedCast);
97 
98   // clang-format off
99   patterns.add<ForwardOperands<CallOp>,
100                ForwardOperands<CallIndirectOp>,
101                ForwardOperands<ReturnOp>>(converter,
102                                           &converter.getContext());
103   patterns.add<SdotOpLowering,
104                SmmlaOpLowering,
105                UdotOpLowering,
106                UmmlaOpLowering,
107                VectorScaleOpLowering>(converter);
108   // clang-format on
109 }
110 
111 void mlir::configureArmSVELegalizeForExportTarget(
112     LLVMConversionTarget &target) {
113   target.addLegalOp<SdotIntrOp>();
114   target.addIllegalOp<SdotOp>();
115   target.addLegalOp<SmmlaIntrOp>();
116   target.addIllegalOp<SmmlaOp>();
117   target.addLegalOp<UdotIntrOp>();
118   target.addIllegalOp<UdotOp>();
119   target.addLegalOp<UmmlaIntrOp>();
120   target.addIllegalOp<UmmlaOp>();
121   target.addLegalOp<VectorScaleIntrOp>();
122   target.addIllegalOp<VectorScaleOp>();
123   auto hasScalableVectorType = [](TypeRange types) {
124     for (Type type : types)
125       if (type.isa<arm_sve::ScalableVectorType>())
126         return true;
127     return false;
128   };
129   target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) {
130     return !hasScalableVectorType(op.getType().getInputs()) &&
131            !hasScalableVectorType(op.getType().getResults());
132   });
133   target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>(
134       [hasScalableVectorType](Operation *op) {
135         return !hasScalableVectorType(op->getOperandTypes()) &&
136                !hasScalableVectorType(op->getResultTypes());
137       });
138 }
139