1*a54f4eaeSMogball //===- ArithmeticToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===// 2*a54f4eaeSMogball // 3*a54f4eaeSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*a54f4eaeSMogball // See https://llvm.org/LICENSE.txt for license information. 5*a54f4eaeSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*a54f4eaeSMogball // 7*a54f4eaeSMogball //===----------------------------------------------------------------------===// 8*a54f4eaeSMogball 9*a54f4eaeSMogball #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 10*a54f4eaeSMogball #include "../PassDetail.h" 11*a54f4eaeSMogball #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12*a54f4eaeSMogball #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 13*a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14*a54f4eaeSMogball #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15*a54f4eaeSMogball #include "mlir/IR/TypeUtilities.h" 16*a54f4eaeSMogball 17*a54f4eaeSMogball using namespace mlir; 18*a54f4eaeSMogball 19*a54f4eaeSMogball namespace { 20*a54f4eaeSMogball 21*a54f4eaeSMogball //===----------------------------------------------------------------------===// 22*a54f4eaeSMogball // Straightforward Op Lowerings 23*a54f4eaeSMogball //===----------------------------------------------------------------------===// 24*a54f4eaeSMogball 25*a54f4eaeSMogball using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>; 26*a54f4eaeSMogball using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>; 27*a54f4eaeSMogball using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>; 28*a54f4eaeSMogball using DivUIOpLowering = 29*a54f4eaeSMogball VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>; 30*a54f4eaeSMogball using DivSIOpLowering = 31*a54f4eaeSMogball VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>; 32*a54f4eaeSMogball using RemUIOpLowering = 33*a54f4eaeSMogball VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>; 34*a54f4eaeSMogball using RemSIOpLowering = 35*a54f4eaeSMogball VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>; 36*a54f4eaeSMogball using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>; 37*a54f4eaeSMogball using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>; 38*a54f4eaeSMogball using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>; 39*a54f4eaeSMogball using ShLIOpLowering = VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp>; 40*a54f4eaeSMogball using ShRUIOpLowering = 41*a54f4eaeSMogball VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>; 42*a54f4eaeSMogball using ShRSIOpLowering = 43*a54f4eaeSMogball VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>; 44*a54f4eaeSMogball using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp>; 45*a54f4eaeSMogball using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp>; 46*a54f4eaeSMogball using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp>; 47*a54f4eaeSMogball using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp>; 48*a54f4eaeSMogball using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp>; 49*a54f4eaeSMogball using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp>; 50*a54f4eaeSMogball using ExtUIOpLowering = 51*a54f4eaeSMogball VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>; 52*a54f4eaeSMogball using ExtSIOpLowering = 53*a54f4eaeSMogball VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>; 54*a54f4eaeSMogball using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>; 55*a54f4eaeSMogball using TruncIOpLowering = 56*a54f4eaeSMogball VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>; 57*a54f4eaeSMogball using TruncFOpLowering = 58*a54f4eaeSMogball VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>; 59*a54f4eaeSMogball using UIToFPOpLowering = 60*a54f4eaeSMogball VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>; 61*a54f4eaeSMogball using SIToFPOpLowering = 62*a54f4eaeSMogball VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>; 63*a54f4eaeSMogball using FPToUIOpLowering = 64*a54f4eaeSMogball VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>; 65*a54f4eaeSMogball using FPToSIOpLowering = 66*a54f4eaeSMogball VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>; 67*a54f4eaeSMogball using BitcastOpLowering = 68*a54f4eaeSMogball VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>; 69*a54f4eaeSMogball 70*a54f4eaeSMogball //===----------------------------------------------------------------------===// 71*a54f4eaeSMogball // Op Lowering Patterns 72*a54f4eaeSMogball //===----------------------------------------------------------------------===// 73*a54f4eaeSMogball 74*a54f4eaeSMogball /// Directly lower to LLVM op. 75*a54f4eaeSMogball struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> { 76*a54f4eaeSMogball using ConvertOpToLLVMPattern<arith::ConstantOp>::ConvertOpToLLVMPattern; 77*a54f4eaeSMogball 78*a54f4eaeSMogball LogicalResult 79*a54f4eaeSMogball matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 80*a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override; 81*a54f4eaeSMogball }; 82*a54f4eaeSMogball 83*a54f4eaeSMogball /// The lowering of index_cast becomes an integer conversion since index 84*a54f4eaeSMogball /// becomes an integer. If the bit width of the source and target integer 85*a54f4eaeSMogball /// types is the same, just erase the cast. If the target type is wider, 86*a54f4eaeSMogball /// sign-extend the value, otherwise truncate it. 87*a54f4eaeSMogball struct IndexCastOpLowering : public ConvertOpToLLVMPattern<arith::IndexCastOp> { 88*a54f4eaeSMogball using ConvertOpToLLVMPattern<arith::IndexCastOp>::ConvertOpToLLVMPattern; 89*a54f4eaeSMogball 90*a54f4eaeSMogball LogicalResult 91*a54f4eaeSMogball matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, 92*a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override; 93*a54f4eaeSMogball }; 94*a54f4eaeSMogball 95*a54f4eaeSMogball struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> { 96*a54f4eaeSMogball using ConvertOpToLLVMPattern<arith::CmpIOp>::ConvertOpToLLVMPattern; 97*a54f4eaeSMogball 98*a54f4eaeSMogball LogicalResult 99*a54f4eaeSMogball matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 100*a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override; 101*a54f4eaeSMogball }; 102*a54f4eaeSMogball 103*a54f4eaeSMogball struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> { 104*a54f4eaeSMogball using ConvertOpToLLVMPattern<arith::CmpFOp>::ConvertOpToLLVMPattern; 105*a54f4eaeSMogball 106*a54f4eaeSMogball LogicalResult 107*a54f4eaeSMogball matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 108*a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override; 109*a54f4eaeSMogball }; 110*a54f4eaeSMogball 111*a54f4eaeSMogball } // end anonymous namespace 112*a54f4eaeSMogball 113*a54f4eaeSMogball //===----------------------------------------------------------------------===// 114*a54f4eaeSMogball // ConstantOpLowering 115*a54f4eaeSMogball //===----------------------------------------------------------------------===// 116*a54f4eaeSMogball 117*a54f4eaeSMogball LogicalResult 118*a54f4eaeSMogball ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 119*a54f4eaeSMogball ConversionPatternRewriter &rewriter) const { 120*a54f4eaeSMogball return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), 121*a54f4eaeSMogball adaptor.getOperands(), 122*a54f4eaeSMogball *getTypeConverter(), rewriter); 123*a54f4eaeSMogball } 124*a54f4eaeSMogball 125*a54f4eaeSMogball //===----------------------------------------------------------------------===// 126*a54f4eaeSMogball // IndexCastOpLowering 127*a54f4eaeSMogball //===----------------------------------------------------------------------===// 128*a54f4eaeSMogball 129*a54f4eaeSMogball LogicalResult IndexCastOpLowering::matchAndRewrite( 130*a54f4eaeSMogball arith::IndexCastOp op, OpAdaptor adaptor, 131*a54f4eaeSMogball ConversionPatternRewriter &rewriter) const { 132*a54f4eaeSMogball auto targetType = typeConverter->convertType(op.getResult().getType()); 133*a54f4eaeSMogball auto targetElementType = 134*a54f4eaeSMogball typeConverter->convertType(getElementTypeOrSelf(op.getResult())) 135*a54f4eaeSMogball .cast<IntegerType>(); 136*a54f4eaeSMogball auto sourceElementType = 137*a54f4eaeSMogball getElementTypeOrSelf(adaptor.in()).cast<IntegerType>(); 138*a54f4eaeSMogball unsigned targetBits = targetElementType.getWidth(); 139*a54f4eaeSMogball unsigned sourceBits = sourceElementType.getWidth(); 140*a54f4eaeSMogball 141*a54f4eaeSMogball if (targetBits == sourceBits) 142*a54f4eaeSMogball rewriter.replaceOp(op, adaptor.in()); 143*a54f4eaeSMogball else if (targetBits < sourceBits) 144*a54f4eaeSMogball rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, adaptor.in()); 145*a54f4eaeSMogball else 146*a54f4eaeSMogball rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, adaptor.in()); 147*a54f4eaeSMogball return success(); 148*a54f4eaeSMogball } 149*a54f4eaeSMogball 150*a54f4eaeSMogball //===----------------------------------------------------------------------===// 151*a54f4eaeSMogball // CmpIOpLowering 152*a54f4eaeSMogball //===----------------------------------------------------------------------===// 153*a54f4eaeSMogball 154*a54f4eaeSMogball // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums 155*a54f4eaeSMogball // share numerical values so just cast. 156*a54f4eaeSMogball template <typename LLVMPredType, typename PredType> 157*a54f4eaeSMogball static LLVMPredType convertCmpPredicate(PredType pred) { 158*a54f4eaeSMogball return static_cast<LLVMPredType>(pred); 159*a54f4eaeSMogball } 160*a54f4eaeSMogball 161*a54f4eaeSMogball LogicalResult 162*a54f4eaeSMogball CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 163*a54f4eaeSMogball ConversionPatternRewriter &rewriter) const { 164*a54f4eaeSMogball auto operandType = adaptor.lhs().getType(); 165*a54f4eaeSMogball auto resultType = op.getResult().getType(); 166*a54f4eaeSMogball 167*a54f4eaeSMogball // Handle the scalar and 1D vector cases. 168*a54f4eaeSMogball if (!operandType.isa<LLVM::LLVMArrayType>()) { 169*a54f4eaeSMogball rewriter.replaceOpWithNewOp<LLVM::ICmpOp>( 170*a54f4eaeSMogball op, typeConverter->convertType(resultType), 171*a54f4eaeSMogball convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), 172*a54f4eaeSMogball adaptor.lhs(), adaptor.rhs()); 173*a54f4eaeSMogball return success(); 174*a54f4eaeSMogball } 175*a54f4eaeSMogball 176*a54f4eaeSMogball auto vectorType = resultType.dyn_cast<VectorType>(); 177*a54f4eaeSMogball if (!vectorType) 178*a54f4eaeSMogball return rewriter.notifyMatchFailure(op, "expected vector result type"); 179*a54f4eaeSMogball 180*a54f4eaeSMogball return LLVM::detail::handleMultidimensionalVectors( 181*a54f4eaeSMogball op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 182*a54f4eaeSMogball [&](Type llvm1DVectorTy, ValueRange operands) { 183*a54f4eaeSMogball OpAdaptor adaptor(operands); 184*a54f4eaeSMogball return rewriter.create<LLVM::ICmpOp>( 185*a54f4eaeSMogball op.getLoc(), llvm1DVectorTy, 186*a54f4eaeSMogball convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), 187*a54f4eaeSMogball adaptor.lhs(), adaptor.rhs()); 188*a54f4eaeSMogball }, 189*a54f4eaeSMogball rewriter); 190*a54f4eaeSMogball 191*a54f4eaeSMogball return success(); 192*a54f4eaeSMogball } 193*a54f4eaeSMogball 194*a54f4eaeSMogball //===----------------------------------------------------------------------===// 195*a54f4eaeSMogball // CmpFOpLowering 196*a54f4eaeSMogball //===----------------------------------------------------------------------===// 197*a54f4eaeSMogball 198*a54f4eaeSMogball LogicalResult 199*a54f4eaeSMogball CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 200*a54f4eaeSMogball ConversionPatternRewriter &rewriter) const { 201*a54f4eaeSMogball auto operandType = adaptor.lhs().getType(); 202*a54f4eaeSMogball auto resultType = op.getResult().getType(); 203*a54f4eaeSMogball 204*a54f4eaeSMogball // Handle the scalar and 1D vector cases. 205*a54f4eaeSMogball if (!operandType.isa<LLVM::LLVMArrayType>()) { 206*a54f4eaeSMogball rewriter.replaceOpWithNewOp<LLVM::FCmpOp>( 207*a54f4eaeSMogball op, typeConverter->convertType(resultType), 208*a54f4eaeSMogball convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), 209*a54f4eaeSMogball adaptor.lhs(), adaptor.rhs()); 210*a54f4eaeSMogball return success(); 211*a54f4eaeSMogball } 212*a54f4eaeSMogball 213*a54f4eaeSMogball auto vectorType = resultType.dyn_cast<VectorType>(); 214*a54f4eaeSMogball if (!vectorType) 215*a54f4eaeSMogball return rewriter.notifyMatchFailure(op, "expected vector result type"); 216*a54f4eaeSMogball 217*a54f4eaeSMogball return LLVM::detail::handleMultidimensionalVectors( 218*a54f4eaeSMogball op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 219*a54f4eaeSMogball [&](Type llvm1DVectorTy, ValueRange operands) { 220*a54f4eaeSMogball OpAdaptor adaptor(operands); 221*a54f4eaeSMogball return rewriter.create<LLVM::FCmpOp>( 222*a54f4eaeSMogball op.getLoc(), llvm1DVectorTy, 223*a54f4eaeSMogball convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), 224*a54f4eaeSMogball adaptor.lhs(), adaptor.rhs()); 225*a54f4eaeSMogball }, 226*a54f4eaeSMogball rewriter); 227*a54f4eaeSMogball } 228*a54f4eaeSMogball 229*a54f4eaeSMogball //===----------------------------------------------------------------------===// 230*a54f4eaeSMogball // Pass Definition 231*a54f4eaeSMogball //===----------------------------------------------------------------------===// 232*a54f4eaeSMogball 233*a54f4eaeSMogball namespace { 234*a54f4eaeSMogball struct ConvertArithmeticToLLVMPass 235*a54f4eaeSMogball : public ConvertArithmeticToLLVMBase<ConvertArithmeticToLLVMPass> { 236*a54f4eaeSMogball ConvertArithmeticToLLVMPass() = default; 237*a54f4eaeSMogball 238*a54f4eaeSMogball void runOnFunction() override { 239*a54f4eaeSMogball LLVMConversionTarget target(getContext()); 240*a54f4eaeSMogball RewritePatternSet patterns(&getContext()); 241*a54f4eaeSMogball 242*a54f4eaeSMogball LowerToLLVMOptions options(&getContext()); 243*a54f4eaeSMogball if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 244*a54f4eaeSMogball options.overrideIndexBitwidth(indexBitwidth); 245*a54f4eaeSMogball 246*a54f4eaeSMogball LLVMTypeConverter converter(&getContext(), options); 247*a54f4eaeSMogball mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, 248*a54f4eaeSMogball patterns); 249*a54f4eaeSMogball 250*a54f4eaeSMogball if (failed( 251*a54f4eaeSMogball applyPartialConversion(getFunction(), target, std::move(patterns)))) 252*a54f4eaeSMogball signalPassFailure(); 253*a54f4eaeSMogball } 254*a54f4eaeSMogball }; 255*a54f4eaeSMogball } // end anonymous namespace 256*a54f4eaeSMogball 257*a54f4eaeSMogball //===----------------------------------------------------------------------===// 258*a54f4eaeSMogball // Pattern Population 259*a54f4eaeSMogball //===----------------------------------------------------------------------===// 260*a54f4eaeSMogball 261*a54f4eaeSMogball void mlir::arith::populateArithmeticToLLVMConversionPatterns( 262*a54f4eaeSMogball LLVMTypeConverter &converter, RewritePatternSet &patterns) { 263*a54f4eaeSMogball // clang-format off 264*a54f4eaeSMogball patterns.add< 265*a54f4eaeSMogball ConstantOpLowering, 266*a54f4eaeSMogball AddIOpLowering, 267*a54f4eaeSMogball SubIOpLowering, 268*a54f4eaeSMogball MulIOpLowering, 269*a54f4eaeSMogball DivUIOpLowering, 270*a54f4eaeSMogball DivSIOpLowering, 271*a54f4eaeSMogball RemUIOpLowering, 272*a54f4eaeSMogball RemSIOpLowering, 273*a54f4eaeSMogball AndIOpLowering, 274*a54f4eaeSMogball OrIOpLowering, 275*a54f4eaeSMogball XOrIOpLowering, 276*a54f4eaeSMogball ShLIOpLowering, 277*a54f4eaeSMogball ShRUIOpLowering, 278*a54f4eaeSMogball ShRSIOpLowering, 279*a54f4eaeSMogball NegFOpLowering, 280*a54f4eaeSMogball AddFOpLowering, 281*a54f4eaeSMogball SubFOpLowering, 282*a54f4eaeSMogball MulFOpLowering, 283*a54f4eaeSMogball DivFOpLowering, 284*a54f4eaeSMogball RemFOpLowering, 285*a54f4eaeSMogball ExtUIOpLowering, 286*a54f4eaeSMogball ExtSIOpLowering, 287*a54f4eaeSMogball ExtFOpLowering, 288*a54f4eaeSMogball TruncIOpLowering, 289*a54f4eaeSMogball TruncFOpLowering, 290*a54f4eaeSMogball UIToFPOpLowering, 291*a54f4eaeSMogball SIToFPOpLowering, 292*a54f4eaeSMogball FPToUIOpLowering, 293*a54f4eaeSMogball FPToSIOpLowering, 294*a54f4eaeSMogball IndexCastOpLowering, 295*a54f4eaeSMogball BitcastOpLowering, 296*a54f4eaeSMogball CmpIOpLowering, 297*a54f4eaeSMogball CmpFOpLowering 298*a54f4eaeSMogball >(converter); 299*a54f4eaeSMogball // clang-format on 300*a54f4eaeSMogball } 301*a54f4eaeSMogball 302*a54f4eaeSMogball std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToLLVMPass() { 303*a54f4eaeSMogball return std::make_unique<ConvertArithmeticToLLVMPass>(); 304*a54f4eaeSMogball } 305