1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 10 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" 11 #include "mlir/Dialect/ArmSVE/Transforms.h" 12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 13 #include "mlir/Dialect/StandardOps/IR/Ops.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/PatternMatch.h" 16 17 using namespace mlir; 18 using namespace mlir::arm_sve; 19 20 // Extract an LLVM IR type from the LLVM IR dialect type. 21 static Type unwrap(Type type) { 22 if (!type) 23 return nullptr; 24 auto *mlirContext = type.getContext(); 25 if (!LLVM::isCompatibleType(type)) 26 emitError(UnknownLoc::get(mlirContext), 27 "conversion resulted in a non-LLVM type"); 28 return type; 29 } 30 31 static Optional<Type> 32 convertScalableVectorTypeToLLVM(ScalableVectorType svType, 33 LLVMTypeConverter &converter) { 34 auto elementType = unwrap(converter.convertType(svType.getElementType())); 35 if (!elementType) 36 return {}; 37 38 auto sVectorType = 39 LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back()); 40 return sVectorType; 41 } 42 43 template <typename OpTy> 44 class ForwardOperands : public OpConversionPattern<OpTy> { 45 using OpConversionPattern<OpTy>::OpConversionPattern; 46 47 LogicalResult 48 matchAndRewrite(OpTy op, ArrayRef<Value> operands, 49 ConversionPatternRewriter &rewriter) const final { 50 if (ValueRange(operands).getTypes() == op->getOperands().getTypes()) 51 return rewriter.notifyMatchFailure(op, "operand types already match"); 52 53 rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); 54 return success(); 55 } 56 }; 57 58 class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> { 59 public: 60 using OpConversionPattern<ReturnOp>::OpConversionPattern; 61 62 LogicalResult 63 matchAndRewrite(ReturnOp op, ArrayRef<Value> operands, 64 ConversionPatternRewriter &rewriter) const final { 65 rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); 66 return success(); 67 } 68 }; 69 70 static Optional<Value> addUnrealizedCast(OpBuilder &builder, 71 ScalableVectorType svType, 72 ValueRange inputs, Location loc) { 73 if (inputs.size() != 1 || 74 !inputs[0].getType().isa<LLVM::LLVMScalableVectorType>()) 75 return Value(); 76 return builder.create<UnrealizedConversionCastOp>(loc, svType, inputs) 77 .getResult(0); 78 } 79 80 using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>; 81 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>; 82 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>; 83 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>; 84 using VectorScaleOpLowering = 85 OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>; 86 87 /// Populate the given list with patterns that convert from ArmSVE to LLVM. 88 void mlir::populateArmSVELegalizeForLLVMExportPatterns( 89 LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { 90 // Populate conversion patterns 91 // Remove any ArmSVE-specific types from function signatures and results. 92 populateFuncOpTypeConversionPattern(patterns, converter); 93 converter.addConversion([&converter](ScalableVectorType svType) { 94 return convertScalableVectorTypeToLLVM(svType, converter); 95 }); 96 converter.addSourceMaterialization(addUnrealizedCast); 97 98 // clang-format off 99 patterns.add<ForwardOperands<CallOp>, 100 ForwardOperands<CallIndirectOp>, 101 ForwardOperands<ReturnOp>>(converter, 102 &converter.getContext()); 103 patterns.add<SdotOpLowering, 104 SmmlaOpLowering, 105 UdotOpLowering, 106 UmmlaOpLowering, 107 VectorScaleOpLowering>(converter); 108 // clang-format on 109 } 110 111 void mlir::configureArmSVELegalizeForExportTarget( 112 LLVMConversionTarget &target) { 113 target.addLegalOp<SdotIntrOp>(); 114 target.addIllegalOp<SdotOp>(); 115 target.addLegalOp<SmmlaIntrOp>(); 116 target.addIllegalOp<SmmlaOp>(); 117 target.addLegalOp<UdotIntrOp>(); 118 target.addIllegalOp<UdotOp>(); 119 target.addLegalOp<UmmlaIntrOp>(); 120 target.addIllegalOp<UmmlaOp>(); 121 target.addLegalOp<VectorScaleIntrOp>(); 122 target.addIllegalOp<VectorScaleOp>(); 123 auto hasScalableVectorType = [](TypeRange types) { 124 for (Type type : types) 125 if (type.isa<arm_sve::ScalableVectorType>()) 126 return true; 127 return false; 128 }; 129 target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) { 130 return !hasScalableVectorType(op.getType().getInputs()) && 131 !hasScalableVectorType(op.getType().getResults()); 132 }); 133 target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>( 134 [hasScalableVectorType](Operation *op) { 135 return !hasScalableVectorType(op->getOperandTypes()) && 136 !hasScalableVectorType(op->getResultTypes()); 137 }); 138 } 139