1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 10 #include "mlir/Conversion/LLVMCommon/Pattern.h" 11 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" 12 #include "mlir/Dialect/ArmSVE/Transforms.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/BuiltinOps.h" 16 #include "mlir/IR/PatternMatch.h" 17 18 using namespace mlir; 19 using namespace mlir::arm_sve; 20 21 // Extract an LLVM IR type from the LLVM IR dialect type. 22 static Type unwrap(Type type) { 23 if (!type) 24 return nullptr; 25 auto *mlirContext = type.getContext(); 26 if (!LLVM::isCompatibleType(type)) 27 emitError(UnknownLoc::get(mlirContext), 28 "conversion resulted in a non-LLVM type"); 29 return type; 30 } 31 32 static Optional<Type> 33 convertScalableVectorTypeToLLVM(ScalableVectorType svType, 34 LLVMTypeConverter &converter) { 35 auto elementType = unwrap(converter.convertType(svType.getElementType())); 36 if (!elementType) 37 return {}; 38 39 auto sVectorType = 40 LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back()); 41 return sVectorType; 42 } 43 44 template <typename OpTy> 45 class ForwardOperands : public OpConversionPattern<OpTy> { 46 using OpConversionPattern<OpTy>::OpConversionPattern; 47 48 LogicalResult 49 matchAndRewrite(OpTy op, ArrayRef<Value> operands, 50 ConversionPatternRewriter &rewriter) const final { 51 if (ValueRange(operands).getTypes() == op->getOperands().getTypes()) 52 return rewriter.notifyMatchFailure(op, "operand types already match"); 53 54 rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); 55 return success(); 56 } 57 }; 58 59 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> { 60 public: 61 using OpConversionPattern<ReturnOp>::OpConversionPattern; 62 63 LogicalResult 64 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 65 ConversionPatternRewriter &rewriter) const final { 66 rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); 67 return success(); 68 } 69 }; 70 71 static Optional<Value> addUnrealizedCast(OpBuilder &builder, 72 ScalableVectorType svType, 73 ValueRange inputs, Location loc) { 74 if (inputs.size() != 1 || 75 !inputs[0].getType().isa<LLVM::LLVMScalableVectorType>()) 76 return Value(); 77 return builder.create<UnrealizedConversionCastOp>(loc, svType, inputs) 78 .getResult(0); 79 } 80 81 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>; 82 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>; 83 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>; 84 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>; 85 using VectorScaleOpLowering = 86 OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>; 87 using ScalableMaskedAddIOpLowering = 88 OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp, 89 ScalableMaskedAddIIntrOp>; 90 using ScalableMaskedAddFOpLowering = 91 OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp, 92 ScalableMaskedAddFIntrOp>; 93 using ScalableMaskedSubIOpLowering = 94 OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp, 95 ScalableMaskedSubIIntrOp>; 96 using ScalableMaskedSubFOpLowering = 97 OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp, 98 ScalableMaskedSubFIntrOp>; 99 using ScalableMaskedMulIOpLowering = 100 OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp, 101 ScalableMaskedMulIIntrOp>; 102 using ScalableMaskedMulFOpLowering = 103 OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp, 104 ScalableMaskedMulFIntrOp>; 105 using ScalableMaskedSDivIOpLowering = 106 OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp, 107 ScalableMaskedSDivIIntrOp>; 108 using ScalableMaskedUDivIOpLowering = 109 OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp, 110 ScalableMaskedUDivIIntrOp>; 111 using ScalableMaskedDivFOpLowering = 112 OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp, 113 ScalableMaskedDivFIntrOp>; 114 115 // Load operation is lowered to code that obtains a pointer to the indexed 116 // element and loads from it. 117 struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern<ScalableLoadOp> { 118 using ConvertOpToLLVMPattern<ScalableLoadOp>::ConvertOpToLLVMPattern; 119 120 LogicalResult 121 matchAndRewrite(ScalableLoadOp loadOp, ArrayRef<Value> operands, 122 ConversionPatternRewriter &rewriter) const override { 123 auto type = loadOp.getMemRefType(); 124 if (!isConvertibleAndHasIdentityMaps(type)) 125 return failure(); 126 127 ScalableLoadOp::Adaptor transformed(operands); 128 LLVMTypeConverter converter(loadOp.getContext()); 129 130 auto resultType = loadOp.result().getType(); 131 LLVM::LLVMPointerType llvmDataTypePtr; 132 if (resultType.isa<VectorType>()) { 133 llvmDataTypePtr = 134 LLVM::LLVMPointerType::get(resultType.cast<VectorType>()); 135 } else if (resultType.isa<ScalableVectorType>()) { 136 llvmDataTypePtr = LLVM::LLVMPointerType::get( 137 convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(), 138 converter) 139 .getValue()); 140 } 141 Value dataPtr = 142 getStridedElementPtr(loadOp.getLoc(), type, transformed.base(), 143 transformed.index(), rewriter); 144 Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>( 145 loadOp.getLoc(), llvmDataTypePtr, dataPtr); 146 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, bitCastedPtr); 147 return success(); 148 } 149 }; 150 151 // Store operation is lowered to code that obtains a pointer to the indexed 152 // element, and stores the given value to it. 153 struct ScalableStoreOpLowering 154 : public ConvertOpToLLVMPattern<ScalableStoreOp> { 155 using ConvertOpToLLVMPattern<ScalableStoreOp>::ConvertOpToLLVMPattern; 156 157 LogicalResult 158 matchAndRewrite(ScalableStoreOp storeOp, ArrayRef<Value> operands, 159 ConversionPatternRewriter &rewriter) const override { 160 auto type = storeOp.getMemRefType(); 161 if (!isConvertibleAndHasIdentityMaps(type)) 162 return failure(); 163 164 ScalableStoreOp::Adaptor transformed(operands); 165 LLVMTypeConverter converter(storeOp.getContext()); 166 167 auto resultType = storeOp.value().getType(); 168 LLVM::LLVMPointerType llvmDataTypePtr; 169 if (resultType.isa<VectorType>()) { 170 llvmDataTypePtr = 171 LLVM::LLVMPointerType::get(resultType.cast<VectorType>()); 172 } else if (resultType.isa<ScalableVectorType>()) { 173 llvmDataTypePtr = LLVM::LLVMPointerType::get( 174 convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(), 175 converter) 176 .getValue()); 177 } 178 Value dataPtr = 179 getStridedElementPtr(storeOp.getLoc(), type, transformed.base(), 180 transformed.index(), rewriter); 181 Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>( 182 storeOp.getLoc(), llvmDataTypePtr, dataPtr); 183 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, transformed.value(), 184 bitCastedPtr); 185 return success(); 186 } 187 }; 188 189 static void 190 populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter, 191 OwningRewritePatternList &patterns) { 192 // clang-format off 193 patterns.add<OneToOneConvertToLLVMPattern<ScalableAddIOp, LLVM::AddOp>, 194 OneToOneConvertToLLVMPattern<ScalableAddFOp, LLVM::FAddOp>, 195 OneToOneConvertToLLVMPattern<ScalableSubIOp, LLVM::SubOp>, 196 OneToOneConvertToLLVMPattern<ScalableSubFOp, LLVM::FSubOp>, 197 OneToOneConvertToLLVMPattern<ScalableMulIOp, LLVM::MulOp>, 198 OneToOneConvertToLLVMPattern<ScalableMulFOp, LLVM::FMulOp>, 199 OneToOneConvertToLLVMPattern<ScalableSDivIOp, LLVM::SDivOp>, 200 OneToOneConvertToLLVMPattern<ScalableUDivIOp, LLVM::UDivOp>, 201 OneToOneConvertToLLVMPattern<ScalableDivFOp, LLVM::FDivOp> 202 >(converter); 203 // clang-format on 204 } 205 206 static void 207 configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) { 208 // clang-format off 209 target.addIllegalOp<ScalableAddIOp, 210 ScalableAddFOp, 211 ScalableSubIOp, 212 ScalableSubFOp, 213 ScalableMulIOp, 214 ScalableMulFOp, 215 ScalableSDivIOp, 216 ScalableUDivIOp, 217 ScalableDivFOp>(); 218 // clang-format on 219 } 220 221 static void 222 populateSVEMaskGenerationExportPatterns(LLVMTypeConverter &converter, 223 OwningRewritePatternList &patterns) { 224 // clang-format off 225 patterns.add<OneToOneConvertToLLVMPattern<ScalableCmpFOp, LLVM::FCmpOp>, 226 OneToOneConvertToLLVMPattern<ScalableCmpIOp, LLVM::ICmpOp> 227 >(converter); 228 // clang-format on 229 } 230 231 static void 232 configureSVEMaskGenerationLegalizations(LLVMConversionTarget &target) { 233 // clang-format off 234 target.addIllegalOp<ScalableCmpFOp, 235 ScalableCmpIOp>(); 236 // clang-format on 237 } 238 239 /// Populate the given list with patterns that convert from ArmSVE to LLVM. 240 void mlir::populateArmSVELegalizeForLLVMExportPatterns( 241 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 242 // Populate conversion patterns 243 // Remove any ArmSVE-specific types from function signatures and results. 244 populateFuncOpTypeConversionPattern(patterns, converter); 245 converter.addConversion([&converter](ScalableVectorType svType) { 246 return convertScalableVectorTypeToLLVM(svType, converter); 247 }); 248 converter.addSourceMaterialization(addUnrealizedCast); 249 250 // clang-format off 251 patterns.add<ForwardOperands<CallOp>, 252 ForwardOperands<CallIndirectOp>, 253 ForwardOperands<ReturnOp>>(converter, 254 &converter.getContext()); 255 patterns.add<SdotOpLowering, 256 SmmlaOpLowering, 257 UdotOpLowering, 258 UmmlaOpLowering, 259 VectorScaleOpLowering, 260 ScalableMaskedAddIOpLowering, 261 ScalableMaskedAddFOpLowering, 262 ScalableMaskedSubIOpLowering, 263 ScalableMaskedSubFOpLowering, 264 ScalableMaskedMulIOpLowering, 265 ScalableMaskedMulFOpLowering, 266 ScalableMaskedSDivIOpLowering, 267 ScalableMaskedUDivIOpLowering, 268 ScalableMaskedDivFOpLowering>(converter); 269 patterns.add<ScalableLoadOpLowering, 270 ScalableStoreOpLowering>(converter); 271 // clang-format on 272 populateBasicSVEArithmeticExportPatterns(converter, patterns); 273 populateSVEMaskGenerationExportPatterns(converter, patterns); 274 } 275 276 void mlir::configureArmSVELegalizeForExportTarget( 277 LLVMConversionTarget &target) { 278 // clang-format off 279 target.addLegalOp<SdotIntrOp, 280 SmmlaIntrOp, 281 UdotIntrOp, 282 UmmlaIntrOp, 283 VectorScaleIntrOp, 284 ScalableMaskedAddIIntrOp, 285 ScalableMaskedAddFIntrOp, 286 ScalableMaskedSubIIntrOp, 287 ScalableMaskedSubFIntrOp, 288 ScalableMaskedMulIIntrOp, 289 ScalableMaskedMulFIntrOp, 290 ScalableMaskedSDivIIntrOp, 291 ScalableMaskedUDivIIntrOp, 292 ScalableMaskedDivFIntrOp>(); 293 target.addIllegalOp<SdotOp, 294 SmmlaOp, 295 UdotOp, 296 UmmlaOp, 297 VectorScaleOp, 298 ScalableMaskedAddIOp, 299 ScalableMaskedAddFOp, 300 ScalableMaskedSubIOp, 301 ScalableMaskedSubFOp, 302 ScalableMaskedMulIOp, 303 ScalableMaskedMulFOp, 304 ScalableMaskedSDivIOp, 305 ScalableMaskedUDivIOp, 306 ScalableMaskedDivFOp, 307 ScalableLoadOp, 308 ScalableStoreOp>(); 309 // clang-format on 310 auto hasScalableVectorType = [](TypeRange types) { 311 for (Type type : types) 312 if (type.isa<arm_sve::ScalableVectorType>()) 313 return true; 314 return false; 315 }; 316 target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) { 317 return !hasScalableVectorType(op.getType().getInputs()) && 318 !hasScalableVectorType(op.getType().getResults()); 319 }); 320 target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>( 321 [hasScalableVectorType](Operation *op) { 322 return !hasScalableVectorType(op->getOperandTypes()) && 323 !hasScalableVectorType(op->getResultTypes()); 324 }); 325 configureBasicSVEArithmeticLegalizations(target); 326 configureSVEMaskGenerationLegalizations(target); 327 } 328