1a54f4eaeSMogball //===- ExpandOps.cpp - Pass to legalize Arithmetic ops for LLVM lowering --===//
2a54f4eaeSMogball //
3a54f4eaeSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a54f4eaeSMogball // See https://llvm.org/LICENSE.txt for license information.
5a54f4eaeSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a54f4eaeSMogball //
7a54f4eaeSMogball //===----------------------------------------------------------------------===//
8a54f4eaeSMogball 
9a54f4eaeSMogball #include "PassDetail.h"
10f89bb3c0SAlexander Belyaev #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
129b1d90e8SAlexander Belyaev #include "mlir/IR/TypeUtilities.h"
13f89bb3c0SAlexander Belyaev #include "mlir/Transforms/DialectConversion.h"
14a54f4eaeSMogball 
15a54f4eaeSMogball using namespace mlir;
16a54f4eaeSMogball 
178cb785caSMogball /// Create an integer or index constant.
createConst(Location loc,Type type,int value,PatternRewriter & rewriter)188cb785caSMogball static Value createConst(Location loc, Type type, int value,
198cb785caSMogball                          PatternRewriter &rewriter) {
208cb785caSMogball   return rewriter.create<arith::ConstantOp>(
218cb785caSMogball       loc, rewriter.getIntegerAttr(type, value));
228cb785caSMogball }
238cb785caSMogball 
24a54f4eaeSMogball namespace {
25a54f4eaeSMogball 
268165eaa8Slipracer /// Expands CeilDivUIOp (n, m) into
278165eaa8Slipracer ///  n == 0 ? 0 : ((n-1) / m) + 1
288165eaa8Slipracer struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
298165eaa8Slipracer   using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anond9e264a70111::CeilDivUIOpConverter308165eaa8Slipracer   LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
318165eaa8Slipracer                                 PatternRewriter &rewriter) const final {
328165eaa8Slipracer     Location loc = op.getLoc();
3362fea88bSJacques Pienaar     Value a = op.getLhs();
3462fea88bSJacques Pienaar     Value b = op.getRhs();
358cb785caSMogball     Value zero = createConst(loc, a.getType(), 0, rewriter);
368165eaa8Slipracer     Value compare =
378165eaa8Slipracer         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
388cb785caSMogball     Value one = createConst(loc, a.getType(), 1, rewriter);
398165eaa8Slipracer     Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
408165eaa8Slipracer     Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
418165eaa8Slipracer     Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
42dec8af70SRiver Riddle     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
438165eaa8Slipracer     return success();
448165eaa8Slipracer   }
458165eaa8Slipracer };
468165eaa8Slipracer 
47a54f4eaeSMogball /// Expands CeilDivSIOp (n, m) into
48a54f4eaeSMogball ///   1) x = (m > 0) ? -1 : 1
49a54f4eaeSMogball ///   2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
50a54f4eaeSMogball struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
51a54f4eaeSMogball   using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anond9e264a70111::CeilDivSIOpConverter52a54f4eaeSMogball   LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
53a54f4eaeSMogball                                 PatternRewriter &rewriter) const final {
54a54f4eaeSMogball     Location loc = op.getLoc();
558cb785caSMogball     Type type = op.getType();
568cb785caSMogball     Value a = op.getLhs();
578cb785caSMogball     Value b = op.getRhs();
588cb785caSMogball     Value plusOne = createConst(loc, type, 1, rewriter);
598cb785caSMogball     Value zero = createConst(loc, type, 0, rewriter);
608cb785caSMogball     Value minusOne = createConst(loc, type, -1, rewriter);
61a54f4eaeSMogball     // Compute x = (b>0) ? -1 : 1.
62a54f4eaeSMogball     Value compare =
63a54f4eaeSMogball         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
64dec8af70SRiver Riddle     Value x = rewriter.create<arith::SelectOp>(loc, compare, minusOne, plusOne);
65a54f4eaeSMogball     // Compute positive res: 1 + ((x+a)/b).
66a54f4eaeSMogball     Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
67a54f4eaeSMogball     Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
68a54f4eaeSMogball     Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
69a54f4eaeSMogball     // Compute negative res: - ((-a)/b).
70a54f4eaeSMogball     Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
71a54f4eaeSMogball     Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
72a54f4eaeSMogball     Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
73a54f4eaeSMogball     // Result is (a*b>0) ? pos result : neg result.
74a54f4eaeSMogball     // Note, we want to avoid using a*b because of possible overflow.
75a54f4eaeSMogball     // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
76a54f4eaeSMogball     // not particuliarly care if a*b<0 is true or false when b is zero
77a54f4eaeSMogball     // as this will result in an illegal divide. So `a*b<0` can be reformulated
78a54f4eaeSMogball     // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
79a54f4eaeSMogball     // We pick the first expression here.
80a54f4eaeSMogball     Value aNeg =
81a54f4eaeSMogball         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
82a54f4eaeSMogball     Value aPos =
83a54f4eaeSMogball         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
84a54f4eaeSMogball     Value bNeg =
85a54f4eaeSMogball         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
86a54f4eaeSMogball     Value bPos =
87a54f4eaeSMogball         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
88a54f4eaeSMogball     Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
89a54f4eaeSMogball     Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
90a54f4eaeSMogball     Value compareRes =
91a54f4eaeSMogball         rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
92a54f4eaeSMogball     // Perform substitution and return success.
93dec8af70SRiver Riddle     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
94dec8af70SRiver Riddle                                                  negRes);
95a54f4eaeSMogball     return success();
96a54f4eaeSMogball   }
97a54f4eaeSMogball };
98a54f4eaeSMogball 
99a54f4eaeSMogball /// Expands FloorDivSIOp (n, m) into
100a54f4eaeSMogball ///   1)  x = (m<0) ? 1 : -1
101a54f4eaeSMogball ///   2)  return (n*m<0) ? - ((-n+x) / m) -1 : n / m
102a54f4eaeSMogball struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
103a54f4eaeSMogball   using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anond9e264a70111::FloorDivSIOpConverter104a54f4eaeSMogball   LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
105a54f4eaeSMogball                                 PatternRewriter &rewriter) const final {
106a54f4eaeSMogball     Location loc = op.getLoc();
1078cb785caSMogball     Type type = op.getType();
1088cb785caSMogball     Value a = op.getLhs();
1098cb785caSMogball     Value b = op.getRhs();
1108cb785caSMogball     Value plusOne = createConst(loc, type, 1, rewriter);
1118cb785caSMogball     Value zero = createConst(loc, type, 0, rewriter);
1128cb785caSMogball     Value minusOne = createConst(loc, type, -1, rewriter);
113a54f4eaeSMogball     // Compute x = (b<0) ? 1 : -1.
114a54f4eaeSMogball     Value compare =
115a54f4eaeSMogball         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
116dec8af70SRiver Riddle     Value x = rewriter.create<arith::SelectOp>(loc, compare, plusOne, minusOne);
117a54f4eaeSMogball     // Compute negative res: -1 - ((x-a)/b).
118a54f4eaeSMogball     Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a);
119a54f4eaeSMogball     Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b);
120a54f4eaeSMogball     Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB);
121a54f4eaeSMogball     // Compute positive res: a/b.
122a54f4eaeSMogball     Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b);
123a54f4eaeSMogball     // Result is (a*b<0) ? negative result : positive result.
124a54f4eaeSMogball     // Note, we want to avoid using a*b because of possible overflow.
125a54f4eaeSMogball     // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
126a54f4eaeSMogball     // not particuliarly care if a*b<0 is true or false when b is zero
127a54f4eaeSMogball     // as this will result in an illegal divide. So `a*b<0` can be reformulated
128a54f4eaeSMogball     // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'.
129a54f4eaeSMogball     // We pick the first expression here.
130a54f4eaeSMogball     Value aNeg =
131a54f4eaeSMogball         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
132a54f4eaeSMogball     Value aPos =
133a54f4eaeSMogball         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
134a54f4eaeSMogball     Value bNeg =
135a54f4eaeSMogball         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
136a54f4eaeSMogball     Value bPos =
137a54f4eaeSMogball         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
138a54f4eaeSMogball     Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos);
139a54f4eaeSMogball     Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg);
140a54f4eaeSMogball     Value compareRes =
141a54f4eaeSMogball         rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
142a54f4eaeSMogball     // Perform substitution and return success.
143dec8af70SRiver Riddle     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, negRes,
144dec8af70SRiver Riddle                                                  posRes);
145a54f4eaeSMogball     return success();
146a54f4eaeSMogball   }
147a54f4eaeSMogball };
148a54f4eaeSMogball 
1499b1d90e8SAlexander Belyaev template <typename OpTy, arith::CmpFPredicate pred>
1509b1d90e8SAlexander Belyaev struct MaxMinFOpConverter : public OpRewritePattern<OpTy> {
1519b1d90e8SAlexander Belyaev public:
1529b1d90e8SAlexander Belyaev   using OpRewritePattern<OpTy>::OpRewritePattern;
1539b1d90e8SAlexander Belyaev 
matchAndRewrite__anond9e264a70111::MaxMinFOpConverter1549b1d90e8SAlexander Belyaev   LogicalResult matchAndRewrite(OpTy op,
1559b1d90e8SAlexander Belyaev                                 PatternRewriter &rewriter) const final {
1569b1d90e8SAlexander Belyaev     Value lhs = op.getLhs();
1579b1d90e8SAlexander Belyaev     Value rhs = op.getRhs();
1589b1d90e8SAlexander Belyaev 
1599b1d90e8SAlexander Belyaev     Location loc = op.getLoc();
160be1aeb81SChristian Sigg     // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
161be1aeb81SChristian Sigg     static_assert(pred == arith::CmpFPredicate::UGT ||
16244bb5cd8SKazu Hirata                       pred == arith::CmpFPredicate::ULT,
16344bb5cd8SKazu Hirata                   "pred must be either UGT or ULT");
1649b1d90e8SAlexander Belyaev     Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
165dec8af70SRiver Riddle     Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
1669b1d90e8SAlexander Belyaev 
167be1aeb81SChristian Sigg     // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
1689b1d90e8SAlexander Belyaev     Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
169be1aeb81SChristian Sigg                                                  rhs, rhs);
170dec8af70SRiver Riddle     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
1719b1d90e8SAlexander Belyaev     return success();
1729b1d90e8SAlexander Belyaev   }
1739b1d90e8SAlexander Belyaev };
1749b1d90e8SAlexander Belyaev 
1759b1d90e8SAlexander Belyaev template <typename OpTy, arith::CmpIPredicate pred>
1769b1d90e8SAlexander Belyaev struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
1779b1d90e8SAlexander Belyaev public:
1789b1d90e8SAlexander Belyaev   using OpRewritePattern<OpTy>::OpRewritePattern;
matchAndRewrite__anond9e264a70111::MaxMinIOpConverter1799b1d90e8SAlexander Belyaev   LogicalResult matchAndRewrite(OpTy op,
1809b1d90e8SAlexander Belyaev                                 PatternRewriter &rewriter) const final {
1819b1d90e8SAlexander Belyaev     Value lhs = op.getLhs();
1829b1d90e8SAlexander Belyaev     Value rhs = op.getRhs();
1839b1d90e8SAlexander Belyaev 
1849b1d90e8SAlexander Belyaev     Location loc = op.getLoc();
1859b1d90e8SAlexander Belyaev     Value cmp = rewriter.create<arith::CmpIOp>(loc, pred, lhs, rhs);
186dec8af70SRiver Riddle     rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
1879b1d90e8SAlexander Belyaev     return success();
1889b1d90e8SAlexander Belyaev   }
1899b1d90e8SAlexander Belyaev };
1909b1d90e8SAlexander Belyaev 
191a54f4eaeSMogball struct ArithmeticExpandOpsPass
192a54f4eaeSMogball     : public ArithmeticExpandOpsBase<ArithmeticExpandOpsPass> {
runOnOperation__anond9e264a70111::ArithmeticExpandOpsPass19341574554SRiver Riddle   void runOnOperation() override {
194a54f4eaeSMogball     RewritePatternSet patterns(&getContext());
195a54f4eaeSMogball     ConversionTarget target(getContext());
196a54f4eaeSMogball 
197a54f4eaeSMogball     arith::populateArithmeticExpandOpsPatterns(patterns);
198a54f4eaeSMogball 
1991f971e23SRiver Riddle     target.addLegalDialect<arith::ArithmeticDialect>();
2009b1d90e8SAlexander Belyaev     // clang-format off
2019b1d90e8SAlexander Belyaev     target.addIllegalOp<
2029b1d90e8SAlexander Belyaev       arith::CeilDivSIOp,
2039b1d90e8SAlexander Belyaev       arith::CeilDivUIOp,
2049b1d90e8SAlexander Belyaev       arith::FloorDivSIOp,
2059b1d90e8SAlexander Belyaev       arith::MaxFOp,
2069b1d90e8SAlexander Belyaev       arith::MaxSIOp,
2079b1d90e8SAlexander Belyaev       arith::MaxUIOp,
2089b1d90e8SAlexander Belyaev       arith::MinFOp,
2099b1d90e8SAlexander Belyaev       arith::MinSIOp,
210*894641e9SChristopher Bate       arith::MinUIOp
2119b1d90e8SAlexander Belyaev     >();
2129b1d90e8SAlexander Belyaev     // clang-format on
21341574554SRiver Riddle     if (failed(applyPartialConversion(getOperation(), target,
21441574554SRiver Riddle                                       std::move(patterns))))
215a54f4eaeSMogball       signalPassFailure();
216a54f4eaeSMogball   }
217a54f4eaeSMogball };
218a54f4eaeSMogball 
219be0a7e9fSMehdi Amini } // namespace
220a54f4eaeSMogball 
populateArithmeticExpandOpsPatterns(RewritePatternSet & patterns)221a54f4eaeSMogball void mlir::arith::populateArithmeticExpandOpsPatterns(
222a54f4eaeSMogball     RewritePatternSet &patterns) {
2239b1d90e8SAlexander Belyaev   // clang-format off
2249b1d90e8SAlexander Belyaev   patterns.add<
2259b1d90e8SAlexander Belyaev     CeilDivSIOpConverter,
2269b1d90e8SAlexander Belyaev     CeilDivUIOpConverter,
2279b1d90e8SAlexander Belyaev     FloorDivSIOpConverter,
228be1aeb81SChristian Sigg     MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::UGT>,
229be1aeb81SChristian Sigg     MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::ULT>,
2309b1d90e8SAlexander Belyaev     MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
2319b1d90e8SAlexander Belyaev     MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
2329b1d90e8SAlexander Belyaev     MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
233*894641e9SChristopher Bate     MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>
2349b1d90e8SAlexander Belyaev    >(patterns.getContext());
2359b1d90e8SAlexander Belyaev   // clang-format on
236a54f4eaeSMogball }
237a54f4eaeSMogball 
createArithmeticExpandOpsPass()238a54f4eaeSMogball std::unique_ptr<Pass> mlir::arith::createArithmeticExpandOpsPass() {
239a54f4eaeSMogball   return std::make_unique<ArithmeticExpandOpsPass>();
240a54f4eaeSMogball }
241