1 //===- ArithmeticToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/IR/TypeUtilities.h" 16 17 using namespace mlir; 18 19 namespace { 20 21 //===----------------------------------------------------------------------===// 22 // Straightforward Op Lowerings 23 //===----------------------------------------------------------------------===// 24 25 using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>; 26 using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>; 27 using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>; 28 using DivUIOpLowering = 29 VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>; 30 using DivSIOpLowering = 31 VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>; 32 using RemUIOpLowering = 33 VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>; 34 using RemSIOpLowering = 35 VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>; 36 using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>; 37 using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>; 38 using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>; 39 using ShLIOpLowering = VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp>; 40 using ShRUIOpLowering = 41 VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>; 42 using ShRSIOpLowering = 43 VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>; 44 using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp>; 45 using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp>; 46 using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp>; 47 using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp>; 48 using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp>; 49 using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp>; 50 using ExtUIOpLowering = 51 VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>; 52 using ExtSIOpLowering = 53 VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>; 54 using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>; 55 using TruncIOpLowering = 56 VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>; 57 using TruncFOpLowering = 58 VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>; 59 using UIToFPOpLowering = 60 VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>; 61 using SIToFPOpLowering = 62 VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>; 63 using FPToUIOpLowering = 64 VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>; 65 using FPToSIOpLowering = 66 VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>; 67 using BitcastOpLowering = 68 VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>; 69 70 //===----------------------------------------------------------------------===// 71 // Op Lowering Patterns 72 //===----------------------------------------------------------------------===// 73 74 /// Directly lower to LLVM op. 75 struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> { 76 using ConvertOpToLLVMPattern<arith::ConstantOp>::ConvertOpToLLVMPattern; 77 78 LogicalResult 79 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 80 ConversionPatternRewriter &rewriter) const override; 81 }; 82 83 /// The lowering of index_cast becomes an integer conversion since index 84 /// becomes an integer. If the bit width of the source and target integer 85 /// types is the same, just erase the cast. If the target type is wider, 86 /// sign-extend the value, otherwise truncate it. 87 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<arith::IndexCastOp> { 88 using ConvertOpToLLVMPattern<arith::IndexCastOp>::ConvertOpToLLVMPattern; 89 90 LogicalResult 91 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, 92 ConversionPatternRewriter &rewriter) const override; 93 }; 94 95 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> { 96 using ConvertOpToLLVMPattern<arith::CmpIOp>::ConvertOpToLLVMPattern; 97 98 LogicalResult 99 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 100 ConversionPatternRewriter &rewriter) const override; 101 }; 102 103 struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> { 104 using ConvertOpToLLVMPattern<arith::CmpFOp>::ConvertOpToLLVMPattern; 105 106 LogicalResult 107 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 108 ConversionPatternRewriter &rewriter) const override; 109 }; 110 111 } // namespace 112 113 //===----------------------------------------------------------------------===// 114 // ConstantOpLowering 115 //===----------------------------------------------------------------------===// 116 117 LogicalResult 118 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 119 ConversionPatternRewriter &rewriter) const { 120 return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), 121 adaptor.getOperands(), 122 *getTypeConverter(), rewriter); 123 } 124 125 //===----------------------------------------------------------------------===// 126 // IndexCastOpLowering 127 //===----------------------------------------------------------------------===// 128 129 LogicalResult IndexCastOpLowering::matchAndRewrite( 130 arith::IndexCastOp op, OpAdaptor adaptor, 131 ConversionPatternRewriter &rewriter) const { 132 auto targetType = typeConverter->convertType(op.getResult().getType()); 133 auto targetElementType = 134 typeConverter->convertType(getElementTypeOrSelf(op.getResult())) 135 .cast<IntegerType>(); 136 auto sourceElementType = 137 getElementTypeOrSelf(adaptor.getIn()).cast<IntegerType>(); 138 unsigned targetBits = targetElementType.getWidth(); 139 unsigned sourceBits = sourceElementType.getWidth(); 140 141 if (targetBits == sourceBits) 142 rewriter.replaceOp(op, adaptor.getIn()); 143 else if (targetBits < sourceBits) 144 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, adaptor.getIn()); 145 else 146 rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, adaptor.getIn()); 147 return success(); 148 } 149 150 //===----------------------------------------------------------------------===// 151 // CmpIOpLowering 152 //===----------------------------------------------------------------------===// 153 154 // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums 155 // share numerical values so just cast. 156 template <typename LLVMPredType, typename PredType> 157 static LLVMPredType convertCmpPredicate(PredType pred) { 158 return static_cast<LLVMPredType>(pred); 159 } 160 161 LogicalResult 162 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 163 ConversionPatternRewriter &rewriter) const { 164 auto operandType = adaptor.getLhs().getType(); 165 auto resultType = op.getResult().getType(); 166 167 // Handle the scalar and 1D vector cases. 168 if (!operandType.isa<LLVM::LLVMArrayType>()) { 169 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>( 170 op, typeConverter->convertType(resultType), 171 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), 172 adaptor.getLhs(), adaptor.getRhs()); 173 return success(); 174 } 175 176 auto vectorType = resultType.dyn_cast<VectorType>(); 177 if (!vectorType) 178 return rewriter.notifyMatchFailure(op, "expected vector result type"); 179 180 return LLVM::detail::handleMultidimensionalVectors( 181 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 182 [&](Type llvm1DVectorTy, ValueRange operands) { 183 OpAdaptor adaptor(operands); 184 return rewriter.create<LLVM::ICmpOp>( 185 op.getLoc(), llvm1DVectorTy, 186 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), 187 adaptor.getLhs(), adaptor.getRhs()); 188 }, 189 rewriter); 190 } 191 192 //===----------------------------------------------------------------------===// 193 // CmpFOpLowering 194 //===----------------------------------------------------------------------===// 195 196 LogicalResult 197 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 198 ConversionPatternRewriter &rewriter) const { 199 auto operandType = adaptor.getLhs().getType(); 200 auto resultType = op.getResult().getType(); 201 202 // Handle the scalar and 1D vector cases. 203 if (!operandType.isa<LLVM::LLVMArrayType>()) { 204 rewriter.replaceOpWithNewOp<LLVM::FCmpOp>( 205 op, typeConverter->convertType(resultType), 206 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), 207 adaptor.getLhs(), adaptor.getRhs()); 208 return success(); 209 } 210 211 auto vectorType = resultType.dyn_cast<VectorType>(); 212 if (!vectorType) 213 return rewriter.notifyMatchFailure(op, "expected vector result type"); 214 215 return LLVM::detail::handleMultidimensionalVectors( 216 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 217 [&](Type llvm1DVectorTy, ValueRange operands) { 218 OpAdaptor adaptor(operands); 219 return rewriter.create<LLVM::FCmpOp>( 220 op.getLoc(), llvm1DVectorTy, 221 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), 222 adaptor.getLhs(), adaptor.getRhs()); 223 }, 224 rewriter); 225 } 226 227 //===----------------------------------------------------------------------===// 228 // Pass Definition 229 //===----------------------------------------------------------------------===// 230 231 namespace { 232 struct ConvertArithmeticToLLVMPass 233 : public ConvertArithmeticToLLVMBase<ConvertArithmeticToLLVMPass> { 234 ConvertArithmeticToLLVMPass() = default; 235 236 void runOnOperation() override { 237 LLVMConversionTarget target(getContext()); 238 RewritePatternSet patterns(&getContext()); 239 240 LowerToLLVMOptions options(&getContext()); 241 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 242 options.overrideIndexBitwidth(indexBitwidth); 243 244 LLVMTypeConverter converter(&getContext(), options); 245 mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, 246 patterns); 247 248 if (failed(applyPartialConversion(getOperation(), target, 249 std::move(patterns)))) 250 signalPassFailure(); 251 } 252 }; 253 } // namespace 254 255 //===----------------------------------------------------------------------===// 256 // Pattern Population 257 //===----------------------------------------------------------------------===// 258 259 void mlir::arith::populateArithmeticToLLVMConversionPatterns( 260 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 261 // clang-format off 262 patterns.add< 263 ConstantOpLowering, 264 AddIOpLowering, 265 SubIOpLowering, 266 MulIOpLowering, 267 DivUIOpLowering, 268 DivSIOpLowering, 269 RemUIOpLowering, 270 RemSIOpLowering, 271 AndIOpLowering, 272 OrIOpLowering, 273 XOrIOpLowering, 274 ShLIOpLowering, 275 ShRUIOpLowering, 276 ShRSIOpLowering, 277 NegFOpLowering, 278 AddFOpLowering, 279 SubFOpLowering, 280 MulFOpLowering, 281 DivFOpLowering, 282 RemFOpLowering, 283 ExtUIOpLowering, 284 ExtSIOpLowering, 285 ExtFOpLowering, 286 TruncIOpLowering, 287 TruncFOpLowering, 288 UIToFPOpLowering, 289 SIToFPOpLowering, 290 FPToUIOpLowering, 291 FPToSIOpLowering, 292 IndexCastOpLowering, 293 BitcastOpLowering, 294 CmpIOpLowering, 295 CmpFOpLowering 296 >(converter); 297 // clang-format on 298 } 299 300 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToLLVMPass() { 301 return std::make_unique<ConvertArithmeticToLLVMPass>(); 302 } 303