1 //===- ArithmeticToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 10 #include "../PassDetail.h" 11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 12 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 15 #include "mlir/IR/TypeUtilities.h" 16 17 using namespace mlir; 18 19 namespace { 20 21 //===----------------------------------------------------------------------===// 22 // Straightforward Op Lowerings 23 //===----------------------------------------------------------------------===// 24 25 using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>; 26 using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>; 27 using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>; 28 using DivUIOpLowering = 29 VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>; 30 using DivSIOpLowering = 31 VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>; 32 using RemUIOpLowering = 33 VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>; 34 using RemSIOpLowering = 35 VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>; 36 using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>; 37 using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>; 38 using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>; 39 using ShLIOpLowering = VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp>; 40 using ShRUIOpLowering = 41 VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>; 42 using ShRSIOpLowering = 43 VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>; 44 using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp>; 45 using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp>; 46 using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp>; 47 using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp>; 48 using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp>; 49 using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp>; 50 using ExtUIOpLowering = 51 VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>; 52 using ExtSIOpLowering = 53 VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>; 54 using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>; 55 using TruncIOpLowering = 56 VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>; 57 using TruncFOpLowering = 58 VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>; 59 using UIToFPOpLowering = 60 VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>; 61 using SIToFPOpLowering = 62 VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>; 63 using FPToUIOpLowering = 64 VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>; 65 using FPToSIOpLowering = 66 VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>; 67 using BitcastOpLowering = 68 VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>; 69 70 //===----------------------------------------------------------------------===// 71 // Op Lowering Patterns 72 //===----------------------------------------------------------------------===// 73 74 /// Directly lower to LLVM op. 75 struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> { 76 using ConvertOpToLLVMPattern<arith::ConstantOp>::ConvertOpToLLVMPattern; 77 78 LogicalResult 79 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 80 ConversionPatternRewriter &rewriter) const override; 81 }; 82 83 /// The lowering of index_cast becomes an integer conversion since index 84 /// becomes an integer. If the bit width of the source and target integer 85 /// types is the same, just erase the cast. If the target type is wider, 86 /// sign-extend the value, otherwise truncate it. 87 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<arith::IndexCastOp> { 88 using ConvertOpToLLVMPattern<arith::IndexCastOp>::ConvertOpToLLVMPattern; 89 90 LogicalResult 91 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, 92 ConversionPatternRewriter &rewriter) const override; 93 }; 94 95 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> { 96 using ConvertOpToLLVMPattern<arith::CmpIOp>::ConvertOpToLLVMPattern; 97 98 LogicalResult 99 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 100 ConversionPatternRewriter &rewriter) const override; 101 }; 102 103 struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> { 104 using ConvertOpToLLVMPattern<arith::CmpFOp>::ConvertOpToLLVMPattern; 105 106 LogicalResult 107 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 108 ConversionPatternRewriter &rewriter) const override; 109 }; 110 111 } // namespace 112 113 //===----------------------------------------------------------------------===// 114 // ConstantOpLowering 115 //===----------------------------------------------------------------------===// 116 117 LogicalResult 118 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 119 ConversionPatternRewriter &rewriter) const { 120 return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), 121 adaptor.getOperands(), 122 *getTypeConverter(), rewriter); 123 } 124 125 //===----------------------------------------------------------------------===// 126 // IndexCastOpLowering 127 //===----------------------------------------------------------------------===// 128 129 LogicalResult IndexCastOpLowering::matchAndRewrite( 130 arith::IndexCastOp op, OpAdaptor adaptor, 131 ConversionPatternRewriter &rewriter) const { 132 auto targetType = typeConverter->convertType(op.getResult().getType()); 133 auto targetElementType = 134 typeConverter->convertType(getElementTypeOrSelf(op.getResult())) 135 .cast<IntegerType>(); 136 auto sourceElementType = 137 getElementTypeOrSelf(adaptor.getIn()).cast<IntegerType>(); 138 unsigned targetBits = targetElementType.getWidth(); 139 unsigned sourceBits = sourceElementType.getWidth(); 140 141 if (targetBits == sourceBits) 142 rewriter.replaceOp(op, adaptor.getIn()); 143 else if (targetBits < sourceBits) 144 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, adaptor.getIn()); 145 else 146 rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, adaptor.getIn()); 147 return success(); 148 } 149 150 //===----------------------------------------------------------------------===// 151 // CmpIOpLowering 152 //===----------------------------------------------------------------------===// 153 154 // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums 155 // share numerical values so just cast. 156 template <typename LLVMPredType, typename PredType> 157 static LLVMPredType convertCmpPredicate(PredType pred) { 158 return static_cast<LLVMPredType>(pred); 159 } 160 161 LogicalResult 162 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 163 ConversionPatternRewriter &rewriter) const { 164 auto operandType = adaptor.getLhs().getType(); 165 auto resultType = op.getResult().getType(); 166 167 // Handle the scalar and 1D vector cases. 168 if (!operandType.isa<LLVM::LLVMArrayType>()) { 169 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>( 170 op, typeConverter->convertType(resultType), 171 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), 172 adaptor.getLhs(), adaptor.getRhs()); 173 return success(); 174 } 175 176 auto vectorType = resultType.dyn_cast<VectorType>(); 177 if (!vectorType) 178 return rewriter.notifyMatchFailure(op, "expected vector result type"); 179 180 return LLVM::detail::handleMultidimensionalVectors( 181 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 182 [&](Type llvm1DVectorTy, ValueRange operands) { 183 OpAdaptor adaptor(operands); 184 return rewriter.create<LLVM::ICmpOp>( 185 op.getLoc(), llvm1DVectorTy, 186 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), 187 adaptor.getLhs(), adaptor.getRhs()); 188 }, 189 rewriter); 190 191 return success(); 192 } 193 194 //===----------------------------------------------------------------------===// 195 // CmpFOpLowering 196 //===----------------------------------------------------------------------===// 197 198 LogicalResult 199 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 200 ConversionPatternRewriter &rewriter) const { 201 auto operandType = adaptor.getLhs().getType(); 202 auto resultType = op.getResult().getType(); 203 204 // Handle the scalar and 1D vector cases. 205 if (!operandType.isa<LLVM::LLVMArrayType>()) { 206 rewriter.replaceOpWithNewOp<LLVM::FCmpOp>( 207 op, typeConverter->convertType(resultType), 208 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), 209 adaptor.getLhs(), adaptor.getRhs()); 210 return success(); 211 } 212 213 auto vectorType = resultType.dyn_cast<VectorType>(); 214 if (!vectorType) 215 return rewriter.notifyMatchFailure(op, "expected vector result type"); 216 217 return LLVM::detail::handleMultidimensionalVectors( 218 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 219 [&](Type llvm1DVectorTy, ValueRange operands) { 220 OpAdaptor adaptor(operands); 221 return rewriter.create<LLVM::FCmpOp>( 222 op.getLoc(), llvm1DVectorTy, 223 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), 224 adaptor.getLhs(), adaptor.getRhs()); 225 }, 226 rewriter); 227 } 228 229 //===----------------------------------------------------------------------===// 230 // Pass Definition 231 //===----------------------------------------------------------------------===// 232 233 namespace { 234 struct ConvertArithmeticToLLVMPass 235 : public ConvertArithmeticToLLVMBase<ConvertArithmeticToLLVMPass> { 236 ConvertArithmeticToLLVMPass() = default; 237 238 void runOnFunction() override { 239 LLVMConversionTarget target(getContext()); 240 RewritePatternSet patterns(&getContext()); 241 242 LowerToLLVMOptions options(&getContext()); 243 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 244 options.overrideIndexBitwidth(indexBitwidth); 245 246 LLVMTypeConverter converter(&getContext(), options); 247 mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, 248 patterns); 249 250 if (failed( 251 applyPartialConversion(getFunction(), target, std::move(patterns)))) 252 signalPassFailure(); 253 } 254 }; 255 } // namespace 256 257 //===----------------------------------------------------------------------===// 258 // Pattern Population 259 //===----------------------------------------------------------------------===// 260 261 void mlir::arith::populateArithmeticToLLVMConversionPatterns( 262 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 263 // clang-format off 264 patterns.add< 265 ConstantOpLowering, 266 AddIOpLowering, 267 SubIOpLowering, 268 MulIOpLowering, 269 DivUIOpLowering, 270 DivSIOpLowering, 271 RemUIOpLowering, 272 RemSIOpLowering, 273 AndIOpLowering, 274 OrIOpLowering, 275 XOrIOpLowering, 276 ShLIOpLowering, 277 ShRUIOpLowering, 278 ShRSIOpLowering, 279 NegFOpLowering, 280 AddFOpLowering, 281 SubFOpLowering, 282 MulFOpLowering, 283 DivFOpLowering, 284 RemFOpLowering, 285 ExtUIOpLowering, 286 ExtSIOpLowering, 287 ExtFOpLowering, 288 TruncIOpLowering, 289 TruncFOpLowering, 290 UIToFPOpLowering, 291 SIToFPOpLowering, 292 FPToUIOpLowering, 293 FPToSIOpLowering, 294 IndexCastOpLowering, 295 BitcastOpLowering, 296 CmpIOpLowering, 297 CmpFOpLowering 298 >(converter); 299 // clang-format on 300 } 301 302 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToLLVMPass() { 303 return std::make_unique<ConvertArithmeticToLLVMPass>(); 304 } 305