1 //===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with
2 // unsigned
3 // ones when all their arguments and results are statically non-negative --===//
4 //
5 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6 // See https://llvm.org/LICENSE.txt for license information.
7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #include "PassDetail.h"
12 #include "mlir/Analysis/IntRangeAnalysis.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
15 
16 using namespace mlir;
17 using namespace mlir::arith;
18 
19 using OpList = llvm::SmallVector<Operation *>;
20 
21 /// Returns true when a value is statically non-negative in that it has a lower
22 /// bound on its value (if it is treated as signed) and that bound is
23 /// non-negative.
24 static bool staticallyNonNegative(IntRangeAnalysis &analysis, Value v) {
25   Optional<ConstantIntRanges> result = analysis.getResult(v);
26   if (!result.hasValue())
27     return false;
28   const ConstantIntRanges &range = result.getValue();
29   return (range.smin().isNonNegative());
30 }
31 
32 /// Identify all operations in a block that have signed equivalents and have
33 /// operands and results that are statically non-negative.
34 template <typename... Ts>
35 static void getConvertableOps(Operation *root, OpList &toRewrite,
36                               IntRangeAnalysis &analysis) {
37   auto nonNegativePred = [&analysis](Value v) -> bool {
38     return staticallyNonNegative(analysis, v);
39   };
40   root->walk([&nonNegativePred, &toRewrite](Operation *orig) {
41     if (isa<Ts...>(orig) &&
42         llvm::all_of(orig->getOperands(), nonNegativePred) &&
43         llvm::all_of(orig->getResults(), nonNegativePred)) {
44       toRewrite.push_back(orig);
45     }
46   });
47 }
48 
49 static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
50   switch (pred) {
51   case CmpIPredicate::sle:
52     return CmpIPredicate::ule;
53   case CmpIPredicate::slt:
54     return CmpIPredicate::ult;
55   case CmpIPredicate::sge:
56     return CmpIPredicate::uge;
57   case CmpIPredicate::sgt:
58     return CmpIPredicate::ugt;
59   default:
60     return pred;
61   }
62 }
63 
64 /// Find all cmpi ops that can be replaced by their unsigned equivalents.
65 static void getConvertableCmpi(Operation *root, OpList &toRewrite,
66                                IntRangeAnalysis &analysis) {
67   auto nonNegativePred = [&analysis](Value v) -> bool {
68     return staticallyNonNegative(analysis, v);
69   };
70   root->walk([&nonNegativePred, &toRewrite](arith::CmpIOp orig) {
71     CmpIPredicate pred = orig.getPredicate();
72     if (toUnsignedPred(pred) != pred &&
73         // i1 will spuriously and trivially show up as pontentially negative,
74         // so don't check the results
75         llvm::all_of(orig->getOperands(), nonNegativePred)) {
76       toRewrite.push_back(orig.getOperation());
77     }
78   });
79 }
80 
81 /// Return ops to be replaced in the order they should be rewritten.
82 static OpList getMatching(Operation *root, IntRangeAnalysis &analysis) {
83   OpList ret;
84   getConvertableOps<DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp, MinSIOp,
85                     MaxSIOp, ExtSIOp>(root, ret, analysis);
86   // Since these are in-place changes, they don't need to be topological order
87   // like the others.
88   getConvertableCmpi(root, ret, analysis);
89   return ret;
90 }
91 
92 template <typename T, typename U>
93 static bool rewriteOp(Operation *op, OpBuilder &b) {
94   if (isa<T>(op)) {
95     OpBuilder::InsertionGuard guard(b);
96     b.setInsertionPoint(op);
97     Operation *newOp = b.create<U>(op->getLoc(), op->getResultTypes(),
98                                    op->getOperands(), op->getAttrs());
99     op->replaceAllUsesWith(newOp->getResults());
100     op->erase();
101     return true;
102   }
103   return false;
104 }
105 
106 static bool rewriteCmpI(Operation *op, OpBuilder &b) {
107   if (auto cmpOp = dyn_cast<CmpIOp>(op)) {
108     cmpOp.setPredicateAttr(CmpIPredicateAttr::get(
109         b.getContext(), toUnsignedPred(cmpOp.getPredicate())));
110     return true;
111   }
112   return false;
113 }
114 
115 static void rewrite(Operation *root, const OpList &toReplace) {
116   OpBuilder b(root->getContext());
117   b.setInsertionPoint(root);
118   for (Operation *op : toReplace) {
119     rewriteOp<DivSIOp, DivUIOp>(op, b) ||
120         rewriteOp<CeilDivSIOp, CeilDivUIOp>(op, b) ||
121         rewriteOp<FloorDivSIOp, DivUIOp>(op, b) ||
122         rewriteOp<RemSIOp, RemUIOp>(op, b) ||
123         rewriteOp<MinSIOp, MinUIOp>(op, b) ||
124         rewriteOp<MaxSIOp, MaxUIOp>(op, b) ||
125         rewriteOp<ExtSIOp, ExtUIOp>(op, b) || rewriteCmpI(op, b);
126   }
127 }
128 
129 namespace {
130 struct ArithmeticUnsignedWhenEquivalentPass
131     : public ArithmeticUnsignedWhenEquivalentBase<
132           ArithmeticUnsignedWhenEquivalentPass> {
133   /// Implementation structure: first find all equivalent ops and collect them,
134   /// then perform all the rewrites in a second pass over the target op. This
135   /// ensures that analysis results are not invalidated during rewriting.
136   void runOnOperation() override {
137     Operation *op = getOperation();
138     IntRangeAnalysis analysis(op);
139     rewrite(op, getMatching(op, analysis));
140   }
141 };
142 } // end anonymous namespace
143 
144 std::unique_ptr<Pass>
145 mlir::arith::createArithmeticUnsignedWhenEquivalentPass() {
146   return std::make_unique<ArithmeticUnsignedWhenEquivalentPass>();
147 }
148