1 //===- ExpandOps.cpp - Pass to legalize Arithmetic ops for LLVM lowering --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PassDetail.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
12 #include "mlir/IR/TypeUtilities.h"
13 #include "mlir/Transforms/DialectConversion.h"
14
15 using namespace mlir;
16
17 /// Create an integer or index constant.
createConst(Location loc,Type type,int value,PatternRewriter & rewriter)18 static Value createConst(Location loc, Type type, int value,
19 PatternRewriter &rewriter) {
20 return rewriter.create<arith::ConstantOp>(
21 loc, rewriter.getIntegerAttr(type, value));
22 }
23
24 namespace {
25
26 /// Expands CeilDivUIOp (n, m) into
27 /// n == 0 ? 0 : ((n-1) / m) + 1
28 struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
29 using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anond9e264a70111::CeilDivUIOpConverter30 LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
31 PatternRewriter &rewriter) const final {
32 Location loc = op.getLoc();
33 Value a = op.getLhs();
34 Value b = op.getRhs();
35 Value zero = createConst(loc, a.getType(), 0, rewriter);
36 Value compare =
37 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
38 Value one = createConst(loc, a.getType(), 1, rewriter);
39 Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
40 Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
41 Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
42 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
43 return success();
44 }
45 };
46
47 /// Expands CeilDivSIOp (n, m) into
48 /// 1) x = (m > 0) ? -1 : 1
49 /// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
50 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
51 using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anond9e264a70111::CeilDivSIOpConverter52 LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
53 PatternRewriter &rewriter) const final {
54 Location loc = op.getLoc();
55 Type type = op.getType();
56 Value a = op.getLhs();
57 Value b = op.getRhs();
58 Value plusOne = createConst(loc, type, 1, rewriter);
59 Value zero = createConst(loc, type, 0, rewriter);
60 Value minusOne = createConst(loc, type, -1, rewriter);
61 // Compute x = (b>0) ? -1 : 1.
62 Value compare =
63 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
64 Value x = rewriter.create<arith::SelectOp>(loc, compare, minusOne, plusOne);
65 // Compute positive res: 1 + ((x+a)/b).
66 Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
67 Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
68 Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
69 // Compute negative res: - ((-a)/b).
70 Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
71 Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
72 Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
73 // Result is (a*b>0) ? pos result : neg result.
74 // Note, we want to avoid using a*b because of possible overflow.
75 // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
76 // not particuliarly care if a*b<0 is true or false when b is zero
77 // as this will result in an illegal divide. So `a*b<0` can be reformulated
78 // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
79 // We pick the first expression here.
80 Value aNeg =
81 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
82 Value aPos =
83 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
84 Value bNeg =
85 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
86 Value bPos =
87 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
88 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
89 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
90 Value compareRes =
91 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
92 // Perform substitution and return success.
93 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
94 negRes);
95 return success();
96 }
97 };
98
99 /// Expands FloorDivSIOp (n, m) into
100 /// 1) x = (m<0) ? 1 : -1
101 /// 2) return (n*m<0) ? - ((-n+x) / m) -1 : n / m
102 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
103 using OpRewritePattern::OpRewritePattern;
matchAndRewrite__anond9e264a70111::FloorDivSIOpConverter104 LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
105 PatternRewriter &rewriter) const final {
106 Location loc = op.getLoc();
107 Type type = op.getType();
108 Value a = op.getLhs();
109 Value b = op.getRhs();
110 Value plusOne = createConst(loc, type, 1, rewriter);
111 Value zero = createConst(loc, type, 0, rewriter);
112 Value minusOne = createConst(loc, type, -1, rewriter);
113 // Compute x = (b<0) ? 1 : -1.
114 Value compare =
115 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
116 Value x = rewriter.create<arith::SelectOp>(loc, compare, plusOne, minusOne);
117 // Compute negative res: -1 - ((x-a)/b).
118 Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a);
119 Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b);
120 Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB);
121 // Compute positive res: a/b.
122 Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b);
123 // Result is (a*b<0) ? negative result : positive result.
124 // Note, we want to avoid using a*b because of possible overflow.
125 // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
126 // not particuliarly care if a*b<0 is true or false when b is zero
127 // as this will result in an illegal divide. So `a*b<0` can be reformulated
128 // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'.
129 // We pick the first expression here.
130 Value aNeg =
131 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
132 Value aPos =
133 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
134 Value bNeg =
135 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
136 Value bPos =
137 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
138 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos);
139 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg);
140 Value compareRes =
141 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
142 // Perform substitution and return success.
143 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, negRes,
144 posRes);
145 return success();
146 }
147 };
148
149 template <typename OpTy, arith::CmpFPredicate pred>
150 struct MaxMinFOpConverter : public OpRewritePattern<OpTy> {
151 public:
152 using OpRewritePattern<OpTy>::OpRewritePattern;
153
matchAndRewrite__anond9e264a70111::MaxMinFOpConverter154 LogicalResult matchAndRewrite(OpTy op,
155 PatternRewriter &rewriter) const final {
156 Value lhs = op.getLhs();
157 Value rhs = op.getRhs();
158
159 Location loc = op.getLoc();
160 // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
161 static_assert(pred == arith::CmpFPredicate::UGT ||
162 pred == arith::CmpFPredicate::ULT,
163 "pred must be either UGT or ULT");
164 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
165 Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
166
167 // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
168 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
169 rhs, rhs);
170 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
171 return success();
172 }
173 };
174
175 template <typename OpTy, arith::CmpIPredicate pred>
176 struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
177 public:
178 using OpRewritePattern<OpTy>::OpRewritePattern;
matchAndRewrite__anond9e264a70111::MaxMinIOpConverter179 LogicalResult matchAndRewrite(OpTy op,
180 PatternRewriter &rewriter) const final {
181 Value lhs = op.getLhs();
182 Value rhs = op.getRhs();
183
184 Location loc = op.getLoc();
185 Value cmp = rewriter.create<arith::CmpIOp>(loc, pred, lhs, rhs);
186 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
187 return success();
188 }
189 };
190
191 struct ArithmeticExpandOpsPass
192 : public ArithmeticExpandOpsBase<ArithmeticExpandOpsPass> {
runOnOperation__anond9e264a70111::ArithmeticExpandOpsPass193 void runOnOperation() override {
194 RewritePatternSet patterns(&getContext());
195 ConversionTarget target(getContext());
196
197 arith::populateArithmeticExpandOpsPatterns(patterns);
198
199 target.addLegalDialect<arith::ArithmeticDialect>();
200 // clang-format off
201 target.addIllegalOp<
202 arith::CeilDivSIOp,
203 arith::CeilDivUIOp,
204 arith::FloorDivSIOp,
205 arith::MaxFOp,
206 arith::MaxSIOp,
207 arith::MaxUIOp,
208 arith::MinFOp,
209 arith::MinSIOp,
210 arith::MinUIOp
211 >();
212 // clang-format on
213 if (failed(applyPartialConversion(getOperation(), target,
214 std::move(patterns))))
215 signalPassFailure();
216 }
217 };
218
219 } // namespace
220
populateArithmeticExpandOpsPatterns(RewritePatternSet & patterns)221 void mlir::arith::populateArithmeticExpandOpsPatterns(
222 RewritePatternSet &patterns) {
223 // clang-format off
224 patterns.add<
225 CeilDivSIOpConverter,
226 CeilDivUIOpConverter,
227 FloorDivSIOpConverter,
228 MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::UGT>,
229 MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::ULT>,
230 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
231 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
232 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
233 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>
234 >(patterns.getContext());
235 // clang-format on
236 }
237
createArithmeticExpandOpsPass()238 std::unique_ptr<Pass> mlir::arith::createArithmeticExpandOpsPass() {
239 return std::make_unique<ArithmeticExpandOpsPass>();
240 }
241