1 //===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with 2 // unsigned 3 // ones when all their arguments and results are statically non-negative --===// 4 // 5 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 6 // See https://llvm.org/LICENSE.txt for license information. 7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 8 // 9 //===----------------------------------------------------------------------===// 10 11 #include "PassDetail.h" 12 #include "mlir/Analysis/IntRangeAnalysis.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" 15 16 using namespace mlir; 17 using namespace mlir::arith; 18 19 using OpList = llvm::SmallVector<Operation *>; 20 21 /// Returns true when a value is statically non-negative in that it has a lower 22 /// bound on its value (if it is treated as signed) and that bound is 23 /// non-negative. 24 static bool staticallyNonNegative(IntRangeAnalysis &analysis, Value v) { 25 Optional<ConstantIntRanges> result = analysis.getResult(v); 26 if (!result.hasValue()) 27 return false; 28 const ConstantIntRanges &range = result.getValue(); 29 return (range.smin().isNonNegative()); 30 } 31 32 /// Identify all operations in a block that have signed equivalents and have 33 /// operands and results that are statically non-negative. 34 template <typename... Ts> 35 static void getConvertableOps(Operation *root, OpList &toRewrite, 36 IntRangeAnalysis &analysis) { 37 auto nonNegativePred = [&analysis](Value v) -> bool { 38 return staticallyNonNegative(analysis, v); 39 }; 40 root->walk([&nonNegativePred, &toRewrite](Operation *orig) { 41 if (isa<Ts...>(orig) && 42 llvm::all_of(orig->getOperands(), nonNegativePred) && 43 llvm::all_of(orig->getResults(), nonNegativePred)) { 44 toRewrite.push_back(orig); 45 } 46 }); 47 } 48 49 static CmpIPredicate toUnsignedPred(CmpIPredicate pred) { 50 switch (pred) { 51 case CmpIPredicate::sle: 52 return CmpIPredicate::ule; 53 case CmpIPredicate::slt: 54 return CmpIPredicate::ult; 55 case CmpIPredicate::sge: 56 return CmpIPredicate::uge; 57 case CmpIPredicate::sgt: 58 return CmpIPredicate::ugt; 59 default: 60 return pred; 61 } 62 } 63 64 /// Find all cmpi ops that can be replaced by their unsigned equivalents. 65 static void getConvertableCmpi(Operation *root, OpList &toRewrite, 66 IntRangeAnalysis &analysis) { 67 auto nonNegativePred = [&analysis](Value v) -> bool { 68 return staticallyNonNegative(analysis, v); 69 }; 70 root->walk([&nonNegativePred, &toRewrite](arith::CmpIOp orig) { 71 CmpIPredicate pred = orig.getPredicate(); 72 if (toUnsignedPred(pred) != pred && 73 // i1 will spuriously and trivially show up as pontentially negative, 74 // so don't check the results 75 llvm::all_of(orig->getOperands(), nonNegativePred)) { 76 toRewrite.push_back(orig.getOperation()); 77 } 78 }); 79 } 80 81 /// Return ops to be replaced in the order they should be rewritten. 82 static OpList getMatching(Operation *root, IntRangeAnalysis &analysis) { 83 OpList ret; 84 getConvertableOps<DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp, MinSIOp, 85 MaxSIOp, ExtSIOp>(root, ret, analysis); 86 // Since these are in-place changes, they don't need to be topological order 87 // like the others. 88 getConvertableCmpi(root, ret, analysis); 89 return ret; 90 } 91 92 template <typename T, typename U> 93 static bool rewriteOp(Operation *op, OpBuilder &b) { 94 if (isa<T>(op)) { 95 OpBuilder::InsertionGuard guard(b); 96 b.setInsertionPoint(op); 97 Operation *newOp = b.create<U>(op->getLoc(), op->getResultTypes(), 98 op->getOperands(), op->getAttrs()); 99 op->replaceAllUsesWith(newOp->getResults()); 100 op->erase(); 101 return true; 102 } 103 return false; 104 } 105 106 static bool rewriteCmpI(Operation *op, OpBuilder &b) { 107 if (auto cmpOp = dyn_cast<CmpIOp>(op)) { 108 cmpOp.setPredicateAttr(CmpIPredicateAttr::get( 109 b.getContext(), toUnsignedPred(cmpOp.getPredicate()))); 110 return true; 111 } 112 return false; 113 } 114 115 static void rewrite(Operation *root, const OpList &toReplace) { 116 OpBuilder b(root->getContext()); 117 b.setInsertionPoint(root); 118 for (Operation *op : toReplace) { 119 rewriteOp<DivSIOp, DivUIOp>(op, b) || 120 rewriteOp<CeilDivSIOp, CeilDivUIOp>(op, b) || 121 rewriteOp<FloorDivSIOp, DivUIOp>(op, b) || 122 rewriteOp<RemSIOp, RemUIOp>(op, b) || 123 rewriteOp<MinSIOp, MinUIOp>(op, b) || 124 rewriteOp<MaxSIOp, MaxUIOp>(op, b) || 125 rewriteOp<ExtSIOp, ExtUIOp>(op, b) || rewriteCmpI(op, b); 126 } 127 } 128 129 namespace { 130 struct ArithmeticUnsignedWhenEquivalentPass 131 : public ArithmeticUnsignedWhenEquivalentBase< 132 ArithmeticUnsignedWhenEquivalentPass> { 133 /// Implementation structure: first find all equivalent ops and collect them, 134 /// then perform all the rewrites in a second pass over the target op. This 135 /// ensures that analysis results are not invalidated during rewriting. 136 void runOnOperation() override { 137 Operation *op = getOperation(); 138 IntRangeAnalysis analysis(op); 139 rewrite(op, getMatching(op, analysis)); 140 } 141 }; 142 } // end anonymous namespace 143 144 std::unique_ptr<Pass> 145 mlir::arith::createArithmeticUnsignedWhenEquivalentPass() { 146 return std::make_unique<ArithmeticUnsignedWhenEquivalentPass>(); 147 } 148