1 //===- ArithmeticToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
10 #include "../PassDetail.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/IR/TypeUtilities.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 
21 //===----------------------------------------------------------------------===//
22 // Straightforward Op Lowerings
23 //===----------------------------------------------------------------------===//
24 
25 using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
26 using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
27 using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
28 using DivUIOpLowering =
29     VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>;
30 using DivSIOpLowering =
31     VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
32 using RemUIOpLowering =
33     VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
34 using RemSIOpLowering =
35     VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
36 using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
37 using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
38 using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
39 using ShLIOpLowering = VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp>;
40 using ShRUIOpLowering =
41     VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>;
42 using ShRSIOpLowering =
43     VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>;
44 using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp>;
45 using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp>;
46 using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp>;
47 using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp>;
48 using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp>;
49 using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp>;
50 using ExtUIOpLowering =
51     VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>;
52 using ExtSIOpLowering =
53     VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>;
54 using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>;
55 using TruncIOpLowering =
56     VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
57 using TruncFOpLowering =
58     VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
59 using UIToFPOpLowering =
60     VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
61 using SIToFPOpLowering =
62     VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
63 using FPToUIOpLowering =
64     VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
65 using FPToSIOpLowering =
66     VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
67 using BitcastOpLowering =
68     VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
69 using SelectOpLowering =
70     VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>;
71 
72 //===----------------------------------------------------------------------===//
73 // Op Lowering Patterns
74 //===----------------------------------------------------------------------===//
75 
76 /// Directly lower to LLVM op.
77 struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
78   using ConvertOpToLLVMPattern<arith::ConstantOp>::ConvertOpToLLVMPattern;
79 
80   LogicalResult
81   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
82                   ConversionPatternRewriter &rewriter) const override;
83 };
84 
85 /// The lowering of index_cast becomes an integer conversion since index
86 /// becomes an integer.  If the bit width of the source and target integer
87 /// types is the same, just erase the cast.  If the target type is wider,
88 /// sign-extend the value, otherwise truncate it.
89 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<arith::IndexCastOp> {
90   using ConvertOpToLLVMPattern<arith::IndexCastOp>::ConvertOpToLLVMPattern;
91 
92   LogicalResult
93   matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
94                   ConversionPatternRewriter &rewriter) const override;
95 };
96 
97 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
98   using ConvertOpToLLVMPattern<arith::CmpIOp>::ConvertOpToLLVMPattern;
99 
100   LogicalResult
101   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
102                   ConversionPatternRewriter &rewriter) const override;
103 };
104 
105 struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
106   using ConvertOpToLLVMPattern<arith::CmpFOp>::ConvertOpToLLVMPattern;
107 
108   LogicalResult
109   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
110                   ConversionPatternRewriter &rewriter) const override;
111 };
112 
113 } // namespace
114 
115 //===----------------------------------------------------------------------===//
116 // ConstantOpLowering
117 //===----------------------------------------------------------------------===//
118 
119 LogicalResult
matchAndRewrite(arith::ConstantOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const120 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
121                                     ConversionPatternRewriter &rewriter) const {
122   return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
123                                        adaptor.getOperands(),
124                                        *getTypeConverter(), rewriter);
125 }
126 
127 //===----------------------------------------------------------------------===//
128 // IndexCastOpLowering
129 //===----------------------------------------------------------------------===//
130 
matchAndRewrite(arith::IndexCastOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const131 LogicalResult IndexCastOpLowering::matchAndRewrite(
132     arith::IndexCastOp op, OpAdaptor adaptor,
133     ConversionPatternRewriter &rewriter) const {
134   auto targetType = typeConverter->convertType(op.getResult().getType());
135   auto targetElementType =
136       typeConverter->convertType(getElementTypeOrSelf(op.getResult()))
137           .cast<IntegerType>();
138   auto sourceElementType =
139       getElementTypeOrSelf(adaptor.getIn()).cast<IntegerType>();
140   unsigned targetBits = targetElementType.getWidth();
141   unsigned sourceBits = sourceElementType.getWidth();
142 
143   if (targetBits == sourceBits)
144     rewriter.replaceOp(op, adaptor.getIn());
145   else if (targetBits < sourceBits)
146     rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, adaptor.getIn());
147   else
148     rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, adaptor.getIn());
149   return success();
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // CmpIOpLowering
154 //===----------------------------------------------------------------------===//
155 
156 // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
157 // share numerical values so just cast.
158 template <typename LLVMPredType, typename PredType>
convertCmpPredicate(PredType pred)159 static LLVMPredType convertCmpPredicate(PredType pred) {
160   return static_cast<LLVMPredType>(pred);
161 }
162 
163 LogicalResult
matchAndRewrite(arith::CmpIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const164 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
165                                 ConversionPatternRewriter &rewriter) const {
166   auto operandType = adaptor.getLhs().getType();
167   auto resultType = op.getResult().getType();
168 
169   // Handle the scalar and 1D vector cases.
170   if (!operandType.isa<LLVM::LLVMArrayType>()) {
171     rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
172         op, typeConverter->convertType(resultType),
173         convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
174         adaptor.getLhs(), adaptor.getRhs());
175     return success();
176   }
177 
178   auto vectorType = resultType.dyn_cast<VectorType>();
179   if (!vectorType)
180     return rewriter.notifyMatchFailure(op, "expected vector result type");
181 
182   return LLVM::detail::handleMultidimensionalVectors(
183       op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
184       [&](Type llvm1DVectorTy, ValueRange operands) {
185         OpAdaptor adaptor(operands);
186         return rewriter.create<LLVM::ICmpOp>(
187             op.getLoc(), llvm1DVectorTy,
188             convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
189             adaptor.getLhs(), adaptor.getRhs());
190       },
191       rewriter);
192 }
193 
194 //===----------------------------------------------------------------------===//
195 // CmpFOpLowering
196 //===----------------------------------------------------------------------===//
197 
198 LogicalResult
matchAndRewrite(arith::CmpFOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const199 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
200                                 ConversionPatternRewriter &rewriter) const {
201   auto operandType = adaptor.getLhs().getType();
202   auto resultType = op.getResult().getType();
203 
204   // Handle the scalar and 1D vector cases.
205   if (!operandType.isa<LLVM::LLVMArrayType>()) {
206     rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
207         op, typeConverter->convertType(resultType),
208         convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
209         adaptor.getLhs(), adaptor.getRhs());
210     return success();
211   }
212 
213   auto vectorType = resultType.dyn_cast<VectorType>();
214   if (!vectorType)
215     return rewriter.notifyMatchFailure(op, "expected vector result type");
216 
217   return LLVM::detail::handleMultidimensionalVectors(
218       op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
219       [&](Type llvm1DVectorTy, ValueRange operands) {
220         OpAdaptor adaptor(operands);
221         return rewriter.create<LLVM::FCmpOp>(
222             op.getLoc(), llvm1DVectorTy,
223             convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
224             adaptor.getLhs(), adaptor.getRhs());
225       },
226       rewriter);
227 }
228 
229 //===----------------------------------------------------------------------===//
230 // Pass Definition
231 //===----------------------------------------------------------------------===//
232 
233 namespace {
234 struct ConvertArithmeticToLLVMPass
235     : public ConvertArithmeticToLLVMBase<ConvertArithmeticToLLVMPass> {
236   ConvertArithmeticToLLVMPass() = default;
237 
runOnOperation__anonb4369bed0411::ConvertArithmeticToLLVMPass238   void runOnOperation() override {
239     LLVMConversionTarget target(getContext());
240     RewritePatternSet patterns(&getContext());
241 
242     LowerToLLVMOptions options(&getContext());
243     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
244       options.overrideIndexBitwidth(indexBitwidth);
245 
246     LLVMTypeConverter converter(&getContext(), options);
247     mlir::arith::populateArithmeticToLLVMConversionPatterns(converter,
248                                                             patterns);
249 
250     if (failed(applyPartialConversion(getOperation(), target,
251                                       std::move(patterns))))
252       signalPassFailure();
253   }
254 };
255 } // namespace
256 
257 //===----------------------------------------------------------------------===//
258 // Pattern Population
259 //===----------------------------------------------------------------------===//
260 
populateArithmeticToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)261 void mlir::arith::populateArithmeticToLLVMConversionPatterns(
262     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
263   // clang-format off
264   patterns.add<
265     ConstantOpLowering,
266     AddIOpLowering,
267     SubIOpLowering,
268     MulIOpLowering,
269     DivUIOpLowering,
270     DivSIOpLowering,
271     RemUIOpLowering,
272     RemSIOpLowering,
273     AndIOpLowering,
274     OrIOpLowering,
275     XOrIOpLowering,
276     ShLIOpLowering,
277     ShRUIOpLowering,
278     ShRSIOpLowering,
279     NegFOpLowering,
280     AddFOpLowering,
281     SubFOpLowering,
282     MulFOpLowering,
283     DivFOpLowering,
284     RemFOpLowering,
285     ExtUIOpLowering,
286     ExtSIOpLowering,
287     ExtFOpLowering,
288     TruncIOpLowering,
289     TruncFOpLowering,
290     UIToFPOpLowering,
291     SIToFPOpLowering,
292     FPToUIOpLowering,
293     FPToSIOpLowering,
294     IndexCastOpLowering,
295     BitcastOpLowering,
296     CmpIOpLowering,
297     CmpFOpLowering,
298     SelectOpLowering
299   >(converter);
300   // clang-format on
301 }
302 
createConvertArithmeticToLLVMPass()303 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToLLVMPass() {
304   return std::make_unique<ConvertArithmeticToLLVMPass>();
305 }
306