1 //===- ExpandOps.cpp - Pass to legalize Arithmetic ops for LLVM lowering --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
12 #include "mlir/Dialect/StandardOps/IR/Ops.h"
13 #include "mlir/IR/TypeUtilities.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 
20 /// Expands CeilDivUIOp (n, m) into
21 ///  n == 0 ? 0 : ((n-1) / m) + 1
22 struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
23   using OpRewritePattern::OpRewritePattern;
24   LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
25                                 PatternRewriter &rewriter) const final {
26     Location loc = op.getLoc();
27     Value a = op.getLhs();
28     Value b = op.getRhs();
29     Value zero = rewriter.create<arith::ConstantOp>(
30         loc, rewriter.getIntegerAttr(a.getType(), 0));
31     Value compare =
32         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
33     Value one = rewriter.create<arith::ConstantOp>(
34         loc, rewriter.getIntegerAttr(a.getType(), 1));
35     Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
36     Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
37     Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
38     Value res = rewriter.create<SelectOp>(loc, compare, zero, plusOne);
39     rewriter.replaceOp(op, {res});
40     return success();
41   }
42 };
43 
44 /// Expands CeilDivSIOp (n, m) into
45 ///   1) x = (m > 0) ? -1 : 1
46 ///   2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
47 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
48   using OpRewritePattern::OpRewritePattern;
49   LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
50                                 PatternRewriter &rewriter) const final {
51     Location loc = op.getLoc();
52     auto signedCeilDivIOp = cast<arith::CeilDivSIOp>(op);
53     Type type = signedCeilDivIOp.getType();
54     Value a = signedCeilDivIOp.getLhs();
55     Value b = signedCeilDivIOp.getRhs();
56     Value plusOne = rewriter.create<arith::ConstantOp>(
57         loc, rewriter.getIntegerAttr(type, 1));
58     Value zero = rewriter.create<arith::ConstantOp>(
59         loc, rewriter.getIntegerAttr(type, 0));
60     Value minusOne = rewriter.create<arith::ConstantOp>(
61         loc, rewriter.getIntegerAttr(type, -1));
62     // Compute x = (b>0) ? -1 : 1.
63     Value compare =
64         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
65     Value x = rewriter.create<SelectOp>(loc, compare, minusOne, plusOne);
66     // Compute positive res: 1 + ((x+a)/b).
67     Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
68     Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
69     Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
70     // Compute negative res: - ((-a)/b).
71     Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
72     Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
73     Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
74     // Result is (a*b>0) ? pos result : neg result.
75     // Note, we want to avoid using a*b because of possible overflow.
76     // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
77     // not particuliarly care if a*b<0 is true or false when b is zero
78     // as this will result in an illegal divide. So `a*b<0` can be reformulated
79     // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
80     // We pick the first expression here.
81     Value aNeg =
82         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
83     Value aPos =
84         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
85     Value bNeg =
86         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
87     Value bPos =
88         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
89     Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
90     Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
91     Value compareRes =
92         rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
93     Value res = rewriter.create<SelectOp>(loc, compareRes, posRes, negRes);
94     // Perform substitution and return success.
95     rewriter.replaceOp(op, {res});
96     return success();
97   }
98 };
99 
100 /// Expands FloorDivSIOp (n, m) into
101 ///   1)  x = (m<0) ? 1 : -1
102 ///   2)  return (n*m<0) ? - ((-n+x) / m) -1 : n / m
103 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
104   using OpRewritePattern::OpRewritePattern;
105   LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
106                                 PatternRewriter &rewriter) const final {
107     Location loc = op.getLoc();
108     arith::FloorDivSIOp signedFloorDivIOp = cast<arith::FloorDivSIOp>(op);
109     Type type = signedFloorDivIOp.getType();
110     Value a = signedFloorDivIOp.getLhs();
111     Value b = signedFloorDivIOp.getRhs();
112     Value plusOne = rewriter.create<arith::ConstantOp>(
113         loc, rewriter.getIntegerAttr(type, 1));
114     Value zero = rewriter.create<arith::ConstantOp>(
115         loc, rewriter.getIntegerAttr(type, 0));
116     Value minusOne = rewriter.create<arith::ConstantOp>(
117         loc, rewriter.getIntegerAttr(type, -1));
118     // Compute x = (b<0) ? 1 : -1.
119     Value compare =
120         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
121     Value x = rewriter.create<SelectOp>(loc, compare, plusOne, minusOne);
122     // Compute negative res: -1 - ((x-a)/b).
123     Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a);
124     Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b);
125     Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB);
126     // Compute positive res: a/b.
127     Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b);
128     // Result is (a*b<0) ? negative result : positive result.
129     // Note, we want to avoid using a*b because of possible overflow.
130     // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
131     // not particuliarly care if a*b<0 is true or false when b is zero
132     // as this will result in an illegal divide. So `a*b<0` can be reformulated
133     // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'.
134     // We pick the first expression here.
135     Value aNeg =
136         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
137     Value aPos =
138         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
139     Value bNeg =
140         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
141     Value bPos =
142         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
143     Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos);
144     Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg);
145     Value compareRes =
146         rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
147     Value res = rewriter.create<SelectOp>(loc, compareRes, negRes, posRes);
148     // Perform substitution and return success.
149     rewriter.replaceOp(op, {res});
150     return success();
151   }
152 };
153 
154 template <typename OpTy, arith::CmpFPredicate pred>
155 struct MaxMinFOpConverter : public OpRewritePattern<OpTy> {
156 public:
157   using OpRewritePattern<OpTy>::OpRewritePattern;
158 
159   LogicalResult matchAndRewrite(OpTy op,
160                                 PatternRewriter &rewriter) const final {
161     Value lhs = op.getLhs();
162     Value rhs = op.getRhs();
163 
164     Location loc = op.getLoc();
165     Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
166     Value select = rewriter.create<SelectOp>(loc, cmp, lhs, rhs);
167 
168     auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
169     Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
170                                                  lhs, rhs);
171 
172     Value nan = rewriter.create<arith::ConstantFloatOp>(
173         loc, APFloat::getQNaN(floatType.getFloatSemantics()), floatType);
174     if (VectorType vectorType = lhs.getType().dyn_cast<VectorType>())
175       nan = rewriter.create<SplatOp>(loc, vectorType, nan);
176 
177     rewriter.replaceOpWithNewOp<SelectOp>(op, isNaN, nan, select);
178     return success();
179   }
180 };
181 
182 template <typename OpTy, arith::CmpIPredicate pred>
183 struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
184 public:
185   using OpRewritePattern<OpTy>::OpRewritePattern;
186   LogicalResult matchAndRewrite(OpTy op,
187                                 PatternRewriter &rewriter) const final {
188     Value lhs = op.getLhs();
189     Value rhs = op.getRhs();
190 
191     Location loc = op.getLoc();
192     Value cmp = rewriter.create<arith::CmpIOp>(loc, pred, lhs, rhs);
193     rewriter.replaceOpWithNewOp<SelectOp>(op, cmp, lhs, rhs);
194     return success();
195   }
196 };
197 
198 struct ArithmeticExpandOpsPass
199     : public ArithmeticExpandOpsBase<ArithmeticExpandOpsPass> {
200   void runOnFunction() override {
201     RewritePatternSet patterns(&getContext());
202     ConversionTarget target(getContext());
203 
204     arith::populateArithmeticExpandOpsPatterns(patterns);
205 
206     target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect>();
207     // clang-format off
208     target.addIllegalOp<
209       arith::CeilDivSIOp,
210       arith::CeilDivUIOp,
211       arith::FloorDivSIOp,
212       arith::MaxFOp,
213       arith::MaxSIOp,
214       arith::MaxUIOp,
215       arith::MinFOp,
216       arith::MinSIOp,
217       arith::MinUIOp
218     >();
219     // clang-format on
220     if (failed(
221             applyPartialConversion(getFunction(), target, std::move(patterns))))
222       signalPassFailure();
223   }
224 };
225 
226 } // namespace
227 
228 void mlir::arith::populateArithmeticExpandOpsPatterns(
229     RewritePatternSet &patterns) {
230   // clang-format off
231   patterns.add<
232     CeilDivSIOpConverter,
233     CeilDivUIOpConverter,
234     FloorDivSIOpConverter,
235     MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::OGT>,
236     MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::OLT>,
237     MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
238     MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
239     MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
240     MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>
241    >(patterns.getContext());
242   // clang-format on
243 }
244 
245 std::unique_ptr<Pass> mlir::arith::createArithmeticExpandOpsPass() {
246   return std::make_unique<ArithmeticExpandOpsPass>();
247 }
248