1 //===- ExpandOps.cpp - Pass to legalize Arithmetic ops for LLVM lowering --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" 12 #include "mlir/Dialect/StandardOps/IR/Ops.h" 13 #include "mlir/IR/TypeUtilities.h" 14 #include "mlir/Transforms/DialectConversion.h" 15 16 using namespace mlir; 17 18 /// Create an integer or index constant. 19 static Value createConst(Location loc, Type type, int value, 20 PatternRewriter &rewriter) { 21 return rewriter.create<arith::ConstantOp>( 22 loc, rewriter.getIntegerAttr(type, value)); 23 } 24 25 namespace { 26 27 /// Expands CeilDivUIOp (n, m) into 28 /// n == 0 ? 0 : ((n-1) / m) + 1 29 struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> { 30 using OpRewritePattern::OpRewritePattern; 31 LogicalResult matchAndRewrite(arith::CeilDivUIOp op, 32 PatternRewriter &rewriter) const final { 33 Location loc = op.getLoc(); 34 Value a = op.getLhs(); 35 Value b = op.getRhs(); 36 Value zero = createConst(loc, a.getType(), 0, rewriter); 37 Value compare = 38 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero); 39 Value one = createConst(loc, a.getType(), 1, rewriter); 40 Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one); 41 Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b); 42 Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one); 43 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne); 44 return success(); 45 } 46 }; 47 48 /// Expands CeilDivSIOp (n, m) into 49 /// 1) x = (m > 0) ? -1 : 1 50 /// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m) 51 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> { 52 using OpRewritePattern::OpRewritePattern; 53 LogicalResult matchAndRewrite(arith::CeilDivSIOp op, 54 PatternRewriter &rewriter) const final { 55 Location loc = op.getLoc(); 56 Type type = op.getType(); 57 Value a = op.getLhs(); 58 Value b = op.getRhs(); 59 Value plusOne = createConst(loc, type, 1, rewriter); 60 Value zero = createConst(loc, type, 0, rewriter); 61 Value minusOne = createConst(loc, type, -1, rewriter); 62 // Compute x = (b>0) ? -1 : 1. 63 Value compare = 64 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); 65 Value x = rewriter.create<arith::SelectOp>(loc, compare, minusOne, plusOne); 66 // Compute positive res: 1 + ((x+a)/b). 67 Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a); 68 Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b); 69 Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB); 70 // Compute negative res: - ((-a)/b). 71 Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a); 72 Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b); 73 Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB); 74 // Result is (a*b>0) ? pos result : neg result. 75 // Note, we want to avoid using a*b because of possible overflow. 76 // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do 77 // not particuliarly care if a*b<0 is true or false when b is zero 78 // as this will result in an illegal divide. So `a*b<0` can be reformulated 79 // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'. 80 // We pick the first expression here. 81 Value aNeg = 82 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero); 83 Value aPos = 84 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero); 85 Value bNeg = 86 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); 87 Value bPos = 88 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); 89 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg); 90 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos); 91 Value compareRes = 92 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm); 93 // Perform substitution and return success. 94 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes, 95 negRes); 96 return success(); 97 } 98 }; 99 100 /// Expands FloorDivSIOp (n, m) into 101 /// 1) x = (m<0) ? 1 : -1 102 /// 2) return (n*m<0) ? - ((-n+x) / m) -1 : n / m 103 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> { 104 using OpRewritePattern::OpRewritePattern; 105 LogicalResult matchAndRewrite(arith::FloorDivSIOp op, 106 PatternRewriter &rewriter) const final { 107 Location loc = op.getLoc(); 108 Type type = op.getType(); 109 Value a = op.getLhs(); 110 Value b = op.getRhs(); 111 Value plusOne = createConst(loc, type, 1, rewriter); 112 Value zero = createConst(loc, type, 0, rewriter); 113 Value minusOne = createConst(loc, type, -1, rewriter); 114 // Compute x = (b<0) ? 1 : -1. 115 Value compare = 116 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); 117 Value x = rewriter.create<arith::SelectOp>(loc, compare, plusOne, minusOne); 118 // Compute negative res: -1 - ((x-a)/b). 119 Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a); 120 Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b); 121 Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB); 122 // Compute positive res: a/b. 123 Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b); 124 // Result is (a*b<0) ? negative result : positive result. 125 // Note, we want to avoid using a*b because of possible overflow. 126 // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do 127 // not particuliarly care if a*b<0 is true or false when b is zero 128 // as this will result in an illegal divide. So `a*b<0` can be reformulated 129 // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'. 130 // We pick the first expression here. 131 Value aNeg = 132 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero); 133 Value aPos = 134 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero); 135 Value bNeg = 136 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); 137 Value bPos = 138 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); 139 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos); 140 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg); 141 Value compareRes = 142 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm); 143 // Perform substitution and return success. 144 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, negRes, 145 posRes); 146 return success(); 147 } 148 }; 149 150 template <typename OpTy, arith::CmpFPredicate pred> 151 struct MaxMinFOpConverter : public OpRewritePattern<OpTy> { 152 public: 153 using OpRewritePattern<OpTy>::OpRewritePattern; 154 155 LogicalResult matchAndRewrite(OpTy op, 156 PatternRewriter &rewriter) const final { 157 Value lhs = op.getLhs(); 158 Value rhs = op.getRhs(); 159 160 Location loc = op.getLoc(); 161 // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs'). 162 static_assert(pred == arith::CmpFPredicate::UGT || 163 pred == arith::CmpFPredicate::ULT, 164 "pred must be either UGT or ULT"); 165 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs); 166 Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs); 167 168 // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'. 169 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO, 170 rhs, rhs); 171 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select); 172 return success(); 173 } 174 }; 175 176 template <typename OpTy, arith::CmpIPredicate pred> 177 struct MaxMinIOpConverter : public OpRewritePattern<OpTy> { 178 public: 179 using OpRewritePattern<OpTy>::OpRewritePattern; 180 LogicalResult matchAndRewrite(OpTy op, 181 PatternRewriter &rewriter) const final { 182 Value lhs = op.getLhs(); 183 Value rhs = op.getRhs(); 184 185 Location loc = op.getLoc(); 186 Value cmp = rewriter.create<arith::CmpIOp>(loc, pred, lhs, rhs); 187 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs); 188 return success(); 189 } 190 }; 191 192 struct ArithmeticExpandOpsPass 193 : public ArithmeticExpandOpsBase<ArithmeticExpandOpsPass> { 194 void runOnOperation() override { 195 RewritePatternSet patterns(&getContext()); 196 ConversionTarget target(getContext()); 197 198 arith::populateArithmeticExpandOpsPatterns(patterns); 199 200 target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect>(); 201 // clang-format off 202 target.addIllegalOp< 203 arith::CeilDivSIOp, 204 arith::CeilDivUIOp, 205 arith::FloorDivSIOp, 206 arith::MaxFOp, 207 arith::MaxSIOp, 208 arith::MaxUIOp, 209 arith::MinFOp, 210 arith::MinSIOp, 211 arith::MinUIOp 212 >(); 213 // clang-format on 214 if (failed(applyPartialConversion(getOperation(), target, 215 std::move(patterns)))) 216 signalPassFailure(); 217 } 218 }; 219 220 } // namespace 221 222 void mlir::arith::populateArithmeticExpandOpsPatterns( 223 RewritePatternSet &patterns) { 224 // clang-format off 225 patterns.add< 226 CeilDivSIOpConverter, 227 CeilDivUIOpConverter, 228 FloorDivSIOpConverter, 229 MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::UGT>, 230 MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::ULT>, 231 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>, 232 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>, 233 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>, 234 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult> 235 >(patterns.getContext()); 236 // clang-format on 237 } 238 239 std::unique_ptr<Pass> mlir::arith::createArithmeticExpandOpsPass() { 240 return std::make_unique<ArithmeticExpandOpsPass>(); 241 } 242