1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 10 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" 11 #include "mlir/Dialect/ArmSVE/Transforms.h" 12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 13 #include "mlir/Dialect/StandardOps/IR/Ops.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/PatternMatch.h" 16 17 using namespace mlir; 18 using namespace mlir::arm_sve; 19 20 // Extract an LLVM IR type from the LLVM IR dialect type. 21 static Type unwrap(Type type) { 22 if (!type) 23 return nullptr; 24 auto *mlirContext = type.getContext(); 25 if (!LLVM::isCompatibleType(type)) 26 emitError(UnknownLoc::get(mlirContext), 27 "conversion resulted in a non-LLVM type"); 28 return type; 29 } 30 31 static Optional<Type> 32 convertScalableVectorTypeToLLVM(ScalableVectorType svType, 33 LLVMTypeConverter &converter) { 34 auto elementType = unwrap(converter.convertType(svType.getElementType())); 35 if (!elementType) 36 return {}; 37 38 auto sVectorType = 39 LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back()); 40 return sVectorType; 41 } 42 43 template <typename OpTy> 44 class ForwardOperands : public OpConversionPattern<OpTy> { 45 using OpConversionPattern<OpTy>::OpConversionPattern; 46 47 LogicalResult 48 matchAndRewrite(OpTy op, ArrayRef<Value> operands, 49 ConversionPatternRewriter &rewriter) const final { 50 if (ValueRange(operands).getTypes() == op->getOperands().getTypes()) 51 return rewriter.notifyMatchFailure(op, "operand types already match"); 52 53 rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); 54 return success(); 55 } 56 }; 57 58 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> { 59 public: 60 using OpConversionPattern<ReturnOp>::OpConversionPattern; 61 62 LogicalResult 63 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 64 ConversionPatternRewriter &rewriter) const final { 65 rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); 66 return success(); 67 } 68 }; 69 70 static Optional<Value> addUnrealizedCast(OpBuilder &builder, 71 ScalableVectorType svType, 72 ValueRange inputs, Location loc) { 73 if (inputs.size() != 1 || 74 !inputs[0].getType().isa<LLVM::LLVMScalableVectorType>()) 75 return Value(); 76 return builder.create<UnrealizedConversionCastOp>(loc, svType, inputs) 77 .getResult(0); 78 } 79 80 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>; 81 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>; 82 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>; 83 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>; 84 using VectorScaleOpLowering = 85 OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>; 86 using ScalableMaskedAddIOpLowering = 87 OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp, 88 ScalableMaskedAddIIntrOp>; 89 using ScalableMaskedAddFOpLowering = 90 OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp, 91 ScalableMaskedAddFIntrOp>; 92 using ScalableMaskedSubIOpLowering = 93 OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp, 94 ScalableMaskedSubIIntrOp>; 95 using ScalableMaskedSubFOpLowering = 96 OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp, 97 ScalableMaskedSubFIntrOp>; 98 using ScalableMaskedMulIOpLowering = 99 OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp, 100 ScalableMaskedMulIIntrOp>; 101 using ScalableMaskedMulFOpLowering = 102 OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp, 103 ScalableMaskedMulFIntrOp>; 104 using ScalableMaskedSDivIOpLowering = 105 OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp, 106 ScalableMaskedSDivIIntrOp>; 107 using ScalableMaskedUDivIOpLowering = 108 OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp, 109 ScalableMaskedUDivIIntrOp>; 110 using ScalableMaskedDivFOpLowering = 111 OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp, 112 ScalableMaskedDivFIntrOp>; 113 114 // Load operation is lowered to code that obtains a pointer to the indexed 115 // element and loads from it. 116 struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern<ScalableLoadOp> { 117 using ConvertOpToLLVMPattern<ScalableLoadOp>::ConvertOpToLLVMPattern; 118 119 LogicalResult 120 matchAndRewrite(ScalableLoadOp loadOp, ArrayRef<Value> operands, 121 ConversionPatternRewriter &rewriter) const override { 122 auto type = loadOp.getMemRefType(); 123 if (!isConvertibleAndHasIdentityMaps(type)) 124 return failure(); 125 126 ScalableLoadOp::Adaptor transformed(operands); 127 LLVMTypeConverter converter(loadOp.getContext()); 128 129 auto resultType = loadOp.result().getType(); 130 LLVM::LLVMPointerType llvmDataTypePtr; 131 if (resultType.isa<VectorType>()) { 132 llvmDataTypePtr = 133 LLVM::LLVMPointerType::get(resultType.cast<VectorType>()); 134 } else if (resultType.isa<ScalableVectorType>()) { 135 llvmDataTypePtr = LLVM::LLVMPointerType::get( 136 convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(), 137 converter) 138 .getValue()); 139 } 140 Value dataPtr = 141 getStridedElementPtr(loadOp.getLoc(), type, transformed.base(), 142 transformed.index(), rewriter); 143 Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>( 144 loadOp.getLoc(), llvmDataTypePtr, dataPtr); 145 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, bitCastedPtr); 146 return success(); 147 } 148 }; 149 150 // Store operation is lowered to code that obtains a pointer to the indexed 151 // element, and stores the given value to it. 152 struct ScalableStoreOpLowering 153 : public ConvertOpToLLVMPattern<ScalableStoreOp> { 154 using ConvertOpToLLVMPattern<ScalableStoreOp>::ConvertOpToLLVMPattern; 155 156 LogicalResult 157 matchAndRewrite(ScalableStoreOp storeOp, ArrayRef<Value> operands, 158 ConversionPatternRewriter &rewriter) const override { 159 auto type = storeOp.getMemRefType(); 160 if (!isConvertibleAndHasIdentityMaps(type)) 161 return failure(); 162 163 ScalableStoreOp::Adaptor transformed(operands); 164 LLVMTypeConverter converter(storeOp.getContext()); 165 166 auto resultType = storeOp.value().getType(); 167 LLVM::LLVMPointerType llvmDataTypePtr; 168 if (resultType.isa<VectorType>()) { 169 llvmDataTypePtr = 170 LLVM::LLVMPointerType::get(resultType.cast<VectorType>()); 171 } else if (resultType.isa<ScalableVectorType>()) { 172 llvmDataTypePtr = LLVM::LLVMPointerType::get( 173 convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(), 174 converter) 175 .getValue()); 176 } 177 Value dataPtr = 178 getStridedElementPtr(storeOp.getLoc(), type, transformed.base(), 179 transformed.index(), rewriter); 180 Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>( 181 storeOp.getLoc(), llvmDataTypePtr, dataPtr); 182 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, transformed.value(), 183 bitCastedPtr); 184 return success(); 185 } 186 }; 187 188 static void 189 populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter, 190 OwningRewritePatternList &patterns) { 191 // clang-format off 192 patterns.add<OneToOneConvertToLLVMPattern<ScalableAddIOp, LLVM::AddOp>, 193 OneToOneConvertToLLVMPattern<ScalableAddFOp, LLVM::FAddOp>, 194 OneToOneConvertToLLVMPattern<ScalableSubIOp, LLVM::SubOp>, 195 OneToOneConvertToLLVMPattern<ScalableSubFOp, LLVM::FSubOp>, 196 OneToOneConvertToLLVMPattern<ScalableMulIOp, LLVM::MulOp>, 197 OneToOneConvertToLLVMPattern<ScalableMulFOp, LLVM::FMulOp>, 198 OneToOneConvertToLLVMPattern<ScalableSDivIOp, LLVM::SDivOp>, 199 OneToOneConvertToLLVMPattern<ScalableUDivIOp, LLVM::UDivOp>, 200 OneToOneConvertToLLVMPattern<ScalableDivFOp, LLVM::FDivOp> 201 >(converter); 202 // clang-format on 203 } 204 205 static void 206 configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) { 207 // clang-format off 208 target.addIllegalOp<ScalableAddIOp, 209 ScalableAddFOp, 210 ScalableSubIOp, 211 ScalableSubFOp, 212 ScalableMulIOp, 213 ScalableMulFOp, 214 ScalableSDivIOp, 215 ScalableUDivIOp, 216 ScalableDivFOp>(); 217 // clang-format on 218 } 219 220 static void 221 populateSVEMaskGenerationExportPatterns(LLVMTypeConverter &converter, 222 OwningRewritePatternList &patterns) { 223 // clang-format off 224 patterns.add<OneToOneConvertToLLVMPattern<ScalableCmpFOp, LLVM::FCmpOp>, 225 OneToOneConvertToLLVMPattern<ScalableCmpIOp, LLVM::ICmpOp> 226 >(converter); 227 // clang-format on 228 } 229 230 static void 231 configureSVEMaskGenerationLegalizations(LLVMConversionTarget &target) { 232 // clang-format off 233 target.addIllegalOp<ScalableCmpFOp, 234 ScalableCmpIOp>(); 235 // clang-format on 236 } 237 238 /// Populate the given list with patterns that convert from ArmSVE to LLVM. 239 void mlir::populateArmSVELegalizeForLLVMExportPatterns( 240 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 241 // Populate conversion patterns 242 // Remove any ArmSVE-specific types from function signatures and results. 243 populateFuncOpTypeConversionPattern(patterns, converter); 244 converter.addConversion([&converter](ScalableVectorType svType) { 245 return convertScalableVectorTypeToLLVM(svType, converter); 246 }); 247 converter.addSourceMaterialization(addUnrealizedCast); 248 249 // clang-format off 250 patterns.add<ForwardOperands<CallOp>, 251 ForwardOperands<CallIndirectOp>, 252 ForwardOperands<ReturnOp>>(converter, 253 &converter.getContext()); 254 patterns.add<SdotOpLowering, 255 SmmlaOpLowering, 256 UdotOpLowering, 257 UmmlaOpLowering, 258 VectorScaleOpLowering, 259 ScalableMaskedAddIOpLowering, 260 ScalableMaskedAddFOpLowering, 261 ScalableMaskedSubIOpLowering, 262 ScalableMaskedSubFOpLowering, 263 ScalableMaskedMulIOpLowering, 264 ScalableMaskedMulFOpLowering, 265 ScalableMaskedSDivIOpLowering, 266 ScalableMaskedUDivIOpLowering, 267 ScalableMaskedDivFOpLowering>(converter); 268 patterns.add<ScalableLoadOpLowering, 269 ScalableStoreOpLowering>(converter); 270 // clang-format on 271 populateBasicSVEArithmeticExportPatterns(converter, patterns); 272 populateSVEMaskGenerationExportPatterns(converter, patterns); 273 } 274 275 void mlir::configureArmSVELegalizeForExportTarget( 276 LLVMConversionTarget &target) { 277 // clang-format off 278 target.addLegalOp<SdotIntrOp, 279 SmmlaIntrOp, 280 UdotIntrOp, 281 UmmlaIntrOp, 282 VectorScaleIntrOp, 283 ScalableMaskedAddIIntrOp, 284 ScalableMaskedAddFIntrOp, 285 ScalableMaskedSubIIntrOp, 286 ScalableMaskedSubFIntrOp, 287 ScalableMaskedMulIIntrOp, 288 ScalableMaskedMulFIntrOp, 289 ScalableMaskedSDivIIntrOp, 290 ScalableMaskedUDivIIntrOp, 291 ScalableMaskedDivFIntrOp>(); 292 target.addIllegalOp<SdotOp, 293 SmmlaOp, 294 UdotOp, 295 UmmlaOp, 296 VectorScaleOp, 297 ScalableMaskedAddIOp, 298 ScalableMaskedAddFOp, 299 ScalableMaskedSubIOp, 300 ScalableMaskedSubFOp, 301 ScalableMaskedMulIOp, 302 ScalableMaskedMulFOp, 303 ScalableMaskedSDivIOp, 304 ScalableMaskedUDivIOp, 305 ScalableMaskedDivFOp, 306 ScalableLoadOp, 307 ScalableStoreOp>(); 308 // clang-format on 309 auto hasScalableVectorType = [](TypeRange types) { 310 for (Type type : types) 311 if (type.isa<arm_sve::ScalableVectorType>()) 312 return true; 313 return false; 314 }; 315 target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) { 316 return !hasScalableVectorType(op.getType().getInputs()) && 317 !hasScalableVectorType(op.getType().getResults()); 318 }); 319 target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>( 320 [hasScalableVectorType](Operation *op) { 321 return !hasScalableVectorType(op->getOperandTypes()) && 322 !hasScalableVectorType(op->getResultTypes()); 323 }); 324 configureBasicSVEArithmeticLegalizations(target); 325 configureSVEMaskGenerationLegalizations(target); 326 } 327