1 //===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with 2 // unsigned 3 // ones when all their arguments and results are statically non-negative --===// 4 // 5 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 6 // See https://llvm.org/LICENSE.txt for license information. 7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 8 // 9 //===----------------------------------------------------------------------===// 10 11 #include "PassDetail.h" 12 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 13 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 18 using namespace mlir; 19 using namespace mlir::arith; 20 using namespace mlir::dataflow; 21 22 /// Succeeds when a value is statically non-negative in that it has a lower 23 /// bound on its value (if it is treated as signed) and that bound is 24 /// non-negative. 25 static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) { 26 auto *result = solver.lookupState<IntegerValueRangeLattice>(v); 27 if (!result) 28 return failure(); 29 const ConstantIntRanges &range = result->getValue().getValue(); 30 return success(range.smin().isNonNegative()); 31 } 32 33 /// Succeeds if an op can be converted to its unsigned equivalent without 34 /// changing its semantics. This is the case when none of its openands or 35 /// results can be below 0 when analyzed from a signed perspective. 36 static LogicalResult staticallyNonNegative(DataFlowSolver &solver, 37 Operation *op) { 38 auto nonNegativePred = [&solver](Value v) -> bool { 39 return succeeded(staticallyNonNegative(solver, v)); 40 }; 41 return success(llvm::all_of(op->getOperands(), nonNegativePred) && 42 llvm::all_of(op->getResults(), nonNegativePred)); 43 } 44 45 /// Succeeds when the comparison predicate is a signed operation and all the 46 /// operands are non-negative, indicating that the cmpi operation `op` can have 47 /// its predicate changed to an unsigned equivalent. 48 static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op) { 49 CmpIPredicate pred = op.getPredicate(); 50 switch (pred) { 51 case CmpIPredicate::sle: 52 case CmpIPredicate::slt: 53 case CmpIPredicate::sge: 54 case CmpIPredicate::sgt: 55 return success(llvm::all_of(op.getOperands(), [&solver](Value v) -> bool { 56 return succeeded(staticallyNonNegative(solver, v)); 57 })); 58 default: 59 return failure(); 60 } 61 } 62 63 /// Return the unsigned equivalent of a signed comparison predicate, 64 /// or the predicate itself if there is none. 65 static CmpIPredicate toUnsignedPred(CmpIPredicate pred) { 66 switch (pred) { 67 case CmpIPredicate::sle: 68 return CmpIPredicate::ule; 69 case CmpIPredicate::slt: 70 return CmpIPredicate::ult; 71 case CmpIPredicate::sge: 72 return CmpIPredicate::uge; 73 case CmpIPredicate::sgt: 74 return CmpIPredicate::ugt; 75 default: 76 return pred; 77 } 78 } 79 80 namespace { 81 template <typename Signed, typename Unsigned> 82 struct ConvertOpToUnsigned : OpConversionPattern<Signed> { 83 using OpConversionPattern<Signed>::OpConversionPattern; 84 85 LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor, 86 ConversionPatternRewriter &rw) const override { 87 rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(), 88 adaptor.getOperands(), op->getAttrs()); 89 return success(); 90 } 91 }; 92 93 struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> { 94 using OpConversionPattern<CmpIOp>::OpConversionPattern; 95 96 LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor, 97 ConversionPatternRewriter &rw) const override { 98 rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()), 99 op.getLhs(), op.getRhs()); 100 return success(); 101 } 102 }; 103 104 struct ArithmeticUnsignedWhenEquivalentPass 105 : public ArithmeticUnsignedWhenEquivalentBase< 106 ArithmeticUnsignedWhenEquivalentPass> { 107 /// Implementation structure: first find all equivalent ops and collect them, 108 /// then perform all the rewrites in a second pass over the target op. This 109 /// ensures that analysis results are not invalidated during rewriting. 110 void runOnOperation() override { 111 Operation *op = getOperation(); 112 MLIRContext *ctx = op->getContext(); 113 DataFlowSolver solver; 114 solver.load<DeadCodeAnalysis>(); 115 solver.load<IntegerRangeAnalysis>(); 116 if (failed(solver.initializeAndRun(op))) 117 return signalPassFailure(); 118 119 ConversionTarget target(*ctx); 120 target.addLegalDialect<ArithmeticDialect>(); 121 target 122 .addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, CeilDivUIOp, FloorDivSIOp, 123 RemSIOp, MinSIOp, MaxSIOp, ExtSIOp>( 124 [&solver](Operation *op) -> Optional<bool> { 125 return failed(staticallyNonNegative(solver, op)); 126 }); 127 target.addDynamicallyLegalOp<CmpIOp>( 128 [&solver](CmpIOp op) -> Optional<bool> { 129 return failed(isCmpIConvertable(solver, op)); 130 }); 131 132 RewritePatternSet patterns(ctx); 133 patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>, 134 ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>, 135 ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>, 136 ConvertOpToUnsigned<RemSIOp, RemUIOp>, 137 ConvertOpToUnsigned<MinSIOp, MinUIOp>, 138 ConvertOpToUnsigned<MaxSIOp, MaxUIOp>, 139 ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>( 140 ctx); 141 142 if (failed(applyPartialConversion(op, target, std::move(patterns)))) { 143 signalPassFailure(); 144 } 145 } 146 }; 147 } // end anonymous namespace 148 149 std::unique_ptr<Pass> 150 mlir::arith::createArithmeticUnsignedWhenEquivalentPass() { 151 return std::make_unique<ArithmeticUnsignedWhenEquivalentPass>(); 152 } 153