1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 10 #include "mlir/Conversion/LLVMCommon/Pattern.h" 11 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" 12 #include "mlir/Dialect/ArmSVE/Transforms.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/PatternMatch.h" 17 18 using namespace mlir; 19 using namespace mlir::arm_sve; 20 21 // Extract an LLVM IR type from the LLVM IR dialect type. 22 static Type unwrap(Type type) { 23 if (!type) 24 return nullptr; 25 auto *mlirContext = type.getContext(); 26 if (!LLVM::isCompatibleType(type)) 27 emitError(UnknownLoc::get(mlirContext), 28 "conversion resulted in a non-LLVM type"); 29 return type; 30 } 31 32 static Optional<Type> 33 convertScalableVectorTypeToLLVM(ScalableVectorType svType, 34 LLVMTypeConverter &converter) { 35 auto elementType = unwrap(converter.convertType(svType.getElementType())); 36 if (!elementType) 37 return {}; 38 39 auto sVectorType = 40 LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back()); 41 return sVectorType; 42 } 43 44 template <typename OpTy> 45 class ForwardOperands : public OpConversionPattern<OpTy> { 46 using OpConversionPattern<OpTy>::OpConversionPattern; 47 48 LogicalResult 49 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, 50 ConversionPatternRewriter &rewriter) const final { 51 if (adaptor.getOperands().getTypes() == op->getOperands().getTypes()) 52 return rewriter.notifyMatchFailure(op, "operand types already match"); 53 54 rewriter.updateRootInPlace( 55 op, [&]() { op->setOperands(adaptor.getOperands()); }); 56 return success(); 57 } 58 }; 59 60 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> { 61 public: 62 using OpConversionPattern<ReturnOp>::OpConversionPattern; 63 64 LogicalResult 65 matchAndRewrite(ReturnOp op, OpAdaptor adaptor, 66 ConversionPatternRewriter &rewriter) const final { 67 rewriter.updateRootInPlace( 68 op, [&]() { op->setOperands(adaptor.getOperands()); }); 69 return success(); 70 } 71 }; 72 73 static Optional<Value> addUnrealizedCast(OpBuilder &builder, 74 ScalableVectorType svType, 75 ValueRange inputs, Location loc) { 76 if (inputs.size() != 1 || 77 !inputs[0].getType().isa<LLVM::LLVMScalableVectorType>()) 78 return Value(); 79 return builder.create<UnrealizedConversionCastOp>(loc, svType, inputs) 80 .getResult(0); 81 } 82 83 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>; 84 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>; 85 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>; 86 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>; 87 using VectorScaleOpLowering = 88 OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>; 89 using ScalableMaskedAddIOpLowering = 90 OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp, 91 ScalableMaskedAddIIntrOp>; 92 using ScalableMaskedAddFOpLowering = 93 OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp, 94 ScalableMaskedAddFIntrOp>; 95 using ScalableMaskedSubIOpLowering = 96 OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp, 97 ScalableMaskedSubIIntrOp>; 98 using ScalableMaskedSubFOpLowering = 99 OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp, 100 ScalableMaskedSubFIntrOp>; 101 using ScalableMaskedMulIOpLowering = 102 OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp, 103 ScalableMaskedMulIIntrOp>; 104 using ScalableMaskedMulFOpLowering = 105 OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp, 106 ScalableMaskedMulFIntrOp>; 107 using ScalableMaskedSDivIOpLowering = 108 OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp, 109 ScalableMaskedSDivIIntrOp>; 110 using ScalableMaskedUDivIOpLowering = 111 OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp, 112 ScalableMaskedUDivIIntrOp>; 113 using ScalableMaskedDivFOpLowering = 114 OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp, 115 ScalableMaskedDivFIntrOp>; 116 117 // Load operation is lowered to code that obtains a pointer to the indexed 118 // element and loads from it. 119 struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern<ScalableLoadOp> { 120 using ConvertOpToLLVMPattern<ScalableLoadOp>::ConvertOpToLLVMPattern; 121 122 LogicalResult 123 matchAndRewrite(ScalableLoadOp loadOp, OpAdaptor adaptor, 124 ConversionPatternRewriter &rewriter) const override { 125 auto type = loadOp.getMemRefType(); 126 if (!isConvertibleAndHasIdentityMaps(type)) 127 return failure(); 128 129 LLVMTypeConverter converter(loadOp.getContext()); 130 131 auto resultType = loadOp.result().getType(); 132 LLVM::LLVMPointerType llvmDataTypePtr; 133 if (resultType.isa<VectorType>()) { 134 llvmDataTypePtr = 135 LLVM::LLVMPointerType::get(resultType.cast<VectorType>()); 136 } else if (resultType.isa<ScalableVectorType>()) { 137 llvmDataTypePtr = LLVM::LLVMPointerType::get( 138 convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(), 139 converter) 140 .getValue()); 141 } 142 Value dataPtr = getStridedElementPtr(loadOp.getLoc(), type, adaptor.base(), 143 adaptor.index(), rewriter); 144 Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>( 145 loadOp.getLoc(), llvmDataTypePtr, dataPtr); 146 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, bitCastedPtr); 147 return success(); 148 } 149 }; 150 151 // Store operation is lowered to code that obtains a pointer to the indexed 152 // element, and stores the given value to it. 153 struct ScalableStoreOpLowering 154 : public ConvertOpToLLVMPattern<ScalableStoreOp> { 155 using ConvertOpToLLVMPattern<ScalableStoreOp>::ConvertOpToLLVMPattern; 156 157 LogicalResult 158 matchAndRewrite(ScalableStoreOp storeOp, OpAdaptor adaptor, 159 ConversionPatternRewriter &rewriter) const override { 160 auto type = storeOp.getMemRefType(); 161 if (!isConvertibleAndHasIdentityMaps(type)) 162 return failure(); 163 164 LLVMTypeConverter converter(storeOp.getContext()); 165 166 auto resultType = storeOp.value().getType(); 167 LLVM::LLVMPointerType llvmDataTypePtr; 168 if (resultType.isa<VectorType>()) { 169 llvmDataTypePtr = 170 LLVM::LLVMPointerType::get(resultType.cast<VectorType>()); 171 } else if (resultType.isa<ScalableVectorType>()) { 172 llvmDataTypePtr = LLVM::LLVMPointerType::get( 173 convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(), 174 converter) 175 .getValue()); 176 } 177 Value dataPtr = getStridedElementPtr(storeOp.getLoc(), type, adaptor.base(), 178 adaptor.index(), rewriter); 179 Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>( 180 storeOp.getLoc(), llvmDataTypePtr, dataPtr); 181 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.value(), 182 bitCastedPtr); 183 return success(); 184 } 185 }; 186 187 static void 188 populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter, 189 OwningRewritePatternList &patterns) { 190 // clang-format off 191 patterns.add<OneToOneConvertToLLVMPattern<ScalableAddIOp, LLVM::AddOp>, 192 OneToOneConvertToLLVMPattern<ScalableAddFOp, LLVM::FAddOp>, 193 OneToOneConvertToLLVMPattern<ScalableSubIOp, LLVM::SubOp>, 194 OneToOneConvertToLLVMPattern<ScalableSubFOp, LLVM::FSubOp>, 195 OneToOneConvertToLLVMPattern<ScalableMulIOp, LLVM::MulOp>, 196 OneToOneConvertToLLVMPattern<ScalableMulFOp, LLVM::FMulOp>, 197 OneToOneConvertToLLVMPattern<ScalableSDivIOp, LLVM::SDivOp>, 198 OneToOneConvertToLLVMPattern<ScalableUDivIOp, LLVM::UDivOp>, 199 OneToOneConvertToLLVMPattern<ScalableDivFOp, LLVM::FDivOp> 200 >(converter); 201 // clang-format on 202 } 203 204 static void 205 configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) { 206 // clang-format off 207 target.addIllegalOp<ScalableAddIOp, 208 ScalableAddFOp, 209 ScalableSubIOp, 210 ScalableSubFOp, 211 ScalableMulIOp, 212 ScalableMulFOp, 213 ScalableSDivIOp, 214 ScalableUDivIOp, 215 ScalableDivFOp>(); 216 // clang-format on 217 } 218 219 static void 220 populateSVEMaskGenerationExportPatterns(LLVMTypeConverter &converter, 221 OwningRewritePatternList &patterns) { 222 // clang-format off 223 patterns.add<OneToOneConvertToLLVMPattern<ScalableCmpFOp, LLVM::FCmpOp>, 224 OneToOneConvertToLLVMPattern<ScalableCmpIOp, LLVM::ICmpOp> 225 >(converter); 226 // clang-format on 227 } 228 229 static void 230 configureSVEMaskGenerationLegalizations(LLVMConversionTarget &target) { 231 // clang-format off 232 target.addIllegalOp<ScalableCmpFOp, 233 ScalableCmpIOp>(); 234 // clang-format on 235 } 236 237 /// Populate the given list with patterns that convert from ArmSVE to LLVM. 238 void mlir::populateArmSVELegalizeForLLVMExportPatterns( 239 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 240 // Populate conversion patterns 241 // Remove any ArmSVE-specific types from function signatures and results. 242 populateFuncOpTypeConversionPattern(patterns, converter); 243 converter.addConversion([&converter](ScalableVectorType svType) { 244 return convertScalableVectorTypeToLLVM(svType, converter); 245 }); 246 converter.addSourceMaterialization(addUnrealizedCast); 247 248 // clang-format off 249 patterns.add<ForwardOperands<CallOp>, 250 ForwardOperands<CallIndirectOp>, 251 ForwardOperands<ReturnOp>>(converter, 252 &converter.getContext()); 253 patterns.add<SdotOpLowering, 254 SmmlaOpLowering, 255 UdotOpLowering, 256 UmmlaOpLowering, 257 VectorScaleOpLowering, 258 ScalableMaskedAddIOpLowering, 259 ScalableMaskedAddFOpLowering, 260 ScalableMaskedSubIOpLowering, 261 ScalableMaskedSubFOpLowering, 262 ScalableMaskedMulIOpLowering, 263 ScalableMaskedMulFOpLowering, 264 ScalableMaskedSDivIOpLowering, 265 ScalableMaskedUDivIOpLowering, 266 ScalableMaskedDivFOpLowering>(converter); 267 patterns.add<ScalableLoadOpLowering, 268 ScalableStoreOpLowering>(converter); 269 // clang-format on 270 populateBasicSVEArithmeticExportPatterns(converter, patterns); 271 populateSVEMaskGenerationExportPatterns(converter, patterns); 272 } 273 274 void mlir::configureArmSVELegalizeForExportTarget( 275 LLVMConversionTarget &target) { 276 // clang-format off 277 target.addLegalOp<SdotIntrOp, 278 SmmlaIntrOp, 279 UdotIntrOp, 280 UmmlaIntrOp, 281 VectorScaleIntrOp, 282 ScalableMaskedAddIIntrOp, 283 ScalableMaskedAddFIntrOp, 284 ScalableMaskedSubIIntrOp, 285 ScalableMaskedSubFIntrOp, 286 ScalableMaskedMulIIntrOp, 287 ScalableMaskedMulFIntrOp, 288 ScalableMaskedSDivIIntrOp, 289 ScalableMaskedUDivIIntrOp, 290 ScalableMaskedDivFIntrOp>(); 291 target.addIllegalOp<SdotOp, 292 SmmlaOp, 293 UdotOp, 294 UmmlaOp, 295 VectorScaleOp, 296 ScalableMaskedAddIOp, 297 ScalableMaskedAddFOp, 298 ScalableMaskedSubIOp, 299 ScalableMaskedSubFOp, 300 ScalableMaskedMulIOp, 301 ScalableMaskedMulFOp, 302 ScalableMaskedSDivIOp, 303 ScalableMaskedUDivIOp, 304 ScalableMaskedDivFOp, 305 ScalableLoadOp, 306 ScalableStoreOp>(); 307 // clang-format on 308 auto hasScalableVectorType = [](TypeRange types) { 309 for (Type type : types) 310 if (type.isa<arm_sve::ScalableVectorType>()) 311 return true; 312 return false; 313 }; 314 target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) { 315 return !hasScalableVectorType(op.getType().getInputs()) && 316 !hasScalableVectorType(op.getType().getResults()); 317 }); 318 target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>( 319 [hasScalableVectorType](Operation *op) { 320 return !hasScalableVectorType(op->getOperandTypes()) && 321 !hasScalableVectorType(op->getResultTypes()); 322 }); 323 configureBasicSVEArithmeticLegalizations(target); 324 configureSVEMaskGenerationLegalizations(target); 325 } 326