1b0b00432SKrzysztof Drewniak //===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with
2b0b00432SKrzysztof Drewniak // unsigned
3b0b00432SKrzysztof Drewniak // ones when all their arguments and results are statically non-negative --===//
4b0b00432SKrzysztof Drewniak //
5b0b00432SKrzysztof Drewniak // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6b0b00432SKrzysztof Drewniak // See https://llvm.org/LICENSE.txt for license information.
7b0b00432SKrzysztof Drewniak // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8b0b00432SKrzysztof Drewniak //
9b0b00432SKrzysztof Drewniak //===----------------------------------------------------------------------===//
10b0b00432SKrzysztof Drewniak 
11b0b00432SKrzysztof Drewniak #include "PassDetail.h"
12*ab701975SMogball #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
13*ab701975SMogball #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
14b0b00432SKrzysztof Drewniak #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15b0b00432SKrzysztof Drewniak #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
16e49ae628SKrzysztof Drewniak #include "mlir/Transforms/DialectConversion.h"
17b0b00432SKrzysztof Drewniak 
18b0b00432SKrzysztof Drewniak using namespace mlir;
19b0b00432SKrzysztof Drewniak using namespace mlir::arith;
20*ab701975SMogball using namespace mlir::dataflow;
21b0b00432SKrzysztof Drewniak 
22e49ae628SKrzysztof Drewniak /// Succeeds when a value is statically non-negative in that it has a lower
23b0b00432SKrzysztof Drewniak /// bound on its value (if it is treated as signed) and that bound is
24b0b00432SKrzysztof Drewniak /// non-negative.
staticallyNonNegative(DataFlowSolver & solver,Value v)25*ab701975SMogball static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) {
26*ab701975SMogball   auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
27*ab701975SMogball   if (!result)
28e49ae628SKrzysztof Drewniak     return failure();
29*ab701975SMogball   const ConstantIntRanges &range = result->getValue().getValue();
30e49ae628SKrzysztof Drewniak   return success(range.smin().isNonNegative());
31b0b00432SKrzysztof Drewniak }
32b0b00432SKrzysztof Drewniak 
33e49ae628SKrzysztof Drewniak /// Succeeds if an op can be converted to its unsigned equivalent without
34e49ae628SKrzysztof Drewniak /// changing its semantics. This is the case when none of its openands or
35e49ae628SKrzysztof Drewniak /// results can be below 0 when analyzed from a signed perspective.
staticallyNonNegative(DataFlowSolver & solver,Operation * op)36*ab701975SMogball static LogicalResult staticallyNonNegative(DataFlowSolver &solver,
37e49ae628SKrzysztof Drewniak                                            Operation *op) {
38*ab701975SMogball   auto nonNegativePred = [&solver](Value v) -> bool {
39*ab701975SMogball     return succeeded(staticallyNonNegative(solver, v));
40b0b00432SKrzysztof Drewniak   };
41e49ae628SKrzysztof Drewniak   return success(llvm::all_of(op->getOperands(), nonNegativePred) &&
42e49ae628SKrzysztof Drewniak                  llvm::all_of(op->getResults(), nonNegativePred));
43b0b00432SKrzysztof Drewniak }
44b0b00432SKrzysztof Drewniak 
45e49ae628SKrzysztof Drewniak /// Succeeds when the comparison predicate is a signed operation and all the
46e49ae628SKrzysztof Drewniak /// operands are non-negative, indicating that the cmpi operation `op` can have
47e49ae628SKrzysztof Drewniak /// its predicate changed to an unsigned equivalent.
isCmpIConvertable(DataFlowSolver & solver,CmpIOp op)48*ab701975SMogball static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op) {
49e49ae628SKrzysztof Drewniak   CmpIPredicate pred = op.getPredicate();
50e49ae628SKrzysztof Drewniak   switch (pred) {
51e49ae628SKrzysztof Drewniak   case CmpIPredicate::sle:
52e49ae628SKrzysztof Drewniak   case CmpIPredicate::slt:
53e49ae628SKrzysztof Drewniak   case CmpIPredicate::sge:
54e49ae628SKrzysztof Drewniak   case CmpIPredicate::sgt:
55*ab701975SMogball     return success(llvm::all_of(op.getOperands(), [&solver](Value v) -> bool {
56*ab701975SMogball       return succeeded(staticallyNonNegative(solver, v));
57e49ae628SKrzysztof Drewniak     }));
58e49ae628SKrzysztof Drewniak   default:
59e49ae628SKrzysztof Drewniak     return failure();
60e49ae628SKrzysztof Drewniak   }
61e49ae628SKrzysztof Drewniak }
62e49ae628SKrzysztof Drewniak 
63e49ae628SKrzysztof Drewniak /// Return the unsigned equivalent of a signed comparison predicate,
64e49ae628SKrzysztof Drewniak /// or the predicate itself if there is none.
toUnsignedPred(CmpIPredicate pred)65b0b00432SKrzysztof Drewniak static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
66b0b00432SKrzysztof Drewniak   switch (pred) {
67b0b00432SKrzysztof Drewniak   case CmpIPredicate::sle:
68b0b00432SKrzysztof Drewniak     return CmpIPredicate::ule;
69b0b00432SKrzysztof Drewniak   case CmpIPredicate::slt:
70b0b00432SKrzysztof Drewniak     return CmpIPredicate::ult;
71b0b00432SKrzysztof Drewniak   case CmpIPredicate::sge:
72b0b00432SKrzysztof Drewniak     return CmpIPredicate::uge;
73b0b00432SKrzysztof Drewniak   case CmpIPredicate::sgt:
74b0b00432SKrzysztof Drewniak     return CmpIPredicate::ugt;
75b0b00432SKrzysztof Drewniak   default:
76b0b00432SKrzysztof Drewniak     return pred;
77b0b00432SKrzysztof Drewniak   }
78b0b00432SKrzysztof Drewniak }
79b0b00432SKrzysztof Drewniak 
80b0b00432SKrzysztof Drewniak namespace {
81e49ae628SKrzysztof Drewniak template <typename Signed, typename Unsigned>
82e49ae628SKrzysztof Drewniak struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
83e49ae628SKrzysztof Drewniak   using OpConversionPattern<Signed>::OpConversionPattern;
84e49ae628SKrzysztof Drewniak 
matchAndRewrite__anon8eba27220311::ConvertOpToUnsigned85e49ae628SKrzysztof Drewniak   LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor,
86e49ae628SKrzysztof Drewniak                                 ConversionPatternRewriter &rw) const override {
87e49ae628SKrzysztof Drewniak     rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(),
88e49ae628SKrzysztof Drewniak                                     adaptor.getOperands(), op->getAttrs());
89e49ae628SKrzysztof Drewniak     return success();
90e49ae628SKrzysztof Drewniak   }
91e49ae628SKrzysztof Drewniak };
92e49ae628SKrzysztof Drewniak 
93e49ae628SKrzysztof Drewniak struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> {
94e49ae628SKrzysztof Drewniak   using OpConversionPattern<CmpIOp>::OpConversionPattern;
95e49ae628SKrzysztof Drewniak 
matchAndRewrite__anon8eba27220311::ConvertCmpIToUnsigned96e49ae628SKrzysztof Drewniak   LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor,
97e49ae628SKrzysztof Drewniak                                 ConversionPatternRewriter &rw) const override {
98e49ae628SKrzysztof Drewniak     rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
99e49ae628SKrzysztof Drewniak                                   op.getLhs(), op.getRhs());
100e49ae628SKrzysztof Drewniak     return success();
101e49ae628SKrzysztof Drewniak   }
102e49ae628SKrzysztof Drewniak };
103e49ae628SKrzysztof Drewniak 
104b0b00432SKrzysztof Drewniak struct ArithmeticUnsignedWhenEquivalentPass
105b0b00432SKrzysztof Drewniak     : public ArithmeticUnsignedWhenEquivalentBase<
106b0b00432SKrzysztof Drewniak           ArithmeticUnsignedWhenEquivalentPass> {
107b0b00432SKrzysztof Drewniak   /// Implementation structure: first find all equivalent ops and collect them,
108b0b00432SKrzysztof Drewniak   /// then perform all the rewrites in a second pass over the target op. This
109b0b00432SKrzysztof Drewniak   /// ensures that analysis results are not invalidated during rewriting.
runOnOperation__anon8eba27220311::ArithmeticUnsignedWhenEquivalentPass110b0b00432SKrzysztof Drewniak   void runOnOperation() override {
111b0b00432SKrzysztof Drewniak     Operation *op = getOperation();
112e49ae628SKrzysztof Drewniak     MLIRContext *ctx = op->getContext();
113*ab701975SMogball     DataFlowSolver solver;
114*ab701975SMogball     solver.load<DeadCodeAnalysis>();
115*ab701975SMogball     solver.load<IntegerRangeAnalysis>();
116*ab701975SMogball     if (failed(solver.initializeAndRun(op)))
117*ab701975SMogball       return signalPassFailure();
118e49ae628SKrzysztof Drewniak 
119e49ae628SKrzysztof Drewniak     ConversionTarget target(*ctx);
120e49ae628SKrzysztof Drewniak     target.addLegalDialect<ArithmeticDialect>();
121e49ae628SKrzysztof Drewniak     target
122e49ae628SKrzysztof Drewniak         .addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, CeilDivUIOp, FloorDivSIOp,
123e49ae628SKrzysztof Drewniak                                RemSIOp, MinSIOp, MaxSIOp, ExtSIOp>(
124*ab701975SMogball             [&solver](Operation *op) -> Optional<bool> {
125*ab701975SMogball               return failed(staticallyNonNegative(solver, op));
126e49ae628SKrzysztof Drewniak             });
127e49ae628SKrzysztof Drewniak     target.addDynamicallyLegalOp<CmpIOp>(
128*ab701975SMogball         [&solver](CmpIOp op) -> Optional<bool> {
129*ab701975SMogball           return failed(isCmpIConvertable(solver, op));
130e49ae628SKrzysztof Drewniak         });
131e49ae628SKrzysztof Drewniak 
132e49ae628SKrzysztof Drewniak     RewritePatternSet patterns(ctx);
133e49ae628SKrzysztof Drewniak     patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
134e49ae628SKrzysztof Drewniak                  ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
135e49ae628SKrzysztof Drewniak                  ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
136e49ae628SKrzysztof Drewniak                  ConvertOpToUnsigned<RemSIOp, RemUIOp>,
137e49ae628SKrzysztof Drewniak                  ConvertOpToUnsigned<MinSIOp, MinUIOp>,
138e49ae628SKrzysztof Drewniak                  ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
139e49ae628SKrzysztof Drewniak                  ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
140e49ae628SKrzysztof Drewniak         ctx);
141e49ae628SKrzysztof Drewniak 
142e49ae628SKrzysztof Drewniak     if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
143e49ae628SKrzysztof Drewniak       signalPassFailure();
144e49ae628SKrzysztof Drewniak     }
145b0b00432SKrzysztof Drewniak   }
146b0b00432SKrzysztof Drewniak };
147b0b00432SKrzysztof Drewniak } // end anonymous namespace
148b0b00432SKrzysztof Drewniak 
149b0b00432SKrzysztof Drewniak std::unique_ptr<Pass>
createArithmeticUnsignedWhenEquivalentPass()150b0b00432SKrzysztof Drewniak mlir::arith::createArithmeticUnsignedWhenEquivalentPass() {
151b0b00432SKrzysztof Drewniak   return std::make_unique<ArithmeticUnsignedWhenEquivalentPass>();
152b0b00432SKrzysztof Drewniak }
153