1 //===- ArithmeticToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/IR/TypeUtilities.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 
21 //===----------------------------------------------------------------------===//
22 // Straightforward Op Lowerings
23 //===----------------------------------------------------------------------===//
24 
25 using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
26 using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
27 using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
28 using DivUIOpLowering =
29     VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>;
30 using DivSIOpLowering =
31     VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
32 using RemUIOpLowering =
33     VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
34 using RemSIOpLowering =
35     VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
36 using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
37 using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
38 using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
39 using ShLIOpLowering = VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp>;
40 using ShRUIOpLowering =
41     VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>;
42 using ShRSIOpLowering =
43     VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>;
44 using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp>;
45 using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp>;
46 using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp>;
47 using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp>;
48 using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp>;
49 using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp>;
50 using ExtUIOpLowering =
51     VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>;
52 using ExtSIOpLowering =
53     VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>;
54 using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>;
55 using TruncIOpLowering =
56     VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
57 using TruncFOpLowering =
58     VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
59 using UIToFPOpLowering =
60     VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
61 using SIToFPOpLowering =
62     VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
63 using FPToUIOpLowering =
64     VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
65 using FPToSIOpLowering =
66     VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
67 using BitcastOpLowering =
68     VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
69 
70 //===----------------------------------------------------------------------===//
71 // Op Lowering Patterns
72 //===----------------------------------------------------------------------===//
73 
74 /// Directly lower to LLVM op.
75 struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
76   using ConvertOpToLLVMPattern<arith::ConstantOp>::ConvertOpToLLVMPattern;
77 
78   LogicalResult
79   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
80                   ConversionPatternRewriter &rewriter) const override;
81 };
82 
83 /// The lowering of index_cast becomes an integer conversion since index
84 /// becomes an integer.  If the bit width of the source and target integer
85 /// types is the same, just erase the cast.  If the target type is wider,
86 /// sign-extend the value, otherwise truncate it.
87 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<arith::IndexCastOp> {
88   using ConvertOpToLLVMPattern<arith::IndexCastOp>::ConvertOpToLLVMPattern;
89 
90   LogicalResult
91   matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
92                   ConversionPatternRewriter &rewriter) const override;
93 };
94 
95 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
96   using ConvertOpToLLVMPattern<arith::CmpIOp>::ConvertOpToLLVMPattern;
97 
98   LogicalResult
99   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
100                   ConversionPatternRewriter &rewriter) const override;
101 };
102 
103 struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
104   using ConvertOpToLLVMPattern<arith::CmpFOp>::ConvertOpToLLVMPattern;
105 
106   LogicalResult
107   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
108                   ConversionPatternRewriter &rewriter) const override;
109 };
110 
111 } // end anonymous namespace
112 
113 //===----------------------------------------------------------------------===//
114 // ConstantOpLowering
115 //===----------------------------------------------------------------------===//
116 
117 LogicalResult
118 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
119                                     ConversionPatternRewriter &rewriter) const {
120   return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
121                                        adaptor.getOperands(),
122                                        *getTypeConverter(), rewriter);
123 }
124 
125 //===----------------------------------------------------------------------===//
126 // IndexCastOpLowering
127 //===----------------------------------------------------------------------===//
128 
129 LogicalResult IndexCastOpLowering::matchAndRewrite(
130     arith::IndexCastOp op, OpAdaptor adaptor,
131     ConversionPatternRewriter &rewriter) const {
132   auto targetType = typeConverter->convertType(op.getResult().getType());
133   auto targetElementType =
134       typeConverter->convertType(getElementTypeOrSelf(op.getResult()))
135           .cast<IntegerType>();
136   auto sourceElementType =
137       getElementTypeOrSelf(adaptor.getIn()).cast<IntegerType>();
138   unsigned targetBits = targetElementType.getWidth();
139   unsigned sourceBits = sourceElementType.getWidth();
140 
141   if (targetBits == sourceBits)
142     rewriter.replaceOp(op, adaptor.getIn());
143   else if (targetBits < sourceBits)
144     rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, adaptor.getIn());
145   else
146     rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, adaptor.getIn());
147   return success();
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // CmpIOpLowering
152 //===----------------------------------------------------------------------===//
153 
154 // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
155 // share numerical values so just cast.
156 template <typename LLVMPredType, typename PredType>
157 static LLVMPredType convertCmpPredicate(PredType pred) {
158   return static_cast<LLVMPredType>(pred);
159 }
160 
161 LogicalResult
162 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
163                                 ConversionPatternRewriter &rewriter) const {
164   auto operandType = adaptor.getLhs().getType();
165   auto resultType = op.getResult().getType();
166 
167   // Handle the scalar and 1D vector cases.
168   if (!operandType.isa<LLVM::LLVMArrayType>()) {
169     rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
170         op, typeConverter->convertType(resultType),
171         convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
172         adaptor.getLhs(), adaptor.getRhs());
173     return success();
174   }
175 
176   auto vectorType = resultType.dyn_cast<VectorType>();
177   if (!vectorType)
178     return rewriter.notifyMatchFailure(op, "expected vector result type");
179 
180   return LLVM::detail::handleMultidimensionalVectors(
181       op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
182       [&](Type llvm1DVectorTy, ValueRange operands) {
183         OpAdaptor adaptor(operands);
184         return rewriter.create<LLVM::ICmpOp>(
185             op.getLoc(), llvm1DVectorTy,
186             convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
187             adaptor.getLhs(), adaptor.getRhs());
188       },
189       rewriter);
190 
191   return success();
192 }
193 
194 //===----------------------------------------------------------------------===//
195 // CmpFOpLowering
196 //===----------------------------------------------------------------------===//
197 
198 LogicalResult
199 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
200                                 ConversionPatternRewriter &rewriter) const {
201   auto operandType = adaptor.getLhs().getType();
202   auto resultType = op.getResult().getType();
203 
204   // Handle the scalar and 1D vector cases.
205   if (!operandType.isa<LLVM::LLVMArrayType>()) {
206     rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
207         op, typeConverter->convertType(resultType),
208         convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
209         adaptor.getLhs(), adaptor.getRhs());
210     return success();
211   }
212 
213   auto vectorType = resultType.dyn_cast<VectorType>();
214   if (!vectorType)
215     return rewriter.notifyMatchFailure(op, "expected vector result type");
216 
217   return LLVM::detail::handleMultidimensionalVectors(
218       op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
219       [&](Type llvm1DVectorTy, ValueRange operands) {
220         OpAdaptor adaptor(operands);
221         return rewriter.create<LLVM::FCmpOp>(
222             op.getLoc(), llvm1DVectorTy,
223             convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
224             adaptor.getLhs(), adaptor.getRhs());
225       },
226       rewriter);
227 }
228 
229 //===----------------------------------------------------------------------===//
230 // Pass Definition
231 //===----------------------------------------------------------------------===//
232 
233 namespace {
234 struct ConvertArithmeticToLLVMPass
235     : public ConvertArithmeticToLLVMBase<ConvertArithmeticToLLVMPass> {
236   ConvertArithmeticToLLVMPass() = default;
237 
238   void runOnFunction() override {
239     LLVMConversionTarget target(getContext());
240     RewritePatternSet patterns(&getContext());
241 
242     LowerToLLVMOptions options(&getContext());
243     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
244       options.overrideIndexBitwidth(indexBitwidth);
245 
246     LLVMTypeConverter converter(&getContext(), options);
247     mlir::arith::populateArithmeticToLLVMConversionPatterns(converter,
248                                                             patterns);
249 
250     if (failed(
251             applyPartialConversion(getFunction(), target, std::move(patterns))))
252       signalPassFailure();
253   }
254 };
255 } // end anonymous namespace
256 
257 //===----------------------------------------------------------------------===//
258 // Pattern Population
259 //===----------------------------------------------------------------------===//
260 
261 void mlir::arith::populateArithmeticToLLVMConversionPatterns(
262     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
263   // clang-format off
264   patterns.add<
265     ConstantOpLowering,
266     AddIOpLowering,
267     SubIOpLowering,
268     MulIOpLowering,
269     DivUIOpLowering,
270     DivSIOpLowering,
271     RemUIOpLowering,
272     RemSIOpLowering,
273     AndIOpLowering,
274     OrIOpLowering,
275     XOrIOpLowering,
276     ShLIOpLowering,
277     ShRUIOpLowering,
278     ShRSIOpLowering,
279     NegFOpLowering,
280     AddFOpLowering,
281     SubFOpLowering,
282     MulFOpLowering,
283     DivFOpLowering,
284     RemFOpLowering,
285     ExtUIOpLowering,
286     ExtSIOpLowering,
287     ExtFOpLowering,
288     TruncIOpLowering,
289     TruncFOpLowering,
290     UIToFPOpLowering,
291     SIToFPOpLowering,
292     FPToUIOpLowering,
293     FPToSIOpLowering,
294     IndexCastOpLowering,
295     BitcastOpLowering,
296     CmpIOpLowering,
297     CmpFOpLowering
298   >(converter);
299   // clang-format on
300 }
301 
302 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToLLVMPass() {
303   return std::make_unique<ConvertArithmeticToLLVMPass>();
304 }
305