1 //===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with
2 // unsigned
3 // ones when all their arguments and results are statically non-negative --===//
4 //
5 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6 // See https://llvm.org/LICENSE.txt for license information.
7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #include "PassDetail.h"
12 #include "mlir/Analysis/IntRangeAnalysis.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 
17 using namespace mlir;
18 using namespace mlir::arith;
19 
20 /// Succeeds when a value is statically non-negative in that it has a lower
21 /// bound on its value (if it is treated as signed) and that bound is
22 /// non-negative.
23 static LogicalResult staticallyNonNegative(IntRangeAnalysis &analysis,
24                                            Value v) {
25   Optional<ConstantIntRanges> result = analysis.getResult(v);
26   if (!result.has_value())
27     return failure();
28   const ConstantIntRanges &range = result.value();
29   return success(range.smin().isNonNegative());
30 }
31 
32 /// Succeeds if an op can be converted to its unsigned equivalent without
33 /// changing its semantics. This is the case when none of its openands or
34 /// results can be below 0 when analyzed from a signed perspective.
35 static LogicalResult staticallyNonNegative(IntRangeAnalysis &analysis,
36                                            Operation *op) {
37   auto nonNegativePred = [&analysis](Value v) -> bool {
38     return succeeded(staticallyNonNegative(analysis, v));
39   };
40   return success(llvm::all_of(op->getOperands(), nonNegativePred) &&
41                  llvm::all_of(op->getResults(), nonNegativePred));
42 }
43 
44 /// Succeeds when the comparison predicate is a signed operation and all the
45 /// operands are non-negative, indicating that the cmpi operation `op` can have
46 /// its predicate changed to an unsigned equivalent.
47 static LogicalResult isCmpIConvertable(IntRangeAnalysis &analysis, CmpIOp op) {
48   CmpIPredicate pred = op.getPredicate();
49   switch (pred) {
50   case CmpIPredicate::sle:
51   case CmpIPredicate::slt:
52   case CmpIPredicate::sge:
53   case CmpIPredicate::sgt:
54     return success(llvm::all_of(op.getOperands(), [&analysis](Value v) -> bool {
55       return succeeded(staticallyNonNegative(analysis, v));
56     }));
57   default:
58     return failure();
59   }
60 }
61 
62 /// Return the unsigned equivalent of a signed comparison predicate,
63 /// or the predicate itself if there is none.
64 static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
65   switch (pred) {
66   case CmpIPredicate::sle:
67     return CmpIPredicate::ule;
68   case CmpIPredicate::slt:
69     return CmpIPredicate::ult;
70   case CmpIPredicate::sge:
71     return CmpIPredicate::uge;
72   case CmpIPredicate::sgt:
73     return CmpIPredicate::ugt;
74   default:
75     return pred;
76   }
77 }
78 
79 namespace {
80 template <typename Signed, typename Unsigned>
81 struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
82   using OpConversionPattern<Signed>::OpConversionPattern;
83 
84   LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor,
85                                 ConversionPatternRewriter &rw) const override {
86     rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(),
87                                     adaptor.getOperands(), op->getAttrs());
88     return success();
89   }
90 };
91 
92 struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> {
93   using OpConversionPattern<CmpIOp>::OpConversionPattern;
94 
95   LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor,
96                                 ConversionPatternRewriter &rw) const override {
97     rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
98                                   op.getLhs(), op.getRhs());
99     return success();
100   }
101 };
102 
103 struct ArithmeticUnsignedWhenEquivalentPass
104     : public ArithmeticUnsignedWhenEquivalentBase<
105           ArithmeticUnsignedWhenEquivalentPass> {
106   /// Implementation structure: first find all equivalent ops and collect them,
107   /// then perform all the rewrites in a second pass over the target op. This
108   /// ensures that analysis results are not invalidated during rewriting.
109   void runOnOperation() override {
110     Operation *op = getOperation();
111     MLIRContext *ctx = op->getContext();
112     IntRangeAnalysis analysis(op);
113 
114     ConversionTarget target(*ctx);
115     target.addLegalDialect<ArithmeticDialect>();
116     target
117         .addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, CeilDivUIOp, FloorDivSIOp,
118                                RemSIOp, MinSIOp, MaxSIOp, ExtSIOp>(
119             [&analysis](Operation *op) -> Optional<bool> {
120               return failed(staticallyNonNegative(analysis, op));
121             });
122     target.addDynamicallyLegalOp<CmpIOp>(
123         [&analysis](CmpIOp op) -> Optional<bool> {
124           return failed(isCmpIConvertable(analysis, op));
125         });
126 
127     RewritePatternSet patterns(ctx);
128     patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
129                  ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
130                  ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
131                  ConvertOpToUnsigned<RemSIOp, RemUIOp>,
132                  ConvertOpToUnsigned<MinSIOp, MinUIOp>,
133                  ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
134                  ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
135         ctx);
136 
137     if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
138       signalPassFailure();
139     }
140   }
141 };
142 } // end anonymous namespace
143 
144 std::unique_ptr<Pass>
145 mlir::arith::createArithmeticUnsignedWhenEquivalentPass() {
146   return std::make_unique<ArithmeticUnsignedWhenEquivalentPass>();
147 }
148