1a54f4eaeSMogball //===- ExpandOps.cpp - Pass to legalize Arithmetic ops for LLVM lowering --===//
2a54f4eaeSMogball //
3a54f4eaeSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a54f4eaeSMogball // See https://llvm.org/LICENSE.txt for license information.
5a54f4eaeSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a54f4eaeSMogball //
7a54f4eaeSMogball //===----------------------------------------------------------------------===//
8a54f4eaeSMogball
9a54f4eaeSMogball #include "PassDetail.h"
10f89bb3c0SAlexander Belyaev #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
129b1d90e8SAlexander Belyaev #include "mlir/IR/TypeUtilities.h"
13f89bb3c0SAlexander Belyaev #include "mlir/Transforms/DialectConversion.h"
14a54f4eaeSMogball
15a54f4eaeSMogball using namespace mlir;
16a54f4eaeSMogball
178cb785caSMogball /// Create an integer or index constant.
createConst(Location loc,Type type,int value,PatternRewriter & rewriter)188cb785caSMogball static Value createConst(Location loc, Type type, int value,
198cb785caSMogball PatternRewriter &rewriter) {
208cb785caSMogball return rewriter.create<arith::ConstantOp>(
218cb785caSMogball loc, rewriter.getIntegerAttr(type, value));
228cb785caSMogball }
238cb785caSMogball
24a54f4eaeSMogball namespace {
25a54f4eaeSMogball
268165eaa8Slipracer /// Expands CeilDivUIOp (n, m) into
278165eaa8Slipracer /// n == 0 ? 0 : ((n-1) / m) + 1
288165eaa8Slipracer struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
298165eaa8Slipracer using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anond9e264a70111::CeilDivUIOpConverter308165eaa8Slipracer LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
318165eaa8Slipracer PatternRewriter &rewriter) const final {
328165eaa8Slipracer Location loc = op.getLoc();
3362fea88bSJacques Pienaar Value a = op.getLhs();
3462fea88bSJacques Pienaar Value b = op.getRhs();
358cb785caSMogball Value zero = createConst(loc, a.getType(), 0, rewriter);
368165eaa8Slipracer Value compare =
378165eaa8Slipracer rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
388cb785caSMogball Value one = createConst(loc, a.getType(), 1, rewriter);
398165eaa8Slipracer Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
408165eaa8Slipracer Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
418165eaa8Slipracer Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
42dec8af70SRiver Riddle rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
438165eaa8Slipracer return success();
448165eaa8Slipracer }
458165eaa8Slipracer };
468165eaa8Slipracer
47a54f4eaeSMogball /// Expands CeilDivSIOp (n, m) into
48a54f4eaeSMogball /// 1) x = (m > 0) ? -1 : 1
49a54f4eaeSMogball /// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
50a54f4eaeSMogball struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
51a54f4eaeSMogball using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anond9e264a70111::CeilDivSIOpConverter52a54f4eaeSMogball LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
53a54f4eaeSMogball PatternRewriter &rewriter) const final {
54a54f4eaeSMogball Location loc = op.getLoc();
558cb785caSMogball Type type = op.getType();
568cb785caSMogball Value a = op.getLhs();
578cb785caSMogball Value b = op.getRhs();
588cb785caSMogball Value plusOne = createConst(loc, type, 1, rewriter);
598cb785caSMogball Value zero = createConst(loc, type, 0, rewriter);
608cb785caSMogball Value minusOne = createConst(loc, type, -1, rewriter);
61a54f4eaeSMogball // Compute x = (b>0) ? -1 : 1.
62a54f4eaeSMogball Value compare =
63a54f4eaeSMogball rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
64dec8af70SRiver Riddle Value x = rewriter.create<arith::SelectOp>(loc, compare, minusOne, plusOne);
65a54f4eaeSMogball // Compute positive res: 1 + ((x+a)/b).
66a54f4eaeSMogball Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
67a54f4eaeSMogball Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
68a54f4eaeSMogball Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
69a54f4eaeSMogball // Compute negative res: - ((-a)/b).
70a54f4eaeSMogball Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
71a54f4eaeSMogball Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
72a54f4eaeSMogball Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
73a54f4eaeSMogball // Result is (a*b>0) ? pos result : neg result.
74a54f4eaeSMogball // Note, we want to avoid using a*b because of possible overflow.
75a54f4eaeSMogball // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
76a54f4eaeSMogball // not particuliarly care if a*b<0 is true or false when b is zero
77a54f4eaeSMogball // as this will result in an illegal divide. So `a*b<0` can be reformulated
78a54f4eaeSMogball // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
79a54f4eaeSMogball // We pick the first expression here.
80a54f4eaeSMogball Value aNeg =
81a54f4eaeSMogball rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
82a54f4eaeSMogball Value aPos =
83a54f4eaeSMogball rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
84a54f4eaeSMogball Value bNeg =
85a54f4eaeSMogball rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
86a54f4eaeSMogball Value bPos =
87a54f4eaeSMogball rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
88a54f4eaeSMogball Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
89a54f4eaeSMogball Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
90a54f4eaeSMogball Value compareRes =
91a54f4eaeSMogball rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
92a54f4eaeSMogball // Perform substitution and return success.
93dec8af70SRiver Riddle rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
94dec8af70SRiver Riddle negRes);
95a54f4eaeSMogball return success();
96a54f4eaeSMogball }
97a54f4eaeSMogball };
98a54f4eaeSMogball
99a54f4eaeSMogball /// Expands FloorDivSIOp (n, m) into
100a54f4eaeSMogball /// 1) x = (m<0) ? 1 : -1
101a54f4eaeSMogball /// 2) return (n*m<0) ? - ((-n+x) / m) -1 : n / m
102a54f4eaeSMogball struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
103a54f4eaeSMogball using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anond9e264a70111::FloorDivSIOpConverter104a54f4eaeSMogball LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
105a54f4eaeSMogball PatternRewriter &rewriter) const final {
106a54f4eaeSMogball Location loc = op.getLoc();
1078cb785caSMogball Type type = op.getType();
1088cb785caSMogball Value a = op.getLhs();
1098cb785caSMogball Value b = op.getRhs();
1108cb785caSMogball Value plusOne = createConst(loc, type, 1, rewriter);
1118cb785caSMogball Value zero = createConst(loc, type, 0, rewriter);
1128cb785caSMogball Value minusOne = createConst(loc, type, -1, rewriter);
113a54f4eaeSMogball // Compute x = (b<0) ? 1 : -1.
114a54f4eaeSMogball Value compare =
115a54f4eaeSMogball rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
116dec8af70SRiver Riddle Value x = rewriter.create<arith::SelectOp>(loc, compare, plusOne, minusOne);
117a54f4eaeSMogball // Compute negative res: -1 - ((x-a)/b).
118a54f4eaeSMogball Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a);
119a54f4eaeSMogball Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b);
120a54f4eaeSMogball Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB);
121a54f4eaeSMogball // Compute positive res: a/b.
122a54f4eaeSMogball Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b);
123a54f4eaeSMogball // Result is (a*b<0) ? negative result : positive result.
124a54f4eaeSMogball // Note, we want to avoid using a*b because of possible overflow.
125a54f4eaeSMogball // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
126a54f4eaeSMogball // not particuliarly care if a*b<0 is true or false when b is zero
127a54f4eaeSMogball // as this will result in an illegal divide. So `a*b<0` can be reformulated
128a54f4eaeSMogball // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'.
129a54f4eaeSMogball // We pick the first expression here.
130a54f4eaeSMogball Value aNeg =
131a54f4eaeSMogball rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
132a54f4eaeSMogball Value aPos =
133a54f4eaeSMogball rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
134a54f4eaeSMogball Value bNeg =
135a54f4eaeSMogball rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
136a54f4eaeSMogball Value bPos =
137a54f4eaeSMogball rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
138a54f4eaeSMogball Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos);
139a54f4eaeSMogball Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg);
140a54f4eaeSMogball Value compareRes =
141a54f4eaeSMogball rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
142a54f4eaeSMogball // Perform substitution and return success.
143dec8af70SRiver Riddle rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, negRes,
144dec8af70SRiver Riddle posRes);
145a54f4eaeSMogball return success();
146a54f4eaeSMogball }
147a54f4eaeSMogball };
148a54f4eaeSMogball
1499b1d90e8SAlexander Belyaev template <typename OpTy, arith::CmpFPredicate pred>
1509b1d90e8SAlexander Belyaev struct MaxMinFOpConverter : public OpRewritePattern<OpTy> {
1519b1d90e8SAlexander Belyaev public:
1529b1d90e8SAlexander Belyaev using OpRewritePattern<OpTy>::OpRewritePattern;
1539b1d90e8SAlexander Belyaev
matchAndRewrite__anond9e264a70111::MaxMinFOpConverter1549b1d90e8SAlexander Belyaev LogicalResult matchAndRewrite(OpTy op,
1559b1d90e8SAlexander Belyaev PatternRewriter &rewriter) const final {
1569b1d90e8SAlexander Belyaev Value lhs = op.getLhs();
1579b1d90e8SAlexander Belyaev Value rhs = op.getRhs();
1589b1d90e8SAlexander Belyaev
1599b1d90e8SAlexander Belyaev Location loc = op.getLoc();
160be1aeb81SChristian Sigg // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
161be1aeb81SChristian Sigg static_assert(pred == arith::CmpFPredicate::UGT ||
16244bb5cd8SKazu Hirata pred == arith::CmpFPredicate::ULT,
16344bb5cd8SKazu Hirata "pred must be either UGT or ULT");
1649b1d90e8SAlexander Belyaev Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
165dec8af70SRiver Riddle Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
1669b1d90e8SAlexander Belyaev
167be1aeb81SChristian Sigg // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
1689b1d90e8SAlexander Belyaev Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
169be1aeb81SChristian Sigg rhs, rhs);
170dec8af70SRiver Riddle rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
1719b1d90e8SAlexander Belyaev return success();
1729b1d90e8SAlexander Belyaev }
1739b1d90e8SAlexander Belyaev };
1749b1d90e8SAlexander Belyaev
1759b1d90e8SAlexander Belyaev template <typename OpTy, arith::CmpIPredicate pred>
1769b1d90e8SAlexander Belyaev struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
1779b1d90e8SAlexander Belyaev public:
1789b1d90e8SAlexander Belyaev using OpRewritePattern<OpTy>::OpRewritePattern;
matchAndRewrite__anond9e264a70111::MaxMinIOpConverter1799b1d90e8SAlexander Belyaev LogicalResult matchAndRewrite(OpTy op,
1809b1d90e8SAlexander Belyaev PatternRewriter &rewriter) const final {
1819b1d90e8SAlexander Belyaev Value lhs = op.getLhs();
1829b1d90e8SAlexander Belyaev Value rhs = op.getRhs();
1839b1d90e8SAlexander Belyaev
1849b1d90e8SAlexander Belyaev Location loc = op.getLoc();
1859b1d90e8SAlexander Belyaev Value cmp = rewriter.create<arith::CmpIOp>(loc, pred, lhs, rhs);
186dec8af70SRiver Riddle rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
1879b1d90e8SAlexander Belyaev return success();
1889b1d90e8SAlexander Belyaev }
1899b1d90e8SAlexander Belyaev };
1909b1d90e8SAlexander Belyaev
191a54f4eaeSMogball struct ArithmeticExpandOpsPass
192a54f4eaeSMogball : public ArithmeticExpandOpsBase<ArithmeticExpandOpsPass> {
runOnOperation__anond9e264a70111::ArithmeticExpandOpsPass19341574554SRiver Riddle void runOnOperation() override {
194a54f4eaeSMogball RewritePatternSet patterns(&getContext());
195a54f4eaeSMogball ConversionTarget target(getContext());
196a54f4eaeSMogball
197a54f4eaeSMogball arith::populateArithmeticExpandOpsPatterns(patterns);
198a54f4eaeSMogball
1991f971e23SRiver Riddle target.addLegalDialect<arith::ArithmeticDialect>();
2009b1d90e8SAlexander Belyaev // clang-format off
2019b1d90e8SAlexander Belyaev target.addIllegalOp<
2029b1d90e8SAlexander Belyaev arith::CeilDivSIOp,
2039b1d90e8SAlexander Belyaev arith::CeilDivUIOp,
2049b1d90e8SAlexander Belyaev arith::FloorDivSIOp,
2059b1d90e8SAlexander Belyaev arith::MaxFOp,
2069b1d90e8SAlexander Belyaev arith::MaxSIOp,
2079b1d90e8SAlexander Belyaev arith::MaxUIOp,
2089b1d90e8SAlexander Belyaev arith::MinFOp,
2099b1d90e8SAlexander Belyaev arith::MinSIOp,
210*894641e9SChristopher Bate arith::MinUIOp
2119b1d90e8SAlexander Belyaev >();
2129b1d90e8SAlexander Belyaev // clang-format on
21341574554SRiver Riddle if (failed(applyPartialConversion(getOperation(), target,
21441574554SRiver Riddle std::move(patterns))))
215a54f4eaeSMogball signalPassFailure();
216a54f4eaeSMogball }
217a54f4eaeSMogball };
218a54f4eaeSMogball
219be0a7e9fSMehdi Amini } // namespace
220a54f4eaeSMogball
populateArithmeticExpandOpsPatterns(RewritePatternSet & patterns)221a54f4eaeSMogball void mlir::arith::populateArithmeticExpandOpsPatterns(
222a54f4eaeSMogball RewritePatternSet &patterns) {
2239b1d90e8SAlexander Belyaev // clang-format off
2249b1d90e8SAlexander Belyaev patterns.add<
2259b1d90e8SAlexander Belyaev CeilDivSIOpConverter,
2269b1d90e8SAlexander Belyaev CeilDivUIOpConverter,
2279b1d90e8SAlexander Belyaev FloorDivSIOpConverter,
228be1aeb81SChristian Sigg MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::UGT>,
229be1aeb81SChristian Sigg MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::ULT>,
2309b1d90e8SAlexander Belyaev MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
2319b1d90e8SAlexander Belyaev MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
2329b1d90e8SAlexander Belyaev MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
233*894641e9SChristopher Bate MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>
2349b1d90e8SAlexander Belyaev >(patterns.getContext());
2359b1d90e8SAlexander Belyaev // clang-format on
236a54f4eaeSMogball }
237a54f4eaeSMogball
createArithmeticExpandOpsPass()238a54f4eaeSMogball std::unique_ptr<Pass> mlir::arith::createArithmeticExpandOpsPass() {
239a54f4eaeSMogball return std::make_unique<ArithmeticExpandOpsPass>();
240a54f4eaeSMogball }
241