1 //===- ExpandOps.cpp - Pass to legalize Arithmetic ops for LLVM lowering --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" 12 #include "mlir/IR/TypeUtilities.h" 13 #include "mlir/Transforms/DialectConversion.h" 14 15 using namespace mlir; 16 17 /// Create an integer or index constant. 18 static Value createConst(Location loc, Type type, int value, 19 PatternRewriter &rewriter) { 20 return rewriter.create<arith::ConstantOp>( 21 loc, rewriter.getIntegerAttr(type, value)); 22 } 23 24 namespace { 25 26 /// Expands CeilDivUIOp (n, m) into 27 /// n == 0 ? 0 : ((n-1) / m) + 1 28 struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> { 29 using OpRewritePattern::OpRewritePattern; 30 LogicalResult matchAndRewrite(arith::CeilDivUIOp op, 31 PatternRewriter &rewriter) const final { 32 Location loc = op.getLoc(); 33 Value a = op.getLhs(); 34 Value b = op.getRhs(); 35 Value zero = createConst(loc, a.getType(), 0, rewriter); 36 Value compare = 37 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero); 38 Value one = createConst(loc, a.getType(), 1, rewriter); 39 Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one); 40 Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b); 41 Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one); 42 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne); 43 return success(); 44 } 45 }; 46 47 /// Expands CeilDivSIOp (n, m) into 48 /// 1) x = (m > 0) ? -1 : 1 49 /// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m) 50 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> { 51 using OpRewritePattern::OpRewritePattern; 52 LogicalResult matchAndRewrite(arith::CeilDivSIOp op, 53 PatternRewriter &rewriter) const final { 54 Location loc = op.getLoc(); 55 Type type = op.getType(); 56 Value a = op.getLhs(); 57 Value b = op.getRhs(); 58 Value plusOne = createConst(loc, type, 1, rewriter); 59 Value zero = createConst(loc, type, 0, rewriter); 60 Value minusOne = createConst(loc, type, -1, rewriter); 61 // Compute x = (b>0) ? -1 : 1. 62 Value compare = 63 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); 64 Value x = rewriter.create<arith::SelectOp>(loc, compare, minusOne, plusOne); 65 // Compute positive res: 1 + ((x+a)/b). 66 Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a); 67 Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b); 68 Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB); 69 // Compute negative res: - ((-a)/b). 70 Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a); 71 Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b); 72 Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB); 73 // Result is (a*b>0) ? pos result : neg result. 74 // Note, we want to avoid using a*b because of possible overflow. 75 // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do 76 // not particuliarly care if a*b<0 is true or false when b is zero 77 // as this will result in an illegal divide. So `a*b<0` can be reformulated 78 // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'. 79 // We pick the first expression here. 80 Value aNeg = 81 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero); 82 Value aPos = 83 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero); 84 Value bNeg = 85 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); 86 Value bPos = 87 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); 88 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg); 89 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos); 90 Value compareRes = 91 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm); 92 // Perform substitution and return success. 93 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes, 94 negRes); 95 return success(); 96 } 97 }; 98 99 /// Expands FloorDivSIOp (n, m) into 100 /// 1) x = (m<0) ? 1 : -1 101 /// 2) return (n*m<0) ? - ((-n+x) / m) -1 : n / m 102 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> { 103 using OpRewritePattern::OpRewritePattern; 104 LogicalResult matchAndRewrite(arith::FloorDivSIOp op, 105 PatternRewriter &rewriter) const final { 106 Location loc = op.getLoc(); 107 Type type = op.getType(); 108 Value a = op.getLhs(); 109 Value b = op.getRhs(); 110 Value plusOne = createConst(loc, type, 1, rewriter); 111 Value zero = createConst(loc, type, 0, rewriter); 112 Value minusOne = createConst(loc, type, -1, rewriter); 113 // Compute x = (b<0) ? 1 : -1. 114 Value compare = 115 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); 116 Value x = rewriter.create<arith::SelectOp>(loc, compare, plusOne, minusOne); 117 // Compute negative res: -1 - ((x-a)/b). 118 Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a); 119 Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b); 120 Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB); 121 // Compute positive res: a/b. 122 Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b); 123 // Result is (a*b<0) ? negative result : positive result. 124 // Note, we want to avoid using a*b because of possible overflow. 125 // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do 126 // not particuliarly care if a*b<0 is true or false when b is zero 127 // as this will result in an illegal divide. So `a*b<0` can be reformulated 128 // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'. 129 // We pick the first expression here. 130 Value aNeg = 131 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero); 132 Value aPos = 133 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero); 134 Value bNeg = 135 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); 136 Value bPos = 137 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); 138 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos); 139 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg); 140 Value compareRes = 141 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm); 142 // Perform substitution and return success. 143 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, negRes, 144 posRes); 145 return success(); 146 } 147 }; 148 149 template <typename OpTy, arith::CmpFPredicate pred> 150 struct MaxMinFOpConverter : public OpRewritePattern<OpTy> { 151 public: 152 using OpRewritePattern<OpTy>::OpRewritePattern; 153 154 LogicalResult matchAndRewrite(OpTy op, 155 PatternRewriter &rewriter) const final { 156 Value lhs = op.getLhs(); 157 Value rhs = op.getRhs(); 158 159 Location loc = op.getLoc(); 160 // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs'). 161 static_assert(pred == arith::CmpFPredicate::UGT || 162 pred == arith::CmpFPredicate::ULT, 163 "pred must be either UGT or ULT"); 164 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs); 165 Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs); 166 167 // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'. 168 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO, 169 rhs, rhs); 170 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select); 171 return success(); 172 } 173 }; 174 175 template <typename OpTy, arith::CmpIPredicate pred> 176 struct MaxMinIOpConverter : public OpRewritePattern<OpTy> { 177 public: 178 using OpRewritePattern<OpTy>::OpRewritePattern; 179 LogicalResult matchAndRewrite(OpTy op, 180 PatternRewriter &rewriter) const final { 181 Value lhs = op.getLhs(); 182 Value rhs = op.getRhs(); 183 184 Location loc = op.getLoc(); 185 Value cmp = rewriter.create<arith::CmpIOp>(loc, pred, lhs, rhs); 186 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs); 187 return success(); 188 } 189 }; 190 191 struct ArithmeticExpandOpsPass 192 : public ArithmeticExpandOpsBase<ArithmeticExpandOpsPass> { 193 void runOnOperation() override { 194 RewritePatternSet patterns(&getContext()); 195 ConversionTarget target(getContext()); 196 197 arith::populateArithmeticExpandOpsPatterns(patterns); 198 199 target.addLegalDialect<arith::ArithmeticDialect>(); 200 // clang-format off 201 target.addIllegalOp< 202 arith::CeilDivSIOp, 203 arith::CeilDivUIOp, 204 arith::FloorDivSIOp, 205 arith::MaxFOp, 206 arith::MaxSIOp, 207 arith::MaxUIOp, 208 arith::MinFOp, 209 arith::MinSIOp, 210 arith::MinUIOp 211 >(); 212 // clang-format on 213 if (failed(applyPartialConversion(getOperation(), target, 214 std::move(patterns)))) 215 signalPassFailure(); 216 } 217 }; 218 219 } // namespace 220 221 void mlir::arith::populateArithmeticExpandOpsPatterns( 222 RewritePatternSet &patterns) { 223 // clang-format off 224 patterns.add< 225 CeilDivSIOpConverter, 226 CeilDivUIOpConverter, 227 FloorDivSIOpConverter, 228 MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::UGT>, 229 MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::ULT>, 230 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>, 231 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>, 232 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>, 233 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult> 234 >(patterns.getContext()); 235 // clang-format on 236 } 237 238 std::unique_ptr<Pass> mlir::arith::createArithmeticExpandOpsPass() { 239 return std::make_unique<ArithmeticExpandOpsPass>(); 240 } 241