1 //===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with
2 // unsigned
3 // ones when all their arguments and results are statically non-negative --===//
4 //
5 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6 // See https://llvm.org/LICENSE.txt for license information.
7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #include "PassDetail.h"
12 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
13 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 
18 using namespace mlir;
19 using namespace mlir::arith;
20 using namespace mlir::dataflow;
21 
22 /// Succeeds when a value is statically non-negative in that it has a lower
23 /// bound on its value (if it is treated as signed) and that bound is
24 /// non-negative.
25 static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) {
26   auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
27   if (!result)
28     return failure();
29   const ConstantIntRanges &range = result->getValue().getValue();
30   return success(range.smin().isNonNegative());
31 }
32 
33 /// Succeeds if an op can be converted to its unsigned equivalent without
34 /// changing its semantics. This is the case when none of its openands or
35 /// results can be below 0 when analyzed from a signed perspective.
36 static LogicalResult staticallyNonNegative(DataFlowSolver &solver,
37                                            Operation *op) {
38   auto nonNegativePred = [&solver](Value v) -> bool {
39     return succeeded(staticallyNonNegative(solver, v));
40   };
41   return success(llvm::all_of(op->getOperands(), nonNegativePred) &&
42                  llvm::all_of(op->getResults(), nonNegativePred));
43 }
44 
45 /// Succeeds when the comparison predicate is a signed operation and all the
46 /// operands are non-negative, indicating that the cmpi operation `op` can have
47 /// its predicate changed to an unsigned equivalent.
48 static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op) {
49   CmpIPredicate pred = op.getPredicate();
50   switch (pred) {
51   case CmpIPredicate::sle:
52   case CmpIPredicate::slt:
53   case CmpIPredicate::sge:
54   case CmpIPredicate::sgt:
55     return success(llvm::all_of(op.getOperands(), [&solver](Value v) -> bool {
56       return succeeded(staticallyNonNegative(solver, v));
57     }));
58   default:
59     return failure();
60   }
61 }
62 
63 /// Return the unsigned equivalent of a signed comparison predicate,
64 /// or the predicate itself if there is none.
65 static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
66   switch (pred) {
67   case CmpIPredicate::sle:
68     return CmpIPredicate::ule;
69   case CmpIPredicate::slt:
70     return CmpIPredicate::ult;
71   case CmpIPredicate::sge:
72     return CmpIPredicate::uge;
73   case CmpIPredicate::sgt:
74     return CmpIPredicate::ugt;
75   default:
76     return pred;
77   }
78 }
79 
80 namespace {
81 template <typename Signed, typename Unsigned>
82 struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
83   using OpConversionPattern<Signed>::OpConversionPattern;
84 
85   LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor,
86                                 ConversionPatternRewriter &rw) const override {
87     rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(),
88                                     adaptor.getOperands(), op->getAttrs());
89     return success();
90   }
91 };
92 
93 struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> {
94   using OpConversionPattern<CmpIOp>::OpConversionPattern;
95 
96   LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor,
97                                 ConversionPatternRewriter &rw) const override {
98     rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
99                                   op.getLhs(), op.getRhs());
100     return success();
101   }
102 };
103 
104 struct ArithmeticUnsignedWhenEquivalentPass
105     : public ArithmeticUnsignedWhenEquivalentBase<
106           ArithmeticUnsignedWhenEquivalentPass> {
107   /// Implementation structure: first find all equivalent ops and collect them,
108   /// then perform all the rewrites in a second pass over the target op. This
109   /// ensures that analysis results are not invalidated during rewriting.
110   void runOnOperation() override {
111     Operation *op = getOperation();
112     MLIRContext *ctx = op->getContext();
113     DataFlowSolver solver;
114     solver.load<DeadCodeAnalysis>();
115     solver.load<IntegerRangeAnalysis>();
116     if (failed(solver.initializeAndRun(op)))
117       return signalPassFailure();
118 
119     ConversionTarget target(*ctx);
120     target.addLegalDialect<ArithmeticDialect>();
121     target
122         .addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, CeilDivUIOp, FloorDivSIOp,
123                                RemSIOp, MinSIOp, MaxSIOp, ExtSIOp>(
124             [&solver](Operation *op) -> Optional<bool> {
125               return failed(staticallyNonNegative(solver, op));
126             });
127     target.addDynamicallyLegalOp<CmpIOp>(
128         [&solver](CmpIOp op) -> Optional<bool> {
129           return failed(isCmpIConvertable(solver, op));
130         });
131 
132     RewritePatternSet patterns(ctx);
133     patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
134                  ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
135                  ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
136                  ConvertOpToUnsigned<RemSIOp, RemUIOp>,
137                  ConvertOpToUnsigned<MinSIOp, MinUIOp>,
138                  ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
139                  ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
140         ctx);
141 
142     if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
143       signalPassFailure();
144     }
145   }
146 };
147 } // end anonymous namespace
148 
149 std::unique_ptr<Pass>
150 mlir::arith::createArithmeticUnsignedWhenEquivalentPass() {
151   return std::make_unique<ArithmeticUnsignedWhenEquivalentPass>();
152 }
153