1 //===- ExpandOps.cpp - Pass to legalize Arithmetic ops for LLVM lowering --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
12 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
13 #include "mlir/IR/TypeUtilities.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 
16 using namespace mlir;
17 
18 /// Create an integer or index constant.
19 static Value createConst(Location loc, Type type, int value,
20                          PatternRewriter &rewriter) {
21   return rewriter.create<arith::ConstantOp>(
22       loc, rewriter.getIntegerAttr(type, value));
23 }
24 
25 namespace {
26 
27 /// Expands CeilDivUIOp (n, m) into
28 ///  n == 0 ? 0 : ((n-1) / m) + 1
29 struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
30   using OpRewritePattern::OpRewritePattern;
31   LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
32                                 PatternRewriter &rewriter) const final {
33     Location loc = op.getLoc();
34     Value a = op.getLhs();
35     Value b = op.getRhs();
36     Value zero = createConst(loc, a.getType(), 0, rewriter);
37     Value compare =
38         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
39     Value one = createConst(loc, a.getType(), 1, rewriter);
40     Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
41     Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
42     Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
43     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
44     return success();
45   }
46 };
47 
48 /// Expands CeilDivSIOp (n, m) into
49 ///   1) x = (m > 0) ? -1 : 1
50 ///   2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
51 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
52   using OpRewritePattern::OpRewritePattern;
53   LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
54                                 PatternRewriter &rewriter) const final {
55     Location loc = op.getLoc();
56     Type type = op.getType();
57     Value a = op.getLhs();
58     Value b = op.getRhs();
59     Value plusOne = createConst(loc, type, 1, rewriter);
60     Value zero = createConst(loc, type, 0, rewriter);
61     Value minusOne = createConst(loc, type, -1, rewriter);
62     // Compute x = (b>0) ? -1 : 1.
63     Value compare =
64         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
65     Value x = rewriter.create<arith::SelectOp>(loc, compare, minusOne, plusOne);
66     // Compute positive res: 1 + ((x+a)/b).
67     Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
68     Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
69     Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
70     // Compute negative res: - ((-a)/b).
71     Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
72     Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
73     Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
74     // Result is (a*b>0) ? pos result : neg result.
75     // Note, we want to avoid using a*b because of possible overflow.
76     // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
77     // not particuliarly care if a*b<0 is true or false when b is zero
78     // as this will result in an illegal divide. So `a*b<0` can be reformulated
79     // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
80     // We pick the first expression here.
81     Value aNeg =
82         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
83     Value aPos =
84         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
85     Value bNeg =
86         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
87     Value bPos =
88         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
89     Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
90     Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
91     Value compareRes =
92         rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
93     // Perform substitution and return success.
94     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
95                                                  negRes);
96     return success();
97   }
98 };
99 
100 /// Expands FloorDivSIOp (n, m) into
101 ///   1)  x = (m<0) ? 1 : -1
102 ///   2)  return (n*m<0) ? - ((-n+x) / m) -1 : n / m
103 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
104   using OpRewritePattern::OpRewritePattern;
105   LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
106                                 PatternRewriter &rewriter) const final {
107     Location loc = op.getLoc();
108     Type type = op.getType();
109     Value a = op.getLhs();
110     Value b = op.getRhs();
111     Value plusOne = createConst(loc, type, 1, rewriter);
112     Value zero = createConst(loc, type, 0, rewriter);
113     Value minusOne = createConst(loc, type, -1, rewriter);
114     // Compute x = (b<0) ? 1 : -1.
115     Value compare =
116         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
117     Value x = rewriter.create<arith::SelectOp>(loc, compare, plusOne, minusOne);
118     // Compute negative res: -1 - ((x-a)/b).
119     Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a);
120     Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b);
121     Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB);
122     // Compute positive res: a/b.
123     Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b);
124     // Result is (a*b<0) ? negative result : positive result.
125     // Note, we want to avoid using a*b because of possible overflow.
126     // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
127     // not particuliarly care if a*b<0 is true or false when b is zero
128     // as this will result in an illegal divide. So `a*b<0` can be reformulated
129     // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'.
130     // We pick the first expression here.
131     Value aNeg =
132         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
133     Value aPos =
134         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
135     Value bNeg =
136         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
137     Value bPos =
138         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
139     Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos);
140     Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg);
141     Value compareRes =
142         rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
143     // Perform substitution and return success.
144     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, negRes,
145                                                  posRes);
146     return success();
147   }
148 };
149 
150 template <typename OpTy, arith::CmpFPredicate pred>
151 struct MaxMinFOpConverter : public OpRewritePattern<OpTy> {
152 public:
153   using OpRewritePattern<OpTy>::OpRewritePattern;
154 
155   LogicalResult matchAndRewrite(OpTy op,
156                                 PatternRewriter &rewriter) const final {
157     Value lhs = op.getLhs();
158     Value rhs = op.getRhs();
159 
160     Location loc = op.getLoc();
161     // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
162     static_assert(pred == arith::CmpFPredicate::UGT ||
163                       pred == arith::CmpFPredicate::ULT,
164                   "pred must be either UGT or ULT");
165     Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
166     Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
167 
168     // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
169     Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
170                                                  rhs, rhs);
171     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
172     return success();
173   }
174 };
175 
176 template <typename OpTy, arith::CmpIPredicate pred>
177 struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
178 public:
179   using OpRewritePattern<OpTy>::OpRewritePattern;
180   LogicalResult matchAndRewrite(OpTy op,
181                                 PatternRewriter &rewriter) const final {
182     Value lhs = op.getLhs();
183     Value rhs = op.getRhs();
184 
185     Location loc = op.getLoc();
186     Value cmp = rewriter.create<arith::CmpIOp>(loc, pred, lhs, rhs);
187     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
188     return success();
189   }
190 };
191 
192 /// Lowers `arith.delinearize_index` into a sequence of division and remainder
193 /// operations.
194 struct LowerDelinearizeIndexOps
195     : public OpRewritePattern<arith::DelinearizeIndexOp> {
196   using OpRewritePattern<arith::DelinearizeIndexOp>::OpRewritePattern;
197   LogicalResult matchAndRewrite(arith::DelinearizeIndexOp op,
198                                 PatternRewriter &rewriter) const override {
199     FailureOr<SmallVector<Value>> multiIndex =
200         delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
201                          llvm::to_vector(op.getBasis()));
202     if (failed(multiIndex))
203       return failure();
204     rewriter.replaceOp(op, *multiIndex);
205     return success();
206   }
207 };
208 
209 struct ArithmeticExpandOpsPass
210     : public ArithmeticExpandOpsBase<ArithmeticExpandOpsPass> {
211   void runOnOperation() override {
212     RewritePatternSet patterns(&getContext());
213     ConversionTarget target(getContext());
214 
215     arith::populateArithmeticExpandOpsPatterns(patterns);
216 
217     target.addLegalDialect<arith::ArithmeticDialect>();
218     // clang-format off
219     target.addIllegalOp<
220       arith::CeilDivSIOp,
221       arith::CeilDivUIOp,
222       arith::FloorDivSIOp,
223       arith::MaxFOp,
224       arith::MaxSIOp,
225       arith::MaxUIOp,
226       arith::MinFOp,
227       arith::MinSIOp,
228       arith::MinUIOp,
229       arith::DelinearizeIndexOp
230     >();
231     // clang-format on
232     if (failed(applyPartialConversion(getOperation(), target,
233                                       std::move(patterns))))
234       signalPassFailure();
235   }
236 };
237 
238 } // namespace
239 
240 void mlir::arith::populateArithmeticExpandOpsPatterns(
241     RewritePatternSet &patterns) {
242   // clang-format off
243   patterns.add<
244     CeilDivSIOpConverter,
245     CeilDivUIOpConverter,
246     FloorDivSIOpConverter,
247     MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::UGT>,
248     MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::ULT>,
249     MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
250     MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
251     MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
252     MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
253     LowerDelinearizeIndexOps
254    >(patterns.getContext());
255   // clang-format on
256 }
257 
258 std::unique_ptr<Pass> mlir::arith::createArithmeticExpandOpsPass() {
259   return std::make_unique<ArithmeticExpandOpsPass>();
260 }
261