1 //===- ArithmeticToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/IR/TypeUtilities.h" 16 17 using namespace mlir; 18 19 namespace { 20 21 //===----------------------------------------------------------------------===// 22 // Straightforward Op Lowerings 23 //===----------------------------------------------------------------------===// 24 25 using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>; 26 using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>; 27 using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>; 28 using DivUIOpLowering = 29 VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>; 30 using DivSIOpLowering = 31 VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>; 32 using RemUIOpLowering = 33 VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>; 34 using RemSIOpLowering = 35 VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>; 36 using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>; 37 using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>; 38 using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>; 39 using ShLIOpLowering = VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp>; 40 using ShRUIOpLowering = 41 VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>; 42 using ShRSIOpLowering = 43 VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>; 44 using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp>; 45 using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp>; 46 using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp>; 47 using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp>; 48 using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp>; 49 using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp>; 50 using ExtUIOpLowering = 51 VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>; 52 using ExtSIOpLowering = 53 VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>; 54 using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>; 55 using TruncIOpLowering = 56 VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>; 57 using TruncFOpLowering = 58 VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>; 59 using UIToFPOpLowering = 60 VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>; 61 using SIToFPOpLowering = 62 VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>; 63 using FPToUIOpLowering = 64 VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>; 65 using FPToSIOpLowering = 66 VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>; 67 using BitcastOpLowering = 68 VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>; 69 using SelectOpLowering = 70 VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>; 71 72 //===----------------------------------------------------------------------===// 73 // Op Lowering Patterns 74 //===----------------------------------------------------------------------===// 75 76 /// Directly lower to LLVM op. 77 struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> { 78 using ConvertOpToLLVMPattern<arith::ConstantOp>::ConvertOpToLLVMPattern; 79 80 LogicalResult 81 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 82 ConversionPatternRewriter &rewriter) const override; 83 }; 84 85 /// The lowering of index_cast becomes an integer conversion since index 86 /// becomes an integer. If the bit width of the source and target integer 87 /// types is the same, just erase the cast. If the target type is wider, 88 /// sign-extend the value, otherwise truncate it. 89 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<arith::IndexCastOp> { 90 using ConvertOpToLLVMPattern<arith::IndexCastOp>::ConvertOpToLLVMPattern; 91 92 LogicalResult 93 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, 94 ConversionPatternRewriter &rewriter) const override; 95 }; 96 97 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> { 98 using ConvertOpToLLVMPattern<arith::CmpIOp>::ConvertOpToLLVMPattern; 99 100 LogicalResult 101 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 102 ConversionPatternRewriter &rewriter) const override; 103 }; 104 105 struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> { 106 using ConvertOpToLLVMPattern<arith::CmpFOp>::ConvertOpToLLVMPattern; 107 108 LogicalResult 109 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 110 ConversionPatternRewriter &rewriter) const override; 111 }; 112 113 } // namespace 114 115 //===----------------------------------------------------------------------===// 116 // ConstantOpLowering 117 //===----------------------------------------------------------------------===// 118 119 LogicalResult 120 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 121 ConversionPatternRewriter &rewriter) const { 122 return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), 123 adaptor.getOperands(), 124 *getTypeConverter(), rewriter); 125 } 126 127 //===----------------------------------------------------------------------===// 128 // IndexCastOpLowering 129 //===----------------------------------------------------------------------===// 130 131 LogicalResult IndexCastOpLowering::matchAndRewrite( 132 arith::IndexCastOp op, OpAdaptor adaptor, 133 ConversionPatternRewriter &rewriter) const { 134 auto targetType = typeConverter->convertType(op.getResult().getType()); 135 auto targetElementType = 136 typeConverter->convertType(getElementTypeOrSelf(op.getResult())) 137 .cast<IntegerType>(); 138 auto sourceElementType = 139 getElementTypeOrSelf(adaptor.getIn()).cast<IntegerType>(); 140 unsigned targetBits = targetElementType.getWidth(); 141 unsigned sourceBits = sourceElementType.getWidth(); 142 143 if (targetBits == sourceBits) 144 rewriter.replaceOp(op, adaptor.getIn()); 145 else if (targetBits < sourceBits) 146 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, adaptor.getIn()); 147 else 148 rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, adaptor.getIn()); 149 return success(); 150 } 151 152 //===----------------------------------------------------------------------===// 153 // CmpIOpLowering 154 //===----------------------------------------------------------------------===// 155 156 // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums 157 // share numerical values so just cast. 158 template <typename LLVMPredType, typename PredType> 159 static LLVMPredType convertCmpPredicate(PredType pred) { 160 return static_cast<LLVMPredType>(pred); 161 } 162 163 LogicalResult 164 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 165 ConversionPatternRewriter &rewriter) const { 166 auto operandType = adaptor.getLhs().getType(); 167 auto resultType = op.getResult().getType(); 168 169 // Handle the scalar and 1D vector cases. 170 if (!operandType.isa<LLVM::LLVMArrayType>()) { 171 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>( 172 op, typeConverter->convertType(resultType), 173 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), 174 adaptor.getLhs(), adaptor.getRhs()); 175 return success(); 176 } 177 178 auto vectorType = resultType.dyn_cast<VectorType>(); 179 if (!vectorType) 180 return rewriter.notifyMatchFailure(op, "expected vector result type"); 181 182 return LLVM::detail::handleMultidimensionalVectors( 183 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 184 [&](Type llvm1DVectorTy, ValueRange operands) { 185 OpAdaptor adaptor(operands); 186 return rewriter.create<LLVM::ICmpOp>( 187 op.getLoc(), llvm1DVectorTy, 188 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), 189 adaptor.getLhs(), adaptor.getRhs()); 190 }, 191 rewriter); 192 } 193 194 //===----------------------------------------------------------------------===// 195 // CmpFOpLowering 196 //===----------------------------------------------------------------------===// 197 198 LogicalResult 199 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 200 ConversionPatternRewriter &rewriter) const { 201 auto operandType = adaptor.getLhs().getType(); 202 auto resultType = op.getResult().getType(); 203 204 // Handle the scalar and 1D vector cases. 205 if (!operandType.isa<LLVM::LLVMArrayType>()) { 206 rewriter.replaceOpWithNewOp<LLVM::FCmpOp>( 207 op, typeConverter->convertType(resultType), 208 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), 209 adaptor.getLhs(), adaptor.getRhs()); 210 return success(); 211 } 212 213 auto vectorType = resultType.dyn_cast<VectorType>(); 214 if (!vectorType) 215 return rewriter.notifyMatchFailure(op, "expected vector result type"); 216 217 return LLVM::detail::handleMultidimensionalVectors( 218 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 219 [&](Type llvm1DVectorTy, ValueRange operands) { 220 OpAdaptor adaptor(operands); 221 return rewriter.create<LLVM::FCmpOp>( 222 op.getLoc(), llvm1DVectorTy, 223 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), 224 adaptor.getLhs(), adaptor.getRhs()); 225 }, 226 rewriter); 227 } 228 229 //===----------------------------------------------------------------------===// 230 // Pass Definition 231 //===----------------------------------------------------------------------===// 232 233 namespace { 234 struct ConvertArithmeticToLLVMPass 235 : public ConvertArithmeticToLLVMBase<ConvertArithmeticToLLVMPass> { 236 ConvertArithmeticToLLVMPass() = default; 237 238 void runOnOperation() override { 239 LLVMConversionTarget target(getContext()); 240 RewritePatternSet patterns(&getContext()); 241 242 LowerToLLVMOptions options(&getContext()); 243 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 244 options.overrideIndexBitwidth(indexBitwidth); 245 246 LLVMTypeConverter converter(&getContext(), options); 247 mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, 248 patterns); 249 250 if (failed(applyPartialConversion(getOperation(), target, 251 std::move(patterns)))) 252 signalPassFailure(); 253 } 254 }; 255 } // namespace 256 257 //===----------------------------------------------------------------------===// 258 // Pattern Population 259 //===----------------------------------------------------------------------===// 260 261 void mlir::arith::populateArithmeticToLLVMConversionPatterns( 262 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 263 // clang-format off 264 patterns.add< 265 ConstantOpLowering, 266 AddIOpLowering, 267 SubIOpLowering, 268 MulIOpLowering, 269 DivUIOpLowering, 270 DivSIOpLowering, 271 RemUIOpLowering, 272 RemSIOpLowering, 273 AndIOpLowering, 274 OrIOpLowering, 275 XOrIOpLowering, 276 ShLIOpLowering, 277 ShRUIOpLowering, 278 ShRSIOpLowering, 279 NegFOpLowering, 280 AddFOpLowering, 281 SubFOpLowering, 282 MulFOpLowering, 283 DivFOpLowering, 284 RemFOpLowering, 285 ExtUIOpLowering, 286 ExtSIOpLowering, 287 ExtFOpLowering, 288 TruncIOpLowering, 289 TruncFOpLowering, 290 UIToFPOpLowering, 291 SIToFPOpLowering, 292 FPToUIOpLowering, 293 FPToSIOpLowering, 294 IndexCastOpLowering, 295 BitcastOpLowering, 296 CmpIOpLowering, 297 CmpFOpLowering, 298 SelectOpLowering 299 >(converter); 300 // clang-format on 301 } 302 303 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToLLVMPass() { 304 return std::make_unique<ConvertArithmeticToLLVMPass>(); 305 } 306