1 //===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with 2 // unsigned 3 // ones when all their arguments and results are statically non-negative --===// 4 // 5 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 6 // See https://llvm.org/LICENSE.txt for license information. 7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 8 // 9 //===----------------------------------------------------------------------===// 10 11 #include "PassDetail.h" 12 #include "mlir/Analysis/IntRangeAnalysis.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 17 using namespace mlir; 18 using namespace mlir::arith; 19 20 /// Succeeds when a value is statically non-negative in that it has a lower 21 /// bound on its value (if it is treated as signed) and that bound is 22 /// non-negative. 23 static LogicalResult staticallyNonNegative(IntRangeAnalysis &analysis, 24 Value v) { 25 Optional<ConstantIntRanges> result = analysis.getResult(v); 26 if (!result.has_value()) 27 return failure(); 28 const ConstantIntRanges &range = result.value(); 29 return success(range.smin().isNonNegative()); 30 } 31 32 /// Succeeds if an op can be converted to its unsigned equivalent without 33 /// changing its semantics. This is the case when none of its openands or 34 /// results can be below 0 when analyzed from a signed perspective. 35 static LogicalResult staticallyNonNegative(IntRangeAnalysis &analysis, 36 Operation *op) { 37 auto nonNegativePred = [&analysis](Value v) -> bool { 38 return succeeded(staticallyNonNegative(analysis, v)); 39 }; 40 return success(llvm::all_of(op->getOperands(), nonNegativePred) && 41 llvm::all_of(op->getResults(), nonNegativePred)); 42 } 43 44 /// Succeeds when the comparison predicate is a signed operation and all the 45 /// operands are non-negative, indicating that the cmpi operation `op` can have 46 /// its predicate changed to an unsigned equivalent. 47 static LogicalResult isCmpIConvertable(IntRangeAnalysis &analysis, CmpIOp op) { 48 CmpIPredicate pred = op.getPredicate(); 49 switch (pred) { 50 case CmpIPredicate::sle: 51 case CmpIPredicate::slt: 52 case CmpIPredicate::sge: 53 case CmpIPredicate::sgt: 54 return success(llvm::all_of(op.getOperands(), [&analysis](Value v) -> bool { 55 return succeeded(staticallyNonNegative(analysis, v)); 56 })); 57 default: 58 return failure(); 59 } 60 } 61 62 /// Return the unsigned equivalent of a signed comparison predicate, 63 /// or the predicate itself if there is none. 64 static CmpIPredicate toUnsignedPred(CmpIPredicate pred) { 65 switch (pred) { 66 case CmpIPredicate::sle: 67 return CmpIPredicate::ule; 68 case CmpIPredicate::slt: 69 return CmpIPredicate::ult; 70 case CmpIPredicate::sge: 71 return CmpIPredicate::uge; 72 case CmpIPredicate::sgt: 73 return CmpIPredicate::ugt; 74 default: 75 return pred; 76 } 77 } 78 79 namespace { 80 template <typename Signed, typename Unsigned> 81 struct ConvertOpToUnsigned : OpConversionPattern<Signed> { 82 using OpConversionPattern<Signed>::OpConversionPattern; 83 84 LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor, 85 ConversionPatternRewriter &rw) const override { 86 rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(), 87 adaptor.getOperands(), op->getAttrs()); 88 return success(); 89 } 90 }; 91 92 struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> { 93 using OpConversionPattern<CmpIOp>::OpConversionPattern; 94 95 LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor, 96 ConversionPatternRewriter &rw) const override { 97 rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()), 98 op.getLhs(), op.getRhs()); 99 return success(); 100 } 101 }; 102 103 struct ArithmeticUnsignedWhenEquivalentPass 104 : public ArithmeticUnsignedWhenEquivalentBase< 105 ArithmeticUnsignedWhenEquivalentPass> { 106 /// Implementation structure: first find all equivalent ops and collect them, 107 /// then perform all the rewrites in a second pass over the target op. This 108 /// ensures that analysis results are not invalidated during rewriting. 109 void runOnOperation() override { 110 Operation *op = getOperation(); 111 MLIRContext *ctx = op->getContext(); 112 IntRangeAnalysis analysis(op); 113 114 ConversionTarget target(*ctx); 115 target.addLegalDialect<ArithmeticDialect>(); 116 target 117 .addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, CeilDivUIOp, FloorDivSIOp, 118 RemSIOp, MinSIOp, MaxSIOp, ExtSIOp>( 119 [&analysis](Operation *op) -> Optional<bool> { 120 return failed(staticallyNonNegative(analysis, op)); 121 }); 122 target.addDynamicallyLegalOp<CmpIOp>( 123 [&analysis](CmpIOp op) -> Optional<bool> { 124 return failed(isCmpIConvertable(analysis, op)); 125 }); 126 127 RewritePatternSet patterns(ctx); 128 patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>, 129 ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>, 130 ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>, 131 ConvertOpToUnsigned<RemSIOp, RemUIOp>, 132 ConvertOpToUnsigned<MinSIOp, MinUIOp>, 133 ConvertOpToUnsigned<MaxSIOp, MaxUIOp>, 134 ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>( 135 ctx); 136 137 if (failed(applyPartialConversion(op, target, std::move(patterns)))) { 138 signalPassFailure(); 139 } 140 } 141 }; 142 } // end anonymous namespace 143 144 std::unique_ptr<Pass> 145 mlir::arith::createArithmeticUnsignedWhenEquivalentPass() { 146 return std::make_unique<ArithmeticUnsignedWhenEquivalentPass>(); 147 } 148