1 //===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with
2 // unsigned
3 // ones when all their arguments and results are statically non-negative --===//
4 //
5 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6 // See https://llvm.org/LICENSE.txt for license information.
7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #include "PassDetail.h"
12 #include "mlir/Analysis/IntRangeAnalysis.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
15 
16 using namespace mlir;
17 using namespace mlir::arith;
18 
19 using OpList = llvm::SmallVector<Operation *>;
20 
21 /// Returns true when a value is statically non-negative in that it has a lower
22 /// bound on its value (if it is treated as signed) and that bound is
23 /// non-negative.
24 static bool staticallyNonNegative(IntRangeAnalysis &analysis, Value v) {
25   Optional<ConstantIntRanges> result = analysis.getResult(v);
26   if (!result.hasValue())
27     return false;
28   const ConstantIntRanges &range = result.getValue();
29   return (range.smin().isNonNegative());
30 }
31 
32 /// Identify all operations in a block that have signed equivalents and have
33 /// operands and results that are statically non-negative.
34 template <typename... Ts>
35 static void getConvertableOps(Operation *root, OpList &toRewrite,
36                               IntRangeAnalysis &analysis) {
37   auto nonNegativePred = [&analysis](Value v) -> bool {
38     return staticallyNonNegative(analysis, v);
39   };
40   root->walk([&nonNegativePred, &toRewrite](Operation *orig) {
41     if (isa<Ts...>(orig) &&
42         llvm::all_of(orig->getOperands(), nonNegativePred) &&
43         llvm::all_of(orig->getResults(), nonNegativePred)) {
44       toRewrite.push_back(orig);
45     }
46   });
47 }
48 
49 static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
50   switch (pred) {
51   case CmpIPredicate::sle:
52     return CmpIPredicate::ule;
53   case CmpIPredicate::slt:
54     return CmpIPredicate::ult;
55   case CmpIPredicate::sge:
56     return CmpIPredicate::uge;
57   case CmpIPredicate::sgt:
58     return CmpIPredicate::ugt;
59   default:
60     return pred;
61   }
62 }
63 
64 /// Find all cmpi ops that can be replaced by their unsigned equivalents.
65 static void getConvertableCmpi(Operation *root, OpList &toRewrite,
66                                IntRangeAnalysis &analysis) {
67   auto nonNegativePred = [&analysis](Value v) -> bool {
68     return staticallyNonNegative(analysis, v);
69   };
70   root->walk([&nonNegativePred, &toRewrite](arith::CmpIOp orig) {
71     CmpIPredicate pred = orig.getPredicate();
72     if (toUnsignedPred(pred) != pred &&
73         // i1 will spuriously and trivially show up as pontentially negative,
74         // so don't check the results
75         llvm::all_of(orig->getOperands(), nonNegativePred)) {
76       toRewrite.push_back(orig.getOperation());
77     }
78   });
79 }
80 
81 /// Return ops to be replaced in the order they should be rewritten.
82 static OpList getMatching(Operation *root, IntRangeAnalysis &analysis) {
83   OpList ret;
84   getConvertableOps<DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp, MinSIOp,
85                     MaxSIOp, ExtSIOp>(root, ret, analysis);
86   // Since these are in-place changes, they don't need to be topological order
87   // like the others.
88   getConvertableCmpi(root, ret, analysis);
89   return ret;
90 }
91 
92 template <typename T, typename U>
93 static void rewriteOp(Operation *op, OpBuilder &b) {
94   if (isa<T>(op)) {
95     OpBuilder::InsertionGuard guard(b);
96     b.setInsertionPoint(op);
97     Operation *newOp = b.create<U>(op->getLoc(), op->getResultTypes(),
98                                    op->getOperands(), op->getAttrs());
99     op->replaceAllUsesWith(newOp->getResults());
100     op->erase();
101   }
102 }
103 
104 static void rewriteCmpI(Operation *op, OpBuilder &b) {
105   if (auto cmpOp = dyn_cast<CmpIOp>(op)) {
106     cmpOp.setPredicateAttr(CmpIPredicateAttr::get(
107         b.getContext(), toUnsignedPred(cmpOp.getPredicate())));
108   }
109 }
110 
111 static void rewrite(Operation *root, const OpList &toReplace) {
112   OpBuilder b(root->getContext());
113   b.setInsertionPoint(root);
114   for (Operation *op : toReplace) {
115     rewriteOp<DivSIOp, DivUIOp>(op, b);
116     rewriteOp<CeilDivSIOp, CeilDivUIOp>(op, b);
117     rewriteOp<FloorDivSIOp, DivUIOp>(op, b);
118     rewriteOp<RemSIOp, RemUIOp>(op, b);
119     rewriteOp<MinSIOp, MinUIOp>(op, b);
120     rewriteOp<MaxSIOp, MaxUIOp>(op, b);
121     rewriteOp<ExtSIOp, ExtUIOp>(op, b);
122     rewriteCmpI(op, b);
123   }
124 }
125 
126 namespace {
127 struct ArithmeticUnsignedWhenEquivalentPass
128     : public ArithmeticUnsignedWhenEquivalentBase<
129           ArithmeticUnsignedWhenEquivalentPass> {
130   /// Implementation structure: first find all equivalent ops and collect them,
131   /// then perform all the rewrites in a second pass over the target op. This
132   /// ensures that analysis results are not invalidated during rewriting.
133   void runOnOperation() override {
134     Operation *op = getOperation();
135     IntRangeAnalysis analysis(op);
136     rewrite(op, getMatching(op, analysis));
137   }
138 };
139 } // end anonymous namespace
140 
141 std::unique_ptr<Pass>
142 mlir::arith::createArithmeticUnsignedWhenEquivalentPass() {
143   return std::make_unique<ArithmeticUnsignedWhenEquivalentPass>();
144 }
145