1 //===- ArithmeticToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/IR/TypeUtilities.h"
16
17 using namespace mlir;
18
19 namespace {
20
21 //===----------------------------------------------------------------------===//
22 // Straightforward Op Lowerings
23 //===----------------------------------------------------------------------===//
24
25 using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
26 using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
27 using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
28 using DivUIOpLowering =
29 VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>;
30 using DivSIOpLowering =
31 VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
32 using RemUIOpLowering =
33 VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
34 using RemSIOpLowering =
35 VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
36 using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
37 using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
38 using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
39 using ShLIOpLowering = VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp>;
40 using ShRUIOpLowering =
41 VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>;
42 using ShRSIOpLowering =
43 VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>;
44 using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp>;
45 using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp>;
46 using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp>;
47 using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp>;
48 using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp>;
49 using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp>;
50 using ExtUIOpLowering =
51 VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>;
52 using ExtSIOpLowering =
53 VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>;
54 using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>;
55 using TruncIOpLowering =
56 VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
57 using TruncFOpLowering =
58 VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
59 using UIToFPOpLowering =
60 VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
61 using SIToFPOpLowering =
62 VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
63 using FPToUIOpLowering =
64 VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
65 using FPToSIOpLowering =
66 VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
67 using BitcastOpLowering =
68 VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
69 using SelectOpLowering =
70 VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>;
71
72 //===----------------------------------------------------------------------===//
73 // Op Lowering Patterns
74 //===----------------------------------------------------------------------===//
75
76 /// Directly lower to LLVM op.
77 struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
78 using ConvertOpToLLVMPattern<arith::ConstantOp>::ConvertOpToLLVMPattern;
79
80 LogicalResult
81 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
82 ConversionPatternRewriter &rewriter) const override;
83 };
84
85 /// The lowering of index_cast becomes an integer conversion since index
86 /// becomes an integer. If the bit width of the source and target integer
87 /// types is the same, just erase the cast. If the target type is wider,
88 /// sign-extend the value, otherwise truncate it.
89 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<arith::IndexCastOp> {
90 using ConvertOpToLLVMPattern<arith::IndexCastOp>::ConvertOpToLLVMPattern;
91
92 LogicalResult
93 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
94 ConversionPatternRewriter &rewriter) const override;
95 };
96
97 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
98 using ConvertOpToLLVMPattern<arith::CmpIOp>::ConvertOpToLLVMPattern;
99
100 LogicalResult
101 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
102 ConversionPatternRewriter &rewriter) const override;
103 };
104
105 struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
106 using ConvertOpToLLVMPattern<arith::CmpFOp>::ConvertOpToLLVMPattern;
107
108 LogicalResult
109 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
110 ConversionPatternRewriter &rewriter) const override;
111 };
112
113 } // namespace
114
115 //===----------------------------------------------------------------------===//
116 // ConstantOpLowering
117 //===----------------------------------------------------------------------===//
118
119 LogicalResult
matchAndRewrite(arith::ConstantOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const120 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
121 ConversionPatternRewriter &rewriter) const {
122 return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
123 adaptor.getOperands(),
124 *getTypeConverter(), rewriter);
125 }
126
127 //===----------------------------------------------------------------------===//
128 // IndexCastOpLowering
129 //===----------------------------------------------------------------------===//
130
matchAndRewrite(arith::IndexCastOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const131 LogicalResult IndexCastOpLowering::matchAndRewrite(
132 arith::IndexCastOp op, OpAdaptor adaptor,
133 ConversionPatternRewriter &rewriter) const {
134 auto targetType = typeConverter->convertType(op.getResult().getType());
135 auto targetElementType =
136 typeConverter->convertType(getElementTypeOrSelf(op.getResult()))
137 .cast<IntegerType>();
138 auto sourceElementType =
139 getElementTypeOrSelf(adaptor.getIn()).cast<IntegerType>();
140 unsigned targetBits = targetElementType.getWidth();
141 unsigned sourceBits = sourceElementType.getWidth();
142
143 if (targetBits == sourceBits)
144 rewriter.replaceOp(op, adaptor.getIn());
145 else if (targetBits < sourceBits)
146 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, adaptor.getIn());
147 else
148 rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, adaptor.getIn());
149 return success();
150 }
151
152 //===----------------------------------------------------------------------===//
153 // CmpIOpLowering
154 //===----------------------------------------------------------------------===//
155
156 // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
157 // share numerical values so just cast.
158 template <typename LLVMPredType, typename PredType>
convertCmpPredicate(PredType pred)159 static LLVMPredType convertCmpPredicate(PredType pred) {
160 return static_cast<LLVMPredType>(pred);
161 }
162
163 LogicalResult
matchAndRewrite(arith::CmpIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const164 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
165 ConversionPatternRewriter &rewriter) const {
166 auto operandType = adaptor.getLhs().getType();
167 auto resultType = op.getResult().getType();
168
169 // Handle the scalar and 1D vector cases.
170 if (!operandType.isa<LLVM::LLVMArrayType>()) {
171 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
172 op, typeConverter->convertType(resultType),
173 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
174 adaptor.getLhs(), adaptor.getRhs());
175 return success();
176 }
177
178 auto vectorType = resultType.dyn_cast<VectorType>();
179 if (!vectorType)
180 return rewriter.notifyMatchFailure(op, "expected vector result type");
181
182 return LLVM::detail::handleMultidimensionalVectors(
183 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
184 [&](Type llvm1DVectorTy, ValueRange operands) {
185 OpAdaptor adaptor(operands);
186 return rewriter.create<LLVM::ICmpOp>(
187 op.getLoc(), llvm1DVectorTy,
188 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
189 adaptor.getLhs(), adaptor.getRhs());
190 },
191 rewriter);
192 }
193
194 //===----------------------------------------------------------------------===//
195 // CmpFOpLowering
196 //===----------------------------------------------------------------------===//
197
198 LogicalResult
matchAndRewrite(arith::CmpFOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const199 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
200 ConversionPatternRewriter &rewriter) const {
201 auto operandType = adaptor.getLhs().getType();
202 auto resultType = op.getResult().getType();
203
204 // Handle the scalar and 1D vector cases.
205 if (!operandType.isa<LLVM::LLVMArrayType>()) {
206 rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
207 op, typeConverter->convertType(resultType),
208 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
209 adaptor.getLhs(), adaptor.getRhs());
210 return success();
211 }
212
213 auto vectorType = resultType.dyn_cast<VectorType>();
214 if (!vectorType)
215 return rewriter.notifyMatchFailure(op, "expected vector result type");
216
217 return LLVM::detail::handleMultidimensionalVectors(
218 op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
219 [&](Type llvm1DVectorTy, ValueRange operands) {
220 OpAdaptor adaptor(operands);
221 return rewriter.create<LLVM::FCmpOp>(
222 op.getLoc(), llvm1DVectorTy,
223 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
224 adaptor.getLhs(), adaptor.getRhs());
225 },
226 rewriter);
227 }
228
229 //===----------------------------------------------------------------------===//
230 // Pass Definition
231 //===----------------------------------------------------------------------===//
232
233 namespace {
234 struct ConvertArithmeticToLLVMPass
235 : public ConvertArithmeticToLLVMBase<ConvertArithmeticToLLVMPass> {
236 ConvertArithmeticToLLVMPass() = default;
237
runOnOperation__anonb4369bed0411::ConvertArithmeticToLLVMPass238 void runOnOperation() override {
239 LLVMConversionTarget target(getContext());
240 RewritePatternSet patterns(&getContext());
241
242 LowerToLLVMOptions options(&getContext());
243 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
244 options.overrideIndexBitwidth(indexBitwidth);
245
246 LLVMTypeConverter converter(&getContext(), options);
247 mlir::arith::populateArithmeticToLLVMConversionPatterns(converter,
248 patterns);
249
250 if (failed(applyPartialConversion(getOperation(), target,
251 std::move(patterns))))
252 signalPassFailure();
253 }
254 };
255 } // namespace
256
257 //===----------------------------------------------------------------------===//
258 // Pattern Population
259 //===----------------------------------------------------------------------===//
260
populateArithmeticToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)261 void mlir::arith::populateArithmeticToLLVMConversionPatterns(
262 LLVMTypeConverter &converter, RewritePatternSet &patterns) {
263 // clang-format off
264 patterns.add<
265 ConstantOpLowering,
266 AddIOpLowering,
267 SubIOpLowering,
268 MulIOpLowering,
269 DivUIOpLowering,
270 DivSIOpLowering,
271 RemUIOpLowering,
272 RemSIOpLowering,
273 AndIOpLowering,
274 OrIOpLowering,
275 XOrIOpLowering,
276 ShLIOpLowering,
277 ShRUIOpLowering,
278 ShRSIOpLowering,
279 NegFOpLowering,
280 AddFOpLowering,
281 SubFOpLowering,
282 MulFOpLowering,
283 DivFOpLowering,
284 RemFOpLowering,
285 ExtUIOpLowering,
286 ExtSIOpLowering,
287 ExtFOpLowering,
288 TruncIOpLowering,
289 TruncFOpLowering,
290 UIToFPOpLowering,
291 SIToFPOpLowering,
292 FPToUIOpLowering,
293 FPToSIOpLowering,
294 IndexCastOpLowering,
295 BitcastOpLowering,
296 CmpIOpLowering,
297 CmpFOpLowering,
298 SelectOpLowering
299 >(converter);
300 // clang-format on
301 }
302
createConvertArithmeticToLLVMPass()303 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToLLVMPass() {
304 return std::make_unique<ConvertArithmeticToLLVMPass>();
305 }
306