1 //===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with
2 // unsigned
3 // ones when all their arguments and results are statically non-negative --===//
4 //
5 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6 // See https://llvm.org/LICENSE.txt for license information.
7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //
9 //===----------------------------------------------------------------------===//
10
11 #include "PassDetail.h"
12 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
13 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
16 #include "mlir/Transforms/DialectConversion.h"
17
18 using namespace mlir;
19 using namespace mlir::arith;
20 using namespace mlir::dataflow;
21
22 /// Succeeds when a value is statically non-negative in that it has a lower
23 /// bound on its value (if it is treated as signed) and that bound is
24 /// non-negative.
staticallyNonNegative(DataFlowSolver & solver,Value v)25 static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) {
26 auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
27 if (!result)
28 return failure();
29 const ConstantIntRanges &range = result->getValue().getValue();
30 return success(range.smin().isNonNegative());
31 }
32
33 /// Succeeds if an op can be converted to its unsigned equivalent without
34 /// changing its semantics. This is the case when none of its openands or
35 /// results can be below 0 when analyzed from a signed perspective.
staticallyNonNegative(DataFlowSolver & solver,Operation * op)36 static LogicalResult staticallyNonNegative(DataFlowSolver &solver,
37 Operation *op) {
38 auto nonNegativePred = [&solver](Value v) -> bool {
39 return succeeded(staticallyNonNegative(solver, v));
40 };
41 return success(llvm::all_of(op->getOperands(), nonNegativePred) &&
42 llvm::all_of(op->getResults(), nonNegativePred));
43 }
44
45 /// Succeeds when the comparison predicate is a signed operation and all the
46 /// operands are non-negative, indicating that the cmpi operation `op` can have
47 /// its predicate changed to an unsigned equivalent.
isCmpIConvertable(DataFlowSolver & solver,CmpIOp op)48 static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op) {
49 CmpIPredicate pred = op.getPredicate();
50 switch (pred) {
51 case CmpIPredicate::sle:
52 case CmpIPredicate::slt:
53 case CmpIPredicate::sge:
54 case CmpIPredicate::sgt:
55 return success(llvm::all_of(op.getOperands(), [&solver](Value v) -> bool {
56 return succeeded(staticallyNonNegative(solver, v));
57 }));
58 default:
59 return failure();
60 }
61 }
62
63 /// Return the unsigned equivalent of a signed comparison predicate,
64 /// or the predicate itself if there is none.
toUnsignedPred(CmpIPredicate pred)65 static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
66 switch (pred) {
67 case CmpIPredicate::sle:
68 return CmpIPredicate::ule;
69 case CmpIPredicate::slt:
70 return CmpIPredicate::ult;
71 case CmpIPredicate::sge:
72 return CmpIPredicate::uge;
73 case CmpIPredicate::sgt:
74 return CmpIPredicate::ugt;
75 default:
76 return pred;
77 }
78 }
79
80 namespace {
81 template <typename Signed, typename Unsigned>
82 struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
83 using OpConversionPattern<Signed>::OpConversionPattern;
84
matchAndRewrite__anon8eba27220311::ConvertOpToUnsigned85 LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor,
86 ConversionPatternRewriter &rw) const override {
87 rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(),
88 adaptor.getOperands(), op->getAttrs());
89 return success();
90 }
91 };
92
93 struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> {
94 using OpConversionPattern<CmpIOp>::OpConversionPattern;
95
matchAndRewrite__anon8eba27220311::ConvertCmpIToUnsigned96 LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor,
97 ConversionPatternRewriter &rw) const override {
98 rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
99 op.getLhs(), op.getRhs());
100 return success();
101 }
102 };
103
104 struct ArithmeticUnsignedWhenEquivalentPass
105 : public ArithmeticUnsignedWhenEquivalentBase<
106 ArithmeticUnsignedWhenEquivalentPass> {
107 /// Implementation structure: first find all equivalent ops and collect them,
108 /// then perform all the rewrites in a second pass over the target op. This
109 /// ensures that analysis results are not invalidated during rewriting.
runOnOperation__anon8eba27220311::ArithmeticUnsignedWhenEquivalentPass110 void runOnOperation() override {
111 Operation *op = getOperation();
112 MLIRContext *ctx = op->getContext();
113 DataFlowSolver solver;
114 solver.load<DeadCodeAnalysis>();
115 solver.load<IntegerRangeAnalysis>();
116 if (failed(solver.initializeAndRun(op)))
117 return signalPassFailure();
118
119 ConversionTarget target(*ctx);
120 target.addLegalDialect<ArithmeticDialect>();
121 target
122 .addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, CeilDivUIOp, FloorDivSIOp,
123 RemSIOp, MinSIOp, MaxSIOp, ExtSIOp>(
124 [&solver](Operation *op) -> Optional<bool> {
125 return failed(staticallyNonNegative(solver, op));
126 });
127 target.addDynamicallyLegalOp<CmpIOp>(
128 [&solver](CmpIOp op) -> Optional<bool> {
129 return failed(isCmpIConvertable(solver, op));
130 });
131
132 RewritePatternSet patterns(ctx);
133 patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
134 ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
135 ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
136 ConvertOpToUnsigned<RemSIOp, RemUIOp>,
137 ConvertOpToUnsigned<MinSIOp, MinUIOp>,
138 ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
139 ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
140 ctx);
141
142 if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
143 signalPassFailure();
144 }
145 }
146 };
147 } // end anonymous namespace
148
149 std::unique_ptr<Pass>
createArithmeticUnsignedWhenEquivalentPass()150 mlir::arith::createArithmeticUnsignedWhenEquivalentPass() {
151 return std::make_unique<ArithmeticUnsignedWhenEquivalentPass>();
152 }
153