1 //===- ArithmeticToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/IR/TypeUtilities.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 
21 //===----------------------------------------------------------------------===//
22 // Straightforward Op Lowerings
23 //===----------------------------------------------------------------------===//
24 
25 using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
26 using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
27 using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
28 using DivUIOpLowering =
29     VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>;
30 using DivSIOpLowering =
31     VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
32 using RemUIOpLowering =
33     VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
34 using RemSIOpLowering =
35     VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
36 using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
37 using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
38 using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
39 using ShLIOpLowering = VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp>;
40 using ShRUIOpLowering =
41     VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>;
42 using ShRSIOpLowering =
43     VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>;
44 using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp>;
45 using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp>;
46 using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp>;
47 using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp>;
48 using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp>;
49 using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp>;
50 using ExtUIOpLowering =
51     VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>;
52 using ExtSIOpLowering =
53     VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>;
54 using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>;
55 using TruncIOpLowering =
56     VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
57 using TruncFOpLowering =
58     VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
59 using UIToFPOpLowering =
60     VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
61 using SIToFPOpLowering =
62     VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
63 using FPToUIOpLowering =
64     VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
65 using FPToSIOpLowering =
66     VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
67 using BitcastOpLowering =
68     VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
69 
70 //===----------------------------------------------------------------------===//
71 // Op Lowering Patterns
72 //===----------------------------------------------------------------------===//
73 
74 /// Directly lower to LLVM op.
75 struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
76   using ConvertOpToLLVMPattern<arith::ConstantOp>::ConvertOpToLLVMPattern;
77 
78   LogicalResult
79   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
80                   ConversionPatternRewriter &rewriter) const override;
81 };
82 
83 /// The lowering of index_cast becomes an integer conversion since index
84 /// becomes an integer.  If the bit width of the source and target integer
85 /// types is the same, just erase the cast.  If the target type is wider,
86 /// sign-extend the value, otherwise truncate it.
87 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<arith::IndexCastOp> {
88   using ConvertOpToLLVMPattern<arith::IndexCastOp>::ConvertOpToLLVMPattern;
89 
90   LogicalResult
91   matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
92                   ConversionPatternRewriter &rewriter) const override;
93 };
94 
95 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
96   using ConvertOpToLLVMPattern<arith::CmpIOp>::ConvertOpToLLVMPattern;
97 
98   LogicalResult
99   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
100                   ConversionPatternRewriter &rewriter) const override;
101 };
102 
103 struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
104   using ConvertOpToLLVMPattern<arith::CmpFOp>::ConvertOpToLLVMPattern;
105 
106   LogicalResult
107   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
108                   ConversionPatternRewriter &rewriter) const override;
109 };
110 
111 } // namespace
112 
113 //===----------------------------------------------------------------------===//
114 // ConstantOpLowering
115 //===----------------------------------------------------------------------===//
116 
117 LogicalResult
118 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
119                                     ConversionPatternRewriter &rewriter) const {
120   return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
121                                        adaptor.getOperands(),
122                                        *getTypeConverter(), rewriter);
123 }
124 
125 //===----------------------------------------------------------------------===//
126 // IndexCastOpLowering
127 //===----------------------------------------------------------------------===//
128 
129 LogicalResult IndexCastOpLowering::matchAndRewrite(
130     arith::IndexCastOp op, OpAdaptor adaptor,
131     ConversionPatternRewriter &rewriter) const {
132   auto targetType = typeConverter->convertType(op.getResult().getType());
133   auto targetElementType =
134       typeConverter->convertType(getElementTypeOrSelf(op.getResult()))
135           .cast<IntegerType>();
136   auto sourceElementType =
137       getElementTypeOrSelf(adaptor.getIn()).cast<IntegerType>();
138   unsigned targetBits = targetElementType.getWidth();
139   unsigned sourceBits = sourceElementType.getWidth();
140 
141   if (targetBits == sourceBits)
142     rewriter.replaceOp(op, adaptor.getIn());
143   else if (targetBits < sourceBits)
144     rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, adaptor.getIn());
145   else
146     rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, adaptor.getIn());
147   return success();
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // CmpIOpLowering
152 //===----------------------------------------------------------------------===//
153 
154 // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
155 // share numerical values so just cast.
156 template <typename LLVMPredType, typename PredType>
157 static LLVMPredType convertCmpPredicate(PredType pred) {
158   return static_cast<LLVMPredType>(pred);
159 }
160 
161 LogicalResult
162 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
163                                 ConversionPatternRewriter &rewriter) const {
164   auto operandType = adaptor.getLhs().getType();
165   auto resultType = op.getResult().getType();
166 
167   // Handle the scalar and 1D vector cases.
168   if (!operandType.isa<LLVM::LLVMArrayType>()) {
169     rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
170         op, typeConverter->convertType(resultType),
171         convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
172         adaptor.getLhs(), adaptor.getRhs());
173     return success();
174   }
175 
176   auto vectorType = resultType.dyn_cast<VectorType>();
177   if (!vectorType)
178     return rewriter.notifyMatchFailure(op, "expected vector result type");
179 
180   return LLVM::detail::handleMultidimensionalVectors(
181       op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
182       [&](Type llvm1DVectorTy, ValueRange operands) {
183         OpAdaptor adaptor(operands);
184         return rewriter.create<LLVM::ICmpOp>(
185             op.getLoc(), llvm1DVectorTy,
186             convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
187             adaptor.getLhs(), adaptor.getRhs());
188       },
189       rewriter);
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // CmpFOpLowering
194 //===----------------------------------------------------------------------===//
195 
196 LogicalResult
197 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
198                                 ConversionPatternRewriter &rewriter) const {
199   auto operandType = adaptor.getLhs().getType();
200   auto resultType = op.getResult().getType();
201 
202   // Handle the scalar and 1D vector cases.
203   if (!operandType.isa<LLVM::LLVMArrayType>()) {
204     rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
205         op, typeConverter->convertType(resultType),
206         convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
207         adaptor.getLhs(), adaptor.getRhs());
208     return success();
209   }
210 
211   auto vectorType = resultType.dyn_cast<VectorType>();
212   if (!vectorType)
213     return rewriter.notifyMatchFailure(op, "expected vector result type");
214 
215   return LLVM::detail::handleMultidimensionalVectors(
216       op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
217       [&](Type llvm1DVectorTy, ValueRange operands) {
218         OpAdaptor adaptor(operands);
219         return rewriter.create<LLVM::FCmpOp>(
220             op.getLoc(), llvm1DVectorTy,
221             convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
222             adaptor.getLhs(), adaptor.getRhs());
223       },
224       rewriter);
225 }
226 
227 //===----------------------------------------------------------------------===//
228 // Pass Definition
229 //===----------------------------------------------------------------------===//
230 
231 namespace {
232 struct ConvertArithmeticToLLVMPass
233     : public ConvertArithmeticToLLVMBase<ConvertArithmeticToLLVMPass> {
234   ConvertArithmeticToLLVMPass() = default;
235 
236   void runOnOperation() override {
237     LLVMConversionTarget target(getContext());
238     RewritePatternSet patterns(&getContext());
239 
240     LowerToLLVMOptions options(&getContext());
241     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
242       options.overrideIndexBitwidth(indexBitwidth);
243 
244     LLVMTypeConverter converter(&getContext(), options);
245     mlir::arith::populateArithmeticToLLVMConversionPatterns(converter,
246                                                             patterns);
247 
248     if (failed(applyPartialConversion(getOperation(), target,
249                                       std::move(patterns))))
250       signalPassFailure();
251   }
252 };
253 } // namespace
254 
255 //===----------------------------------------------------------------------===//
256 // Pattern Population
257 //===----------------------------------------------------------------------===//
258 
259 void mlir::arith::populateArithmeticToLLVMConversionPatterns(
260     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
261   // clang-format off
262   patterns.add<
263     ConstantOpLowering,
264     AddIOpLowering,
265     SubIOpLowering,
266     MulIOpLowering,
267     DivUIOpLowering,
268     DivSIOpLowering,
269     RemUIOpLowering,
270     RemSIOpLowering,
271     AndIOpLowering,
272     OrIOpLowering,
273     XOrIOpLowering,
274     ShLIOpLowering,
275     ShRUIOpLowering,
276     ShRSIOpLowering,
277     NegFOpLowering,
278     AddFOpLowering,
279     SubFOpLowering,
280     MulFOpLowering,
281     DivFOpLowering,
282     RemFOpLowering,
283     ExtUIOpLowering,
284     ExtSIOpLowering,
285     ExtFOpLowering,
286     TruncIOpLowering,
287     TruncFOpLowering,
288     UIToFPOpLowering,
289     SIToFPOpLowering,
290     FPToUIOpLowering,
291     FPToSIOpLowering,
292     IndexCastOpLowering,
293     BitcastOpLowering,
294     CmpIOpLowering,
295     CmpFOpLowering
296   >(converter);
297   // clang-format on
298 }
299 
300 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToLLVMPass() {
301   return std::make_unique<ConvertArithmeticToLLVMPass>();
302 }
303