//===- ArithmeticToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "../PassDetail.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/TypeUtilities.h" using namespace mlir; namespace { //===----------------------------------------------------------------------===// // Straightforward Op Lowerings //===----------------------------------------------------------------------===// using AddIOpLowering = VectorConvertToLLVMPattern; using SubIOpLowering = VectorConvertToLLVMPattern; using MulIOpLowering = VectorConvertToLLVMPattern; using DivUIOpLowering = VectorConvertToLLVMPattern; using DivSIOpLowering = VectorConvertToLLVMPattern; using RemUIOpLowering = VectorConvertToLLVMPattern; using RemSIOpLowering = VectorConvertToLLVMPattern; using AndIOpLowering = VectorConvertToLLVMPattern; using OrIOpLowering = VectorConvertToLLVMPattern; using XOrIOpLowering = VectorConvertToLLVMPattern; using ShLIOpLowering = VectorConvertToLLVMPattern; using ShRUIOpLowering = VectorConvertToLLVMPattern; using ShRSIOpLowering = VectorConvertToLLVMPattern; using NegFOpLowering = VectorConvertToLLVMPattern; using AddFOpLowering = VectorConvertToLLVMPattern; using SubFOpLowering = VectorConvertToLLVMPattern; using MulFOpLowering = VectorConvertToLLVMPattern; using DivFOpLowering = VectorConvertToLLVMPattern; using RemFOpLowering = VectorConvertToLLVMPattern; using ExtUIOpLowering = VectorConvertToLLVMPattern; using ExtSIOpLowering = VectorConvertToLLVMPattern; using ExtFOpLowering = VectorConvertToLLVMPattern; using TruncIOpLowering = VectorConvertToLLVMPattern; using TruncFOpLowering = VectorConvertToLLVMPattern; using UIToFPOpLowering = VectorConvertToLLVMPattern; using SIToFPOpLowering = VectorConvertToLLVMPattern; using FPToUIOpLowering = VectorConvertToLLVMPattern; using FPToSIOpLowering = VectorConvertToLLVMPattern; using BitcastOpLowering = VectorConvertToLLVMPattern; using SelectOpLowering = VectorConvertToLLVMPattern; //===----------------------------------------------------------------------===// // Op Lowering Patterns //===----------------------------------------------------------------------===// /// Directly lower to LLVM op. struct ConstantOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// The lowering of index_cast becomes an integer conversion since index /// becomes an integer. If the bit width of the source and target integer /// types is the same, just erase the cast. If the target type is wider, /// sign-extend the value, otherwise truncate it. struct IndexCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; struct CmpIOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; struct CmpFOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace //===----------------------------------------------------------------------===// // ConstantOpLowering //===----------------------------------------------------------------------===// LogicalResult ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), *getTypeConverter(), rewriter); } //===----------------------------------------------------------------------===// // IndexCastOpLowering //===----------------------------------------------------------------------===// LogicalResult IndexCastOpLowering::matchAndRewrite( arith::IndexCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto targetType = typeConverter->convertType(op.getResult().getType()); auto targetElementType = typeConverter->convertType(getElementTypeOrSelf(op.getResult())) .cast(); auto sourceElementType = getElementTypeOrSelf(adaptor.getIn()).cast(); unsigned targetBits = targetElementType.getWidth(); unsigned sourceBits = sourceElementType.getWidth(); if (targetBits == sourceBits) rewriter.replaceOp(op, adaptor.getIn()); else if (targetBits < sourceBits) rewriter.replaceOpWithNewOp(op, targetType, adaptor.getIn()); else rewriter.replaceOpWithNewOp(op, targetType, adaptor.getIn()); return success(); } //===----------------------------------------------------------------------===// // CmpIOpLowering //===----------------------------------------------------------------------===// // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums // share numerical values so just cast. template static LLVMPredType convertCmpPredicate(PredType pred) { return static_cast(pred); } LogicalResult CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto operandType = adaptor.getLhs().getType(); auto resultType = op.getResult().getType(); // Handle the scalar and 1D vector cases. if (!operandType.isa()) { rewriter.replaceOpWithNewOp( op, typeConverter->convertType(resultType), convertCmpPredicate(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs()); return success(); } auto vectorType = resultType.dyn_cast(); if (!vectorType) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { OpAdaptor adaptor(operands); return rewriter.create( op.getLoc(), llvm1DVectorTy, convertCmpPredicate(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs()); }, rewriter); } //===----------------------------------------------------------------------===// // CmpFOpLowering //===----------------------------------------------------------------------===// LogicalResult CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto operandType = adaptor.getLhs().getType(); auto resultType = op.getResult().getType(); // Handle the scalar and 1D vector cases. if (!operandType.isa()) { rewriter.replaceOpWithNewOp( op, typeConverter->convertType(resultType), convertCmpPredicate(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs()); return success(); } auto vectorType = resultType.dyn_cast(); if (!vectorType) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { OpAdaptor adaptor(operands); return rewriter.create( op.getLoc(), llvm1DVectorTy, convertCmpPredicate(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs()); }, rewriter); } //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// namespace { struct ConvertArithmeticToLLVMPass : public ConvertArithmeticToLLVMBase { ConvertArithmeticToLLVMPass() = default; void runOnOperation() override { LLVMConversionTarget target(getContext()); RewritePatternSet patterns(&getContext()); LowerToLLVMOptions options(&getContext()); if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) options.overrideIndexBitwidth(indexBitwidth); LLVMTypeConverter converter(&getContext(), options); mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } }; } // namespace //===----------------------------------------------------------------------===// // Pattern Population //===----------------------------------------------------------------------===// void mlir::arith::populateArithmeticToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { // clang-format off patterns.add< ConstantOpLowering, AddIOpLowering, SubIOpLowering, MulIOpLowering, DivUIOpLowering, DivSIOpLowering, RemUIOpLowering, RemSIOpLowering, AndIOpLowering, OrIOpLowering, XOrIOpLowering, ShLIOpLowering, ShRUIOpLowering, ShRSIOpLowering, NegFOpLowering, AddFOpLowering, SubFOpLowering, MulFOpLowering, DivFOpLowering, RemFOpLowering, ExtUIOpLowering, ExtSIOpLowering, ExtFOpLowering, TruncIOpLowering, TruncFOpLowering, UIToFPOpLowering, SIToFPOpLowering, FPToUIOpLowering, FPToSIOpLowering, IndexCastOpLowering, BitcastOpLowering, CmpIOpLowering, CmpFOpLowering, SelectOpLowering >(converter); // clang-format on } std::unique_ptr mlir::arith::createConvertArithmeticToLLVMPass() { return std::make_unique(); }