1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 10 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" 11 #include "mlir/Dialect/ArmSVE/Transforms.h" 12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 13 #include "mlir/Dialect/StandardOps/IR/Ops.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/PatternMatch.h" 16 17 using namespace mlir; 18 using namespace mlir::arm_sve; 19 20 // Extract an LLVM IR type from the LLVM IR dialect type. 21 static Type unwrap(Type type) { 22 if (!type) 23 return nullptr; 24 auto *mlirContext = type.getContext(); 25 if (!LLVM::isCompatibleType(type)) 26 emitError(UnknownLoc::get(mlirContext), 27 "conversion resulted in a non-LLVM type"); 28 return type; 29 } 30 31 static Optional<Type> 32 convertScalableVectorTypeToLLVM(ScalableVectorType svType, 33 LLVMTypeConverter &converter) { 34 auto elementType = unwrap(converter.convertType(svType.getElementType())); 35 if (!elementType) 36 return {}; 37 38 auto sVectorType = 39 LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back()); 40 return sVectorType; 41 } 42 43 template <typename OpTy> 44 class ForwardOperands : public OpConversionPattern<OpTy> { 45 using OpConversionPattern<OpTy>::OpConversionPattern; 46 47 LogicalResult 48 matchAndRewrite(OpTy op, ArrayRef<Value> operands, 49 ConversionPatternRewriter &rewriter) const final { 50 if (ValueRange(operands).getTypes() == op->getOperands().getTypes()) 51 return rewriter.notifyMatchFailure(op, "operand types already match"); 52 53 rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); 54 return success(); 55 } 56 }; 57 58 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> { 59 public: 60 using OpConversionPattern<ReturnOp>::OpConversionPattern; 61 62 LogicalResult 63 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 64 ConversionPatternRewriter &rewriter) const final { 65 rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); 66 return success(); 67 } 68 }; 69 70 static Optional<Value> addUnrealizedCast(OpBuilder &builder, 71 ScalableVectorType svType, 72 ValueRange inputs, Location loc) { 73 if (inputs.size() != 1 || 74 !inputs[0].getType().isa<LLVM::LLVMScalableVectorType>()) 75 return Value(); 76 return builder.create<UnrealizedConversionCastOp>(loc, svType, inputs) 77 .getResult(0); 78 } 79 80 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>; 81 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>; 82 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>; 83 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>; 84 using VectorScaleOpLowering = 85 OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>; 86 using ScalableMaskedAddIOpLowering = 87 OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp, 88 ScalableMaskedAddIIntrOp>; 89 using ScalableMaskedAddFOpLowering = 90 OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp, 91 ScalableMaskedAddFIntrOp>; 92 using ScalableMaskedSubIOpLowering = 93 OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp, 94 ScalableMaskedSubIIntrOp>; 95 using ScalableMaskedSubFOpLowering = 96 OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp, 97 ScalableMaskedSubFIntrOp>; 98 using ScalableMaskedMulIOpLowering = 99 OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp, 100 ScalableMaskedMulIIntrOp>; 101 using ScalableMaskedMulFOpLowering = 102 OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp, 103 ScalableMaskedMulFIntrOp>; 104 using ScalableMaskedSDivIOpLowering = 105 OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp, 106 ScalableMaskedSDivIIntrOp>; 107 using ScalableMaskedUDivIOpLowering = 108 OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp, 109 ScalableMaskedUDivIIntrOp>; 110 using ScalableMaskedDivFOpLowering = 111 OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp, 112 ScalableMaskedDivFIntrOp>; 113 114 static void 115 populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter, 116 OwningRewritePatternList &patterns) { 117 // clang-format off 118 patterns.add<OneToOneConvertToLLVMPattern<ScalableAddIOp, LLVM::AddOp>, 119 OneToOneConvertToLLVMPattern<ScalableAddFOp, LLVM::FAddOp>, 120 OneToOneConvertToLLVMPattern<ScalableSubIOp, LLVM::SubOp>, 121 OneToOneConvertToLLVMPattern<ScalableSubFOp, LLVM::FSubOp>, 122 OneToOneConvertToLLVMPattern<ScalableMulIOp, LLVM::MulOp>, 123 OneToOneConvertToLLVMPattern<ScalableMulFOp, LLVM::FMulOp>, 124 OneToOneConvertToLLVMPattern<ScalableSDivIOp, LLVM::SDivOp>, 125 OneToOneConvertToLLVMPattern<ScalableUDivIOp, LLVM::UDivOp>, 126 OneToOneConvertToLLVMPattern<ScalableDivFOp, LLVM::FDivOp> 127 >(converter); 128 // clang-format on 129 } 130 131 static void 132 configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) { 133 // clang-format off 134 target.addIllegalOp<ScalableAddIOp, 135 ScalableAddFOp, 136 ScalableSubIOp, 137 ScalableSubFOp, 138 ScalableMulIOp, 139 ScalableMulFOp, 140 ScalableSDivIOp, 141 ScalableUDivIOp, 142 ScalableDivFOp>(); 143 // clang-format on 144 } 145 146 /// Populate the given list with patterns that convert from ArmSVE to LLVM. 147 void mlir::populateArmSVELegalizeForLLVMExportPatterns( 148 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 149 // Populate conversion patterns 150 // Remove any ArmSVE-specific types from function signatures and results. 151 populateFuncOpTypeConversionPattern(patterns, converter); 152 converter.addConversion([&converter](ScalableVectorType svType) { 153 return convertScalableVectorTypeToLLVM(svType, converter); 154 }); 155 converter.addSourceMaterialization(addUnrealizedCast); 156 157 // clang-format off 158 patterns.add<ForwardOperands<CallOp>, 159 ForwardOperands<CallIndirectOp>, 160 ForwardOperands<ReturnOp>>(converter, 161 &converter.getContext()); 162 patterns.add<SdotOpLowering, 163 SmmlaOpLowering, 164 UdotOpLowering, 165 UmmlaOpLowering, 166 VectorScaleOpLowering, 167 ScalableMaskedAddIOpLowering, 168 ScalableMaskedAddFOpLowering, 169 ScalableMaskedSubIOpLowering, 170 ScalableMaskedSubFOpLowering, 171 ScalableMaskedMulIOpLowering, 172 ScalableMaskedMulFOpLowering, 173 ScalableMaskedSDivIOpLowering, 174 ScalableMaskedUDivIOpLowering, 175 ScalableMaskedDivFOpLowering>(converter); 176 // clang-format on 177 populateBasicSVEArithmeticExportPatterns(converter, patterns); 178 } 179 180 void mlir::configureArmSVELegalizeForExportTarget( 181 LLVMConversionTarget &target) { 182 // clang-format off 183 target.addLegalOp<SdotIntrOp, 184 SmmlaIntrOp, 185 UdotIntrOp, 186 UmmlaIntrOp, 187 VectorScaleIntrOp, 188 ScalableMaskedAddIIntrOp, 189 ScalableMaskedAddFIntrOp, 190 ScalableMaskedSubIIntrOp, 191 ScalableMaskedSubFIntrOp, 192 ScalableMaskedMulIIntrOp, 193 ScalableMaskedMulFIntrOp, 194 ScalableMaskedSDivIIntrOp, 195 ScalableMaskedUDivIIntrOp, 196 ScalableMaskedDivFIntrOp>(); 197 target.addIllegalOp<SdotOp, 198 SmmlaOp, 199 UdotOp, 200 UmmlaOp, 201 VectorScaleOp, 202 ScalableMaskedAddIOp, 203 ScalableMaskedAddFOp, 204 ScalableMaskedSubIOp, 205 ScalableMaskedSubFOp, 206 ScalableMaskedMulIOp, 207 ScalableMaskedMulFOp, 208 ScalableMaskedSDivIOp, 209 ScalableMaskedUDivIOp, 210 ScalableMaskedDivFOp>(); 211 // clang-format on 212 auto hasScalableVectorType = [](TypeRange types) { 213 for (Type type : types) 214 if (type.isa<arm_sve::ScalableVectorType>()) 215 return true; 216 return false; 217 }; 218 target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) { 219 return !hasScalableVectorType(op.getType().getInputs()) && 220 !hasScalableVectorType(op.getType().getResults()); 221 }); 222 target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>( 223 [hasScalableVectorType](Operation *op) { 224 return !hasScalableVectorType(op->getOperandTypes()) && 225 !hasScalableVectorType(op->getResultTypes()); 226 }); 227 configureBasicSVEArithmeticLegalizations(target); 228 } 229