1 //===- ExpandOps.cpp - Pass to legalize Arithmetic ops for LLVM lowering --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
11 #include "mlir/IR/TypeUtilities.h"
12 
13 using namespace mlir;
14 
15 namespace {
16 
17 /// Expands CeilDivUIOp (n, m) into
18 ///  n == 0 ? 0 : ((n-1) / m) + 1
19 struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
20   using OpRewritePattern::OpRewritePattern;
21   LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
22                                 PatternRewriter &rewriter) const final {
23     Location loc = op.getLoc();
24     Value a = op.lhs();
25     Value b = op.rhs();
26     Value zero = rewriter.create<arith::ConstantOp>(
27         loc, rewriter.getIntegerAttr(a.getType(), 0));
28     Value compare =
29         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
30     Value one = rewriter.create<arith::ConstantOp>(
31         loc, rewriter.getIntegerAttr(a.getType(), 1));
32     Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
33     Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
34     Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
35     Value res = rewriter.create<SelectOp>(loc, compare, zero, plusOne);
36     rewriter.replaceOp(op, {res});
37     return success();
38   }
39 };
40 
41 /// Expands CeilDivSIOp (n, m) into
42 ///   1) x = (m > 0) ? -1 : 1
43 ///   2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
44 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
45   using OpRewritePattern::OpRewritePattern;
46   LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
47                                 PatternRewriter &rewriter) const final {
48     Location loc = op.getLoc();
49     auto signedCeilDivIOp = cast<arith::CeilDivSIOp>(op);
50     Type type = signedCeilDivIOp.getType();
51     Value a = signedCeilDivIOp.getLhs();
52     Value b = signedCeilDivIOp.getRhs();
53     Value plusOne = rewriter.create<arith::ConstantOp>(
54         loc, rewriter.getIntegerAttr(type, 1));
55     Value zero = rewriter.create<arith::ConstantOp>(
56         loc, rewriter.getIntegerAttr(type, 0));
57     Value minusOne = rewriter.create<arith::ConstantOp>(
58         loc, rewriter.getIntegerAttr(type, -1));
59     // Compute x = (b>0) ? -1 : 1.
60     Value compare =
61         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
62     Value x = rewriter.create<SelectOp>(loc, compare, minusOne, plusOne);
63     // Compute positive res: 1 + ((x+a)/b).
64     Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
65     Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
66     Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
67     // Compute negative res: - ((-a)/b).
68     Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
69     Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
70     Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
71     // Result is (a*b>0) ? pos result : neg result.
72     // Note, we want to avoid using a*b because of possible overflow.
73     // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
74     // not particuliarly care if a*b<0 is true or false when b is zero
75     // as this will result in an illegal divide. So `a*b<0` can be reformulated
76     // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
77     // We pick the first expression here.
78     Value aNeg =
79         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
80     Value aPos =
81         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
82     Value bNeg =
83         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
84     Value bPos =
85         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
86     Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
87     Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
88     Value compareRes =
89         rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
90     Value res = rewriter.create<SelectOp>(loc, compareRes, posRes, negRes);
91     // Perform substitution and return success.
92     rewriter.replaceOp(op, {res});
93     return success();
94   }
95 };
96 
97 /// Expands FloorDivSIOp (n, m) into
98 ///   1)  x = (m<0) ? 1 : -1
99 ///   2)  return (n*m<0) ? - ((-n+x) / m) -1 : n / m
100 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
101   using OpRewritePattern::OpRewritePattern;
102   LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
103                                 PatternRewriter &rewriter) const final {
104     Location loc = op.getLoc();
105     arith::FloorDivSIOp signedFloorDivIOp = cast<arith::FloorDivSIOp>(op);
106     Type type = signedFloorDivIOp.getType();
107     Value a = signedFloorDivIOp.getLhs();
108     Value b = signedFloorDivIOp.getRhs();
109     Value plusOne = rewriter.create<arith::ConstantOp>(
110         loc, rewriter.getIntegerAttr(type, 1));
111     Value zero = rewriter.create<arith::ConstantOp>(
112         loc, rewriter.getIntegerAttr(type, 0));
113     Value minusOne = rewriter.create<arith::ConstantOp>(
114         loc, rewriter.getIntegerAttr(type, -1));
115     // Compute x = (b<0) ? 1 : -1.
116     Value compare =
117         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
118     Value x = rewriter.create<SelectOp>(loc, compare, plusOne, minusOne);
119     // Compute negative res: -1 - ((x-a)/b).
120     Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a);
121     Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b);
122     Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB);
123     // Compute positive res: a/b.
124     Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b);
125     // Result is (a*b<0) ? negative result : positive result.
126     // Note, we want to avoid using a*b because of possible overflow.
127     // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
128     // not particuliarly care if a*b<0 is true or false when b is zero
129     // as this will result in an illegal divide. So `a*b<0` can be reformulated
130     // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'.
131     // We pick the first expression here.
132     Value aNeg =
133         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
134     Value aPos =
135         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
136     Value bNeg =
137         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
138     Value bPos =
139         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
140     Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos);
141     Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg);
142     Value compareRes =
143         rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
144     Value res = rewriter.create<SelectOp>(loc, compareRes, negRes, posRes);
145     // Perform substitution and return success.
146     rewriter.replaceOp(op, {res});
147     return success();
148   }
149 };
150 
151 template <typename OpTy, arith::CmpFPredicate pred>
152 struct MaxMinFOpConverter : public OpRewritePattern<OpTy> {
153 public:
154   using OpRewritePattern<OpTy>::OpRewritePattern;
155 
156   LogicalResult matchAndRewrite(OpTy op,
157                                 PatternRewriter &rewriter) const final {
158     Value lhs = op.getLhs();
159     Value rhs = op.getRhs();
160 
161     Location loc = op.getLoc();
162     Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
163     Value select = rewriter.create<SelectOp>(loc, cmp, lhs, rhs);
164 
165     auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
166     Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
167                                                  lhs, rhs);
168 
169     Value nan = rewriter.create<arith::ConstantFloatOp>(
170         loc, APFloat::getQNaN(floatType.getFloatSemantics()), floatType);
171     if (VectorType vectorType = lhs.getType().dyn_cast<VectorType>())
172       nan = rewriter.create<SplatOp>(loc, vectorType, nan);
173 
174     rewriter.replaceOpWithNewOp<SelectOp>(op, isNaN, nan, select);
175     return success();
176   }
177 };
178 
179 template <typename OpTy, arith::CmpIPredicate pred>
180 struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
181 public:
182   using OpRewritePattern<OpTy>::OpRewritePattern;
183   LogicalResult matchAndRewrite(OpTy op,
184                                 PatternRewriter &rewriter) const final {
185     Value lhs = op.getLhs();
186     Value rhs = op.getRhs();
187 
188     Location loc = op.getLoc();
189     Value cmp = rewriter.create<arith::CmpIOp>(loc, pred, lhs, rhs);
190     rewriter.replaceOpWithNewOp<SelectOp>(op, cmp, lhs, rhs);
191     return success();
192   }
193 };
194 
195 struct ArithmeticExpandOpsPass
196     : public ArithmeticExpandOpsBase<ArithmeticExpandOpsPass> {
197   void runOnFunction() override {
198     RewritePatternSet patterns(&getContext());
199     ConversionTarget target(getContext());
200 
201     arith::populateArithmeticExpandOpsPatterns(patterns);
202 
203     target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect>();
204     // clang-format off
205     target.addIllegalOp<
206       arith::CeilDivSIOp,
207       arith::CeilDivUIOp,
208       arith::FloorDivSIOp,
209       arith::MaxFOp,
210       arith::MaxSIOp,
211       arith::MaxUIOp,
212       arith::MinFOp,
213       arith::MinSIOp,
214       arith::MinUIOp
215     >();
216     // clang-format on
217     if (failed(
218             applyPartialConversion(getFunction(), target, std::move(patterns))))
219       signalPassFailure();
220   }
221 };
222 
223 } // end anonymous namespace
224 
225 void mlir::arith::populateArithmeticExpandOpsPatterns(
226     RewritePatternSet &patterns) {
227   // clang-format off
228   patterns.add<
229     CeilDivSIOpConverter,
230     CeilDivUIOpConverter,
231     FloorDivSIOpConverter,
232     MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::OGT>,
233     MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::OLT>,
234     MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
235     MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
236     MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
237     MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>
238    >(patterns.getContext());
239   // clang-format on
240 }
241 
242 std::unique_ptr<Pass> mlir::arith::createArithmeticExpandOpsPass() {
243   return std::make_unique<ArithmeticExpandOpsPass>();
244 }
245