1a54f4eaeSMogball //===- ArithmeticToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2a54f4eaeSMogball //
3a54f4eaeSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a54f4eaeSMogball // See https://llvm.org/LICENSE.txt for license information.
5a54f4eaeSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a54f4eaeSMogball //
7a54f4eaeSMogball //===----------------------------------------------------------------------===//
8a54f4eaeSMogball 
9a54f4eaeSMogball #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
10a54f4eaeSMogball #include "../PassDetail.h"
11a54f4eaeSMogball #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12a54f4eaeSMogball #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
13a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14a54f4eaeSMogball #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15a54f4eaeSMogball #include "mlir/IR/TypeUtilities.h"
16a54f4eaeSMogball 
17a54f4eaeSMogball using namespace mlir;
18a54f4eaeSMogball 
19a54f4eaeSMogball namespace {
20a54f4eaeSMogball 
21a54f4eaeSMogball //===----------------------------------------------------------------------===//
22a54f4eaeSMogball // Straightforward Op Lowerings
23a54f4eaeSMogball //===----------------------------------------------------------------------===//
24a54f4eaeSMogball 
25a54f4eaeSMogball using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
26a54f4eaeSMogball using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
27a54f4eaeSMogball using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
28a54f4eaeSMogball using DivUIOpLowering =
29a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>;
30a54f4eaeSMogball using DivSIOpLowering =
31a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
32a54f4eaeSMogball using RemUIOpLowering =
33a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
34a54f4eaeSMogball using RemSIOpLowering =
35a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
36a54f4eaeSMogball using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
37a54f4eaeSMogball using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
38a54f4eaeSMogball using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
39a54f4eaeSMogball using ShLIOpLowering = VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp>;
40a54f4eaeSMogball using ShRUIOpLowering =
41a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>;
42a54f4eaeSMogball using ShRSIOpLowering =
43a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>;
44a54f4eaeSMogball using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp>;
45a54f4eaeSMogball using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp>;
46a54f4eaeSMogball using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp>;
47a54f4eaeSMogball using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp>;
48a54f4eaeSMogball using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp>;
49a54f4eaeSMogball using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp>;
50a54f4eaeSMogball using ExtUIOpLowering =
51a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>;
52a54f4eaeSMogball using ExtSIOpLowering =
53a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>;
54a54f4eaeSMogball using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>;
55a54f4eaeSMogball using TruncIOpLowering =
56a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
57a54f4eaeSMogball using TruncFOpLowering =
58a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
59a54f4eaeSMogball using UIToFPOpLowering =
60a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
61a54f4eaeSMogball using SIToFPOpLowering =
62a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
63a54f4eaeSMogball using FPToUIOpLowering =
64a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
65a54f4eaeSMogball using FPToSIOpLowering =
66a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
67a54f4eaeSMogball using BitcastOpLowering =
68a54f4eaeSMogball     VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
69a54f4eaeSMogball 
70a54f4eaeSMogball //===----------------------------------------------------------------------===//
71a54f4eaeSMogball // Op Lowering Patterns
72a54f4eaeSMogball //===----------------------------------------------------------------------===//
73a54f4eaeSMogball 
74a54f4eaeSMogball /// Directly lower to LLVM op.
75a54f4eaeSMogball struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
76a54f4eaeSMogball   using ConvertOpToLLVMPattern<arith::ConstantOp>::ConvertOpToLLVMPattern;
77a54f4eaeSMogball 
78a54f4eaeSMogball   LogicalResult
79a54f4eaeSMogball   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
80a54f4eaeSMogball                   ConversionPatternRewriter &rewriter) const override;
81a54f4eaeSMogball };
82a54f4eaeSMogball 
83a54f4eaeSMogball /// The lowering of index_cast becomes an integer conversion since index
84a54f4eaeSMogball /// becomes an integer.  If the bit width of the source and target integer
85a54f4eaeSMogball /// types is the same, just erase the cast.  If the target type is wider,
86a54f4eaeSMogball /// sign-extend the value, otherwise truncate it.
87a54f4eaeSMogball struct IndexCastOpLowering : public ConvertOpToLLVMPattern<arith::IndexCastOp> {
88a54f4eaeSMogball   using ConvertOpToLLVMPattern<arith::IndexCastOp>::ConvertOpToLLVMPattern;
89a54f4eaeSMogball 
90a54f4eaeSMogball   LogicalResult
91a54f4eaeSMogball   matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
92a54f4eaeSMogball                   ConversionPatternRewriter &rewriter) const override;
93a54f4eaeSMogball };
94a54f4eaeSMogball 
95a54f4eaeSMogball struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
96a54f4eaeSMogball   using ConvertOpToLLVMPattern<arith::CmpIOp>::ConvertOpToLLVMPattern;
97a54f4eaeSMogball 
98a54f4eaeSMogball   LogicalResult
99a54f4eaeSMogball   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
100a54f4eaeSMogball                   ConversionPatternRewriter &rewriter) const override;
101a54f4eaeSMogball };
102a54f4eaeSMogball 
103a54f4eaeSMogball struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
104a54f4eaeSMogball   using ConvertOpToLLVMPattern<arith::CmpFOp>::ConvertOpToLLVMPattern;
105a54f4eaeSMogball 
106a54f4eaeSMogball   LogicalResult
107a54f4eaeSMogball   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
108a54f4eaeSMogball                   ConversionPatternRewriter &rewriter) const override;
109a54f4eaeSMogball };
110a54f4eaeSMogball 
111a54f4eaeSMogball } // end anonymous namespace
112a54f4eaeSMogball 
113a54f4eaeSMogball //===----------------------------------------------------------------------===//
114a54f4eaeSMogball // ConstantOpLowering
115a54f4eaeSMogball //===----------------------------------------------------------------------===//
116a54f4eaeSMogball 
117a54f4eaeSMogball LogicalResult
118a54f4eaeSMogball ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
119a54f4eaeSMogball                                     ConversionPatternRewriter &rewriter) const {
120a54f4eaeSMogball   return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
121a54f4eaeSMogball                                        adaptor.getOperands(),
122a54f4eaeSMogball                                        *getTypeConverter(), rewriter);
123a54f4eaeSMogball }
124a54f4eaeSMogball 
125a54f4eaeSMogball //===----------------------------------------------------------------------===//
126a54f4eaeSMogball // IndexCastOpLowering
127a54f4eaeSMogball //===----------------------------------------------------------------------===//
128a54f4eaeSMogball 
129a54f4eaeSMogball LogicalResult IndexCastOpLowering::matchAndRewrite(
130a54f4eaeSMogball     arith::IndexCastOp op, OpAdaptor adaptor,
131a54f4eaeSMogball     ConversionPatternRewriter &rewriter) const {
132a54f4eaeSMogball   auto targetType = typeConverter->convertType(op.getResult().getType());
133a54f4eaeSMogball   auto targetElementType =
134a54f4eaeSMogball       typeConverter->convertType(getElementTypeOrSelf(op.getResult()))
135a54f4eaeSMogball           .cast<IntegerType>();
136a54f4eaeSMogball   auto sourceElementType =
137*cfb72fd3SJacques Pienaar       getElementTypeOrSelf(adaptor.getIn()).cast<IntegerType>();
138a54f4eaeSMogball   unsigned targetBits = targetElementType.getWidth();
139a54f4eaeSMogball   unsigned sourceBits = sourceElementType.getWidth();
140a54f4eaeSMogball 
141a54f4eaeSMogball   if (targetBits == sourceBits)
142*cfb72fd3SJacques Pienaar     rewriter.replaceOp(op, adaptor.getIn());
143a54f4eaeSMogball   else if (targetBits < sourceBits)
144*cfb72fd3SJacques Pienaar     rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, adaptor.getIn());
145a54f4eaeSMogball   else
146*cfb72fd3SJacques Pienaar     rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, adaptor.getIn());
147a54f4eaeSMogball   return success();
148a54f4eaeSMogball }
149a54f4eaeSMogball 
150a54f4eaeSMogball //===----------------------------------------------------------------------===//
151a54f4eaeSMogball // CmpIOpLowering
152a54f4eaeSMogball //===----------------------------------------------------------------------===//
153a54f4eaeSMogball 
154a54f4eaeSMogball // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
155a54f4eaeSMogball // share numerical values so just cast.
156a54f4eaeSMogball template <typename LLVMPredType, typename PredType>
157a54f4eaeSMogball static LLVMPredType convertCmpPredicate(PredType pred) {
158a54f4eaeSMogball   return static_cast<LLVMPredType>(pred);
159a54f4eaeSMogball }
160a54f4eaeSMogball 
161a54f4eaeSMogball LogicalResult
162a54f4eaeSMogball CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
163a54f4eaeSMogball                                 ConversionPatternRewriter &rewriter) const {
164*cfb72fd3SJacques Pienaar   auto operandType = adaptor.getLhs().getType();
165a54f4eaeSMogball   auto resultType = op.getResult().getType();
166a54f4eaeSMogball 
167a54f4eaeSMogball   // Handle the scalar and 1D vector cases.
168a54f4eaeSMogball   if (!operandType.isa<LLVM::LLVMArrayType>()) {
169a54f4eaeSMogball     rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
170a54f4eaeSMogball         op, typeConverter->convertType(resultType),
171a54f4eaeSMogball         convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
172*cfb72fd3SJacques Pienaar         adaptor.getLhs(), adaptor.getRhs());
173a54f4eaeSMogball     return success();
174a54f4eaeSMogball   }
175a54f4eaeSMogball 
176a54f4eaeSMogball   auto vectorType = resultType.dyn_cast<VectorType>();
177a54f4eaeSMogball   if (!vectorType)
178a54f4eaeSMogball     return rewriter.notifyMatchFailure(op, "expected vector result type");
179a54f4eaeSMogball 
180a54f4eaeSMogball   return LLVM::detail::handleMultidimensionalVectors(
181a54f4eaeSMogball       op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
182a54f4eaeSMogball       [&](Type llvm1DVectorTy, ValueRange operands) {
183a54f4eaeSMogball         OpAdaptor adaptor(operands);
184a54f4eaeSMogball         return rewriter.create<LLVM::ICmpOp>(
185a54f4eaeSMogball             op.getLoc(), llvm1DVectorTy,
186a54f4eaeSMogball             convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
187*cfb72fd3SJacques Pienaar             adaptor.getLhs(), adaptor.getRhs());
188a54f4eaeSMogball       },
189a54f4eaeSMogball       rewriter);
190a54f4eaeSMogball 
191a54f4eaeSMogball   return success();
192a54f4eaeSMogball }
193a54f4eaeSMogball 
194a54f4eaeSMogball //===----------------------------------------------------------------------===//
195a54f4eaeSMogball // CmpFOpLowering
196a54f4eaeSMogball //===----------------------------------------------------------------------===//
197a54f4eaeSMogball 
198a54f4eaeSMogball LogicalResult
199a54f4eaeSMogball CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
200a54f4eaeSMogball                                 ConversionPatternRewriter &rewriter) const {
201*cfb72fd3SJacques Pienaar   auto operandType = adaptor.getLhs().getType();
202a54f4eaeSMogball   auto resultType = op.getResult().getType();
203a54f4eaeSMogball 
204a54f4eaeSMogball   // Handle the scalar and 1D vector cases.
205a54f4eaeSMogball   if (!operandType.isa<LLVM::LLVMArrayType>()) {
206a54f4eaeSMogball     rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
207a54f4eaeSMogball         op, typeConverter->convertType(resultType),
208a54f4eaeSMogball         convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
209*cfb72fd3SJacques Pienaar         adaptor.getLhs(), adaptor.getRhs());
210a54f4eaeSMogball     return success();
211a54f4eaeSMogball   }
212a54f4eaeSMogball 
213a54f4eaeSMogball   auto vectorType = resultType.dyn_cast<VectorType>();
214a54f4eaeSMogball   if (!vectorType)
215a54f4eaeSMogball     return rewriter.notifyMatchFailure(op, "expected vector result type");
216a54f4eaeSMogball 
217a54f4eaeSMogball   return LLVM::detail::handleMultidimensionalVectors(
218a54f4eaeSMogball       op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
219a54f4eaeSMogball       [&](Type llvm1DVectorTy, ValueRange operands) {
220a54f4eaeSMogball         OpAdaptor adaptor(operands);
221a54f4eaeSMogball         return rewriter.create<LLVM::FCmpOp>(
222a54f4eaeSMogball             op.getLoc(), llvm1DVectorTy,
223a54f4eaeSMogball             convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
224*cfb72fd3SJacques Pienaar             adaptor.getLhs(), adaptor.getRhs());
225a54f4eaeSMogball       },
226a54f4eaeSMogball       rewriter);
227a54f4eaeSMogball }
228a54f4eaeSMogball 
229a54f4eaeSMogball //===----------------------------------------------------------------------===//
230a54f4eaeSMogball // Pass Definition
231a54f4eaeSMogball //===----------------------------------------------------------------------===//
232a54f4eaeSMogball 
233a54f4eaeSMogball namespace {
234a54f4eaeSMogball struct ConvertArithmeticToLLVMPass
235a54f4eaeSMogball     : public ConvertArithmeticToLLVMBase<ConvertArithmeticToLLVMPass> {
236a54f4eaeSMogball   ConvertArithmeticToLLVMPass() = default;
237a54f4eaeSMogball 
238a54f4eaeSMogball   void runOnFunction() override {
239a54f4eaeSMogball     LLVMConversionTarget target(getContext());
240a54f4eaeSMogball     RewritePatternSet patterns(&getContext());
241a54f4eaeSMogball 
242a54f4eaeSMogball     LowerToLLVMOptions options(&getContext());
243a54f4eaeSMogball     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
244a54f4eaeSMogball       options.overrideIndexBitwidth(indexBitwidth);
245a54f4eaeSMogball 
246a54f4eaeSMogball     LLVMTypeConverter converter(&getContext(), options);
247a54f4eaeSMogball     mlir::arith::populateArithmeticToLLVMConversionPatterns(converter,
248a54f4eaeSMogball                                                             patterns);
249a54f4eaeSMogball 
250a54f4eaeSMogball     if (failed(
251a54f4eaeSMogball             applyPartialConversion(getFunction(), target, std::move(patterns))))
252a54f4eaeSMogball       signalPassFailure();
253a54f4eaeSMogball   }
254a54f4eaeSMogball };
255a54f4eaeSMogball } // end anonymous namespace
256a54f4eaeSMogball 
257a54f4eaeSMogball //===----------------------------------------------------------------------===//
258a54f4eaeSMogball // Pattern Population
259a54f4eaeSMogball //===----------------------------------------------------------------------===//
260a54f4eaeSMogball 
261a54f4eaeSMogball void mlir::arith::populateArithmeticToLLVMConversionPatterns(
262a54f4eaeSMogball     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
263a54f4eaeSMogball   // clang-format off
264a54f4eaeSMogball   patterns.add<
265a54f4eaeSMogball     ConstantOpLowering,
266a54f4eaeSMogball     AddIOpLowering,
267a54f4eaeSMogball     SubIOpLowering,
268a54f4eaeSMogball     MulIOpLowering,
269a54f4eaeSMogball     DivUIOpLowering,
270a54f4eaeSMogball     DivSIOpLowering,
271a54f4eaeSMogball     RemUIOpLowering,
272a54f4eaeSMogball     RemSIOpLowering,
273a54f4eaeSMogball     AndIOpLowering,
274a54f4eaeSMogball     OrIOpLowering,
275a54f4eaeSMogball     XOrIOpLowering,
276a54f4eaeSMogball     ShLIOpLowering,
277a54f4eaeSMogball     ShRUIOpLowering,
278a54f4eaeSMogball     ShRSIOpLowering,
279a54f4eaeSMogball     NegFOpLowering,
280a54f4eaeSMogball     AddFOpLowering,
281a54f4eaeSMogball     SubFOpLowering,
282a54f4eaeSMogball     MulFOpLowering,
283a54f4eaeSMogball     DivFOpLowering,
284a54f4eaeSMogball     RemFOpLowering,
285a54f4eaeSMogball     ExtUIOpLowering,
286a54f4eaeSMogball     ExtSIOpLowering,
287a54f4eaeSMogball     ExtFOpLowering,
288a54f4eaeSMogball     TruncIOpLowering,
289a54f4eaeSMogball     TruncFOpLowering,
290a54f4eaeSMogball     UIToFPOpLowering,
291a54f4eaeSMogball     SIToFPOpLowering,
292a54f4eaeSMogball     FPToUIOpLowering,
293a54f4eaeSMogball     FPToSIOpLowering,
294a54f4eaeSMogball     IndexCastOpLowering,
295a54f4eaeSMogball     BitcastOpLowering,
296a54f4eaeSMogball     CmpIOpLowering,
297a54f4eaeSMogball     CmpFOpLowering
298a54f4eaeSMogball   >(converter);
299a54f4eaeSMogball   // clang-format on
300a54f4eaeSMogball }
301a54f4eaeSMogball 
302a54f4eaeSMogball std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToLLVMPass() {
303a54f4eaeSMogball   return std::make_unique<ConvertArithmeticToLLVMPass>();
304a54f4eaeSMogball }
305