1 //===- ExpandOps.cpp - Pass to legalize Arithmetic ops for LLVM lowering --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" 12 #include "mlir/Dialect/StandardOps/IR/Ops.h" 13 #include "mlir/IR/TypeUtilities.h" 14 #include "mlir/Transforms/DialectConversion.h" 15 16 using namespace mlir; 17 18 namespace { 19 20 /// Expands CeilDivUIOp (n, m) into 21 /// n == 0 ? 0 : ((n-1) / m) + 1 22 struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> { 23 using OpRewritePattern::OpRewritePattern; 24 LogicalResult matchAndRewrite(arith::CeilDivUIOp op, 25 PatternRewriter &rewriter) const final { 26 Location loc = op.getLoc(); 27 Value a = op.getLhs(); 28 Value b = op.getRhs(); 29 Value zero = rewriter.create<arith::ConstantOp>( 30 loc, rewriter.getIntegerAttr(a.getType(), 0)); 31 Value compare = 32 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero); 33 Value one = rewriter.create<arith::ConstantOp>( 34 loc, rewriter.getIntegerAttr(a.getType(), 1)); 35 Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one); 36 Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b); 37 Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one); 38 Value res = rewriter.create<SelectOp>(loc, compare, zero, plusOne); 39 rewriter.replaceOp(op, {res}); 40 return success(); 41 } 42 }; 43 44 /// Expands CeilDivSIOp (n, m) into 45 /// 1) x = (m > 0) ? -1 : 1 46 /// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m) 47 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> { 48 using OpRewritePattern::OpRewritePattern; 49 LogicalResult matchAndRewrite(arith::CeilDivSIOp op, 50 PatternRewriter &rewriter) const final { 51 Location loc = op.getLoc(); 52 auto signedCeilDivIOp = cast<arith::CeilDivSIOp>(op); 53 Type type = signedCeilDivIOp.getType(); 54 Value a = signedCeilDivIOp.getLhs(); 55 Value b = signedCeilDivIOp.getRhs(); 56 Value plusOne = rewriter.create<arith::ConstantOp>( 57 loc, rewriter.getIntegerAttr(type, 1)); 58 Value zero = rewriter.create<arith::ConstantOp>( 59 loc, rewriter.getIntegerAttr(type, 0)); 60 Value minusOne = rewriter.create<arith::ConstantOp>( 61 loc, rewriter.getIntegerAttr(type, -1)); 62 // Compute x = (b>0) ? -1 : 1. 63 Value compare = 64 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); 65 Value x = rewriter.create<SelectOp>(loc, compare, minusOne, plusOne); 66 // Compute positive res: 1 + ((x+a)/b). 67 Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a); 68 Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b); 69 Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB); 70 // Compute negative res: - ((-a)/b). 71 Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a); 72 Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b); 73 Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB); 74 // Result is (a*b>0) ? pos result : neg result. 75 // Note, we want to avoid using a*b because of possible overflow. 76 // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do 77 // not particuliarly care if a*b<0 is true or false when b is zero 78 // as this will result in an illegal divide. So `a*b<0` can be reformulated 79 // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'. 80 // We pick the first expression here. 81 Value aNeg = 82 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero); 83 Value aPos = 84 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero); 85 Value bNeg = 86 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); 87 Value bPos = 88 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); 89 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg); 90 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos); 91 Value compareRes = 92 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm); 93 Value res = rewriter.create<SelectOp>(loc, compareRes, posRes, negRes); 94 // Perform substitution and return success. 95 rewriter.replaceOp(op, {res}); 96 return success(); 97 } 98 }; 99 100 /// Expands FloorDivSIOp (n, m) into 101 /// 1) x = (m<0) ? 1 : -1 102 /// 2) return (n*m<0) ? - ((-n+x) / m) -1 : n / m 103 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> { 104 using OpRewritePattern::OpRewritePattern; 105 LogicalResult matchAndRewrite(arith::FloorDivSIOp op, 106 PatternRewriter &rewriter) const final { 107 Location loc = op.getLoc(); 108 arith::FloorDivSIOp signedFloorDivIOp = cast<arith::FloorDivSIOp>(op); 109 Type type = signedFloorDivIOp.getType(); 110 Value a = signedFloorDivIOp.getLhs(); 111 Value b = signedFloorDivIOp.getRhs(); 112 Value plusOne = rewriter.create<arith::ConstantOp>( 113 loc, rewriter.getIntegerAttr(type, 1)); 114 Value zero = rewriter.create<arith::ConstantOp>( 115 loc, rewriter.getIntegerAttr(type, 0)); 116 Value minusOne = rewriter.create<arith::ConstantOp>( 117 loc, rewriter.getIntegerAttr(type, -1)); 118 // Compute x = (b<0) ? 1 : -1. 119 Value compare = 120 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); 121 Value x = rewriter.create<SelectOp>(loc, compare, plusOne, minusOne); 122 // Compute negative res: -1 - ((x-a)/b). 123 Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a); 124 Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b); 125 Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB); 126 // Compute positive res: a/b. 127 Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b); 128 // Result is (a*b<0) ? negative result : positive result. 129 // Note, we want to avoid using a*b because of possible overflow. 130 // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do 131 // not particuliarly care if a*b<0 is true or false when b is zero 132 // as this will result in an illegal divide. So `a*b<0` can be reformulated 133 // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'. 134 // We pick the first expression here. 135 Value aNeg = 136 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero); 137 Value aPos = 138 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero); 139 Value bNeg = 140 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); 141 Value bPos = 142 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); 143 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos); 144 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg); 145 Value compareRes = 146 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm); 147 Value res = rewriter.create<SelectOp>(loc, compareRes, negRes, posRes); 148 // Perform substitution and return success. 149 rewriter.replaceOp(op, {res}); 150 return success(); 151 } 152 }; 153 154 template <typename OpTy, arith::CmpFPredicate pred> 155 struct MaxMinFOpConverter : public OpRewritePattern<OpTy> { 156 public: 157 using OpRewritePattern<OpTy>::OpRewritePattern; 158 159 LogicalResult matchAndRewrite(OpTy op, 160 PatternRewriter &rewriter) const final { 161 Value lhs = op.getLhs(); 162 Value rhs = op.getRhs(); 163 164 Location loc = op.getLoc(); 165 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs); 166 Value select = rewriter.create<SelectOp>(loc, cmp, lhs, rhs); 167 168 auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>(); 169 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO, 170 lhs, rhs); 171 172 Value nan = rewriter.create<arith::ConstantFloatOp>( 173 loc, APFloat::getQNaN(floatType.getFloatSemantics()), floatType); 174 if (VectorType vectorType = lhs.getType().dyn_cast<VectorType>()) 175 nan = rewriter.create<SplatOp>(loc, vectorType, nan); 176 177 rewriter.replaceOpWithNewOp<SelectOp>(op, isNaN, nan, select); 178 return success(); 179 } 180 }; 181 182 template <typename OpTy, arith::CmpIPredicate pred> 183 struct MaxMinIOpConverter : public OpRewritePattern<OpTy> { 184 public: 185 using OpRewritePattern<OpTy>::OpRewritePattern; 186 LogicalResult matchAndRewrite(OpTy op, 187 PatternRewriter &rewriter) const final { 188 Value lhs = op.getLhs(); 189 Value rhs = op.getRhs(); 190 191 Location loc = op.getLoc(); 192 Value cmp = rewriter.create<arith::CmpIOp>(loc, pred, lhs, rhs); 193 rewriter.replaceOpWithNewOp<SelectOp>(op, cmp, lhs, rhs); 194 return success(); 195 } 196 }; 197 198 struct ArithmeticExpandOpsPass 199 : public ArithmeticExpandOpsBase<ArithmeticExpandOpsPass> { 200 void runOnFunction() override { 201 RewritePatternSet patterns(&getContext()); 202 ConversionTarget target(getContext()); 203 204 arith::populateArithmeticExpandOpsPatterns(patterns); 205 206 target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect>(); 207 // clang-format off 208 target.addIllegalOp< 209 arith::CeilDivSIOp, 210 arith::CeilDivUIOp, 211 arith::FloorDivSIOp, 212 arith::MaxFOp, 213 arith::MaxSIOp, 214 arith::MaxUIOp, 215 arith::MinFOp, 216 arith::MinSIOp, 217 arith::MinUIOp 218 >(); 219 // clang-format on 220 if (failed( 221 applyPartialConversion(getFunction(), target, std::move(patterns)))) 222 signalPassFailure(); 223 } 224 }; 225 226 } // namespace 227 228 void mlir::arith::populateArithmeticExpandOpsPatterns( 229 RewritePatternSet &patterns) { 230 // clang-format off 231 patterns.add< 232 CeilDivSIOpConverter, 233 CeilDivUIOpConverter, 234 FloorDivSIOpConverter, 235 MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::OGT>, 236 MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::OLT>, 237 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>, 238 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>, 239 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>, 240 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult> 241 >(patterns.getContext()); 242 // clang-format on 243 } 244 245 std::unique_ptr<Pass> mlir::arith::createArithmeticExpandOpsPass() { 246 return std::make_unique<ArithmeticExpandOpsPass>(); 247 } 248