1a54f4eaeSMogball //===- ArithmeticToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===// 2a54f4eaeSMogball // 3a54f4eaeSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a54f4eaeSMogball // See https://llvm.org/LICENSE.txt for license information. 5a54f4eaeSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a54f4eaeSMogball // 7a54f4eaeSMogball //===----------------------------------------------------------------------===// 8a54f4eaeSMogball 9a54f4eaeSMogball #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 10a54f4eaeSMogball #include "../PassDetail.h" 11a54f4eaeSMogball #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12a54f4eaeSMogball #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 13a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14a54f4eaeSMogball #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15a54f4eaeSMogball #include "mlir/IR/TypeUtilities.h" 16a54f4eaeSMogball 17a54f4eaeSMogball using namespace mlir; 18a54f4eaeSMogball 19a54f4eaeSMogball namespace { 20a54f4eaeSMogball 21a54f4eaeSMogball //===----------------------------------------------------------------------===// 22a54f4eaeSMogball // Straightforward Op Lowerings 23a54f4eaeSMogball //===----------------------------------------------------------------------===// 24a54f4eaeSMogball 25a54f4eaeSMogball using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>; 26a54f4eaeSMogball using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>; 27a54f4eaeSMogball using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>; 28a54f4eaeSMogball using DivUIOpLowering = 29a54f4eaeSMogball VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>; 30a54f4eaeSMogball using DivSIOpLowering = 31a54f4eaeSMogball VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>; 32a54f4eaeSMogball using RemUIOpLowering = 33a54f4eaeSMogball VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>; 34a54f4eaeSMogball using RemSIOpLowering = 35a54f4eaeSMogball VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>; 36a54f4eaeSMogball using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>; 37a54f4eaeSMogball using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>; 38a54f4eaeSMogball using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>; 39a54f4eaeSMogball using ShLIOpLowering = VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp>; 40a54f4eaeSMogball using ShRUIOpLowering = 41a54f4eaeSMogball VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>; 42a54f4eaeSMogball using ShRSIOpLowering = 43a54f4eaeSMogball VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>; 44a54f4eaeSMogball using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp>; 45a54f4eaeSMogball using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp>; 46a54f4eaeSMogball using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp>; 47a54f4eaeSMogball using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp>; 48a54f4eaeSMogball using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp>; 49a54f4eaeSMogball using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp>; 50a54f4eaeSMogball using ExtUIOpLowering = 51a54f4eaeSMogball VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>; 52a54f4eaeSMogball using ExtSIOpLowering = 53a54f4eaeSMogball VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>; 54a54f4eaeSMogball using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>; 55a54f4eaeSMogball using TruncIOpLowering = 56a54f4eaeSMogball VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>; 57a54f4eaeSMogball using TruncFOpLowering = 58a54f4eaeSMogball VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>; 59a54f4eaeSMogball using UIToFPOpLowering = 60a54f4eaeSMogball VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>; 61a54f4eaeSMogball using SIToFPOpLowering = 62a54f4eaeSMogball VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>; 63a54f4eaeSMogball using FPToUIOpLowering = 64a54f4eaeSMogball VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>; 65a54f4eaeSMogball using FPToSIOpLowering = 66a54f4eaeSMogball VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>; 67a54f4eaeSMogball using BitcastOpLowering = 68a54f4eaeSMogball VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>; 69a54f4eaeSMogball 70a54f4eaeSMogball //===----------------------------------------------------------------------===// 71a54f4eaeSMogball // Op Lowering Patterns 72a54f4eaeSMogball //===----------------------------------------------------------------------===// 73a54f4eaeSMogball 74a54f4eaeSMogball /// Directly lower to LLVM op. 75a54f4eaeSMogball struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> { 76a54f4eaeSMogball using ConvertOpToLLVMPattern<arith::ConstantOp>::ConvertOpToLLVMPattern; 77a54f4eaeSMogball 78a54f4eaeSMogball LogicalResult 79a54f4eaeSMogball matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 80a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override; 81a54f4eaeSMogball }; 82a54f4eaeSMogball 83a54f4eaeSMogball /// The lowering of index_cast becomes an integer conversion since index 84a54f4eaeSMogball /// becomes an integer. If the bit width of the source and target integer 85a54f4eaeSMogball /// types is the same, just erase the cast. If the target type is wider, 86a54f4eaeSMogball /// sign-extend the value, otherwise truncate it. 87a54f4eaeSMogball struct IndexCastOpLowering : public ConvertOpToLLVMPattern<arith::IndexCastOp> { 88a54f4eaeSMogball using ConvertOpToLLVMPattern<arith::IndexCastOp>::ConvertOpToLLVMPattern; 89a54f4eaeSMogball 90a54f4eaeSMogball LogicalResult 91a54f4eaeSMogball matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, 92a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override; 93a54f4eaeSMogball }; 94a54f4eaeSMogball 95a54f4eaeSMogball struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> { 96a54f4eaeSMogball using ConvertOpToLLVMPattern<arith::CmpIOp>::ConvertOpToLLVMPattern; 97a54f4eaeSMogball 98a54f4eaeSMogball LogicalResult 99a54f4eaeSMogball matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 100a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override; 101a54f4eaeSMogball }; 102a54f4eaeSMogball 103a54f4eaeSMogball struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> { 104a54f4eaeSMogball using ConvertOpToLLVMPattern<arith::CmpFOp>::ConvertOpToLLVMPattern; 105a54f4eaeSMogball 106a54f4eaeSMogball LogicalResult 107a54f4eaeSMogball matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 108a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override; 109a54f4eaeSMogball }; 110a54f4eaeSMogball 111a54f4eaeSMogball } // end anonymous namespace 112a54f4eaeSMogball 113a54f4eaeSMogball //===----------------------------------------------------------------------===// 114a54f4eaeSMogball // ConstantOpLowering 115a54f4eaeSMogball //===----------------------------------------------------------------------===// 116a54f4eaeSMogball 117a54f4eaeSMogball LogicalResult 118a54f4eaeSMogball ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 119a54f4eaeSMogball ConversionPatternRewriter &rewriter) const { 120a54f4eaeSMogball return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), 121a54f4eaeSMogball adaptor.getOperands(), 122a54f4eaeSMogball *getTypeConverter(), rewriter); 123a54f4eaeSMogball } 124a54f4eaeSMogball 125a54f4eaeSMogball //===----------------------------------------------------------------------===// 126a54f4eaeSMogball // IndexCastOpLowering 127a54f4eaeSMogball //===----------------------------------------------------------------------===// 128a54f4eaeSMogball 129a54f4eaeSMogball LogicalResult IndexCastOpLowering::matchAndRewrite( 130a54f4eaeSMogball arith::IndexCastOp op, OpAdaptor adaptor, 131a54f4eaeSMogball ConversionPatternRewriter &rewriter) const { 132a54f4eaeSMogball auto targetType = typeConverter->convertType(op.getResult().getType()); 133a54f4eaeSMogball auto targetElementType = 134a54f4eaeSMogball typeConverter->convertType(getElementTypeOrSelf(op.getResult())) 135a54f4eaeSMogball .cast<IntegerType>(); 136a54f4eaeSMogball auto sourceElementType = 137*cfb72fd3SJacques Pienaar getElementTypeOrSelf(adaptor.getIn()).cast<IntegerType>(); 138a54f4eaeSMogball unsigned targetBits = targetElementType.getWidth(); 139a54f4eaeSMogball unsigned sourceBits = sourceElementType.getWidth(); 140a54f4eaeSMogball 141a54f4eaeSMogball if (targetBits == sourceBits) 142*cfb72fd3SJacques Pienaar rewriter.replaceOp(op, adaptor.getIn()); 143a54f4eaeSMogball else if (targetBits < sourceBits) 144*cfb72fd3SJacques Pienaar rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, adaptor.getIn()); 145a54f4eaeSMogball else 146*cfb72fd3SJacques Pienaar rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, adaptor.getIn()); 147a54f4eaeSMogball return success(); 148a54f4eaeSMogball } 149a54f4eaeSMogball 150a54f4eaeSMogball //===----------------------------------------------------------------------===// 151a54f4eaeSMogball // CmpIOpLowering 152a54f4eaeSMogball //===----------------------------------------------------------------------===// 153a54f4eaeSMogball 154a54f4eaeSMogball // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums 155a54f4eaeSMogball // share numerical values so just cast. 156a54f4eaeSMogball template <typename LLVMPredType, typename PredType> 157a54f4eaeSMogball static LLVMPredType convertCmpPredicate(PredType pred) { 158a54f4eaeSMogball return static_cast<LLVMPredType>(pred); 159a54f4eaeSMogball } 160a54f4eaeSMogball 161a54f4eaeSMogball LogicalResult 162a54f4eaeSMogball CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 163a54f4eaeSMogball ConversionPatternRewriter &rewriter) const { 164*cfb72fd3SJacques Pienaar auto operandType = adaptor.getLhs().getType(); 165a54f4eaeSMogball auto resultType = op.getResult().getType(); 166a54f4eaeSMogball 167a54f4eaeSMogball // Handle the scalar and 1D vector cases. 168a54f4eaeSMogball if (!operandType.isa<LLVM::LLVMArrayType>()) { 169a54f4eaeSMogball rewriter.replaceOpWithNewOp<LLVM::ICmpOp>( 170a54f4eaeSMogball op, typeConverter->convertType(resultType), 171a54f4eaeSMogball convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), 172*cfb72fd3SJacques Pienaar adaptor.getLhs(), adaptor.getRhs()); 173a54f4eaeSMogball return success(); 174a54f4eaeSMogball } 175a54f4eaeSMogball 176a54f4eaeSMogball auto vectorType = resultType.dyn_cast<VectorType>(); 177a54f4eaeSMogball if (!vectorType) 178a54f4eaeSMogball return rewriter.notifyMatchFailure(op, "expected vector result type"); 179a54f4eaeSMogball 180a54f4eaeSMogball return LLVM::detail::handleMultidimensionalVectors( 181a54f4eaeSMogball op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 182a54f4eaeSMogball [&](Type llvm1DVectorTy, ValueRange operands) { 183a54f4eaeSMogball OpAdaptor adaptor(operands); 184a54f4eaeSMogball return rewriter.create<LLVM::ICmpOp>( 185a54f4eaeSMogball op.getLoc(), llvm1DVectorTy, 186a54f4eaeSMogball convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), 187*cfb72fd3SJacques Pienaar adaptor.getLhs(), adaptor.getRhs()); 188a54f4eaeSMogball }, 189a54f4eaeSMogball rewriter); 190a54f4eaeSMogball 191a54f4eaeSMogball return success(); 192a54f4eaeSMogball } 193a54f4eaeSMogball 194a54f4eaeSMogball //===----------------------------------------------------------------------===// 195a54f4eaeSMogball // CmpFOpLowering 196a54f4eaeSMogball //===----------------------------------------------------------------------===// 197a54f4eaeSMogball 198a54f4eaeSMogball LogicalResult 199a54f4eaeSMogball CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 200a54f4eaeSMogball ConversionPatternRewriter &rewriter) const { 201*cfb72fd3SJacques Pienaar auto operandType = adaptor.getLhs().getType(); 202a54f4eaeSMogball auto resultType = op.getResult().getType(); 203a54f4eaeSMogball 204a54f4eaeSMogball // Handle the scalar and 1D vector cases. 205a54f4eaeSMogball if (!operandType.isa<LLVM::LLVMArrayType>()) { 206a54f4eaeSMogball rewriter.replaceOpWithNewOp<LLVM::FCmpOp>( 207a54f4eaeSMogball op, typeConverter->convertType(resultType), 208a54f4eaeSMogball convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), 209*cfb72fd3SJacques Pienaar adaptor.getLhs(), adaptor.getRhs()); 210a54f4eaeSMogball return success(); 211a54f4eaeSMogball } 212a54f4eaeSMogball 213a54f4eaeSMogball auto vectorType = resultType.dyn_cast<VectorType>(); 214a54f4eaeSMogball if (!vectorType) 215a54f4eaeSMogball return rewriter.notifyMatchFailure(op, "expected vector result type"); 216a54f4eaeSMogball 217a54f4eaeSMogball return LLVM::detail::handleMultidimensionalVectors( 218a54f4eaeSMogball op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 219a54f4eaeSMogball [&](Type llvm1DVectorTy, ValueRange operands) { 220a54f4eaeSMogball OpAdaptor adaptor(operands); 221a54f4eaeSMogball return rewriter.create<LLVM::FCmpOp>( 222a54f4eaeSMogball op.getLoc(), llvm1DVectorTy, 223a54f4eaeSMogball convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), 224*cfb72fd3SJacques Pienaar adaptor.getLhs(), adaptor.getRhs()); 225a54f4eaeSMogball }, 226a54f4eaeSMogball rewriter); 227a54f4eaeSMogball } 228a54f4eaeSMogball 229a54f4eaeSMogball //===----------------------------------------------------------------------===// 230a54f4eaeSMogball // Pass Definition 231a54f4eaeSMogball //===----------------------------------------------------------------------===// 232a54f4eaeSMogball 233a54f4eaeSMogball namespace { 234a54f4eaeSMogball struct ConvertArithmeticToLLVMPass 235a54f4eaeSMogball : public ConvertArithmeticToLLVMBase<ConvertArithmeticToLLVMPass> { 236a54f4eaeSMogball ConvertArithmeticToLLVMPass() = default; 237a54f4eaeSMogball 238a54f4eaeSMogball void runOnFunction() override { 239a54f4eaeSMogball LLVMConversionTarget target(getContext()); 240a54f4eaeSMogball RewritePatternSet patterns(&getContext()); 241a54f4eaeSMogball 242a54f4eaeSMogball LowerToLLVMOptions options(&getContext()); 243a54f4eaeSMogball if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 244a54f4eaeSMogball options.overrideIndexBitwidth(indexBitwidth); 245a54f4eaeSMogball 246a54f4eaeSMogball LLVMTypeConverter converter(&getContext(), options); 247a54f4eaeSMogball mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, 248a54f4eaeSMogball patterns); 249a54f4eaeSMogball 250a54f4eaeSMogball if (failed( 251a54f4eaeSMogball applyPartialConversion(getFunction(), target, std::move(patterns)))) 252a54f4eaeSMogball signalPassFailure(); 253a54f4eaeSMogball } 254a54f4eaeSMogball }; 255a54f4eaeSMogball } // end anonymous namespace 256a54f4eaeSMogball 257a54f4eaeSMogball //===----------------------------------------------------------------------===// 258a54f4eaeSMogball // Pattern Population 259a54f4eaeSMogball //===----------------------------------------------------------------------===// 260a54f4eaeSMogball 261a54f4eaeSMogball void mlir::arith::populateArithmeticToLLVMConversionPatterns( 262a54f4eaeSMogball LLVMTypeConverter &converter, RewritePatternSet &patterns) { 263a54f4eaeSMogball // clang-format off 264a54f4eaeSMogball patterns.add< 265a54f4eaeSMogball ConstantOpLowering, 266a54f4eaeSMogball AddIOpLowering, 267a54f4eaeSMogball SubIOpLowering, 268a54f4eaeSMogball MulIOpLowering, 269a54f4eaeSMogball DivUIOpLowering, 270a54f4eaeSMogball DivSIOpLowering, 271a54f4eaeSMogball RemUIOpLowering, 272a54f4eaeSMogball RemSIOpLowering, 273a54f4eaeSMogball AndIOpLowering, 274a54f4eaeSMogball OrIOpLowering, 275a54f4eaeSMogball XOrIOpLowering, 276a54f4eaeSMogball ShLIOpLowering, 277a54f4eaeSMogball ShRUIOpLowering, 278a54f4eaeSMogball ShRSIOpLowering, 279a54f4eaeSMogball NegFOpLowering, 280a54f4eaeSMogball AddFOpLowering, 281a54f4eaeSMogball SubFOpLowering, 282a54f4eaeSMogball MulFOpLowering, 283a54f4eaeSMogball DivFOpLowering, 284a54f4eaeSMogball RemFOpLowering, 285a54f4eaeSMogball ExtUIOpLowering, 286a54f4eaeSMogball ExtSIOpLowering, 287a54f4eaeSMogball ExtFOpLowering, 288a54f4eaeSMogball TruncIOpLowering, 289a54f4eaeSMogball TruncFOpLowering, 290a54f4eaeSMogball UIToFPOpLowering, 291a54f4eaeSMogball SIToFPOpLowering, 292a54f4eaeSMogball FPToUIOpLowering, 293a54f4eaeSMogball FPToSIOpLowering, 294a54f4eaeSMogball IndexCastOpLowering, 295a54f4eaeSMogball BitcastOpLowering, 296a54f4eaeSMogball CmpIOpLowering, 297a54f4eaeSMogball CmpFOpLowering 298a54f4eaeSMogball >(converter); 299a54f4eaeSMogball // clang-format on 300a54f4eaeSMogball } 301a54f4eaeSMogball 302a54f4eaeSMogball std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToLLVMPass() { 303a54f4eaeSMogball return std::make_unique<ConvertArithmeticToLLVMPass>(); 304a54f4eaeSMogball } 305