1 //===- ExpandOps.cpp - Pass to legalize Arithmetic ops for LLVM lowering --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" 11 #include "mlir/IR/TypeUtilities.h" 12 13 using namespace mlir; 14 15 namespace { 16 17 /// Expands CeilDivUIOp (n, m) into 18 /// n == 0 ? 0 : ((n-1) / m) + 1 19 struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> { 20 using OpRewritePattern::OpRewritePattern; 21 LogicalResult matchAndRewrite(arith::CeilDivUIOp op, 22 PatternRewriter &rewriter) const final { 23 Location loc = op.getLoc(); 24 Value a = op.lhs(); 25 Value b = op.rhs(); 26 Value zero = rewriter.create<arith::ConstantOp>( 27 loc, rewriter.getIntegerAttr(a.getType(), 0)); 28 Value compare = 29 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero); 30 Value one = rewriter.create<arith::ConstantOp>( 31 loc, rewriter.getIntegerAttr(a.getType(), 1)); 32 Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one); 33 Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b); 34 Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one); 35 Value res = rewriter.create<SelectOp>(loc, compare, zero, plusOne); 36 rewriter.replaceOp(op, {res}); 37 return success(); 38 } 39 }; 40 41 /// Expands CeilDivSIOp (n, m) into 42 /// 1) x = (m > 0) ? -1 : 1 43 /// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m) 44 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> { 45 using OpRewritePattern::OpRewritePattern; 46 LogicalResult matchAndRewrite(arith::CeilDivSIOp op, 47 PatternRewriter &rewriter) const final { 48 Location loc = op.getLoc(); 49 auto signedCeilDivIOp = cast<arith::CeilDivSIOp>(op); 50 Type type = signedCeilDivIOp.getType(); 51 Value a = signedCeilDivIOp.getLhs(); 52 Value b = signedCeilDivIOp.getRhs(); 53 Value plusOne = rewriter.create<arith::ConstantOp>( 54 loc, rewriter.getIntegerAttr(type, 1)); 55 Value zero = rewriter.create<arith::ConstantOp>( 56 loc, rewriter.getIntegerAttr(type, 0)); 57 Value minusOne = rewriter.create<arith::ConstantOp>( 58 loc, rewriter.getIntegerAttr(type, -1)); 59 // Compute x = (b>0) ? -1 : 1. 60 Value compare = 61 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); 62 Value x = rewriter.create<SelectOp>(loc, compare, minusOne, plusOne); 63 // Compute positive res: 1 + ((x+a)/b). 64 Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a); 65 Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b); 66 Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB); 67 // Compute negative res: - ((-a)/b). 68 Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a); 69 Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b); 70 Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB); 71 // Result is (a*b>0) ? pos result : neg result. 72 // Note, we want to avoid using a*b because of possible overflow. 73 // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do 74 // not particuliarly care if a*b<0 is true or false when b is zero 75 // as this will result in an illegal divide. So `a*b<0` can be reformulated 76 // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'. 77 // We pick the first expression here. 78 Value aNeg = 79 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero); 80 Value aPos = 81 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero); 82 Value bNeg = 83 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); 84 Value bPos = 85 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); 86 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg); 87 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos); 88 Value compareRes = 89 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm); 90 Value res = rewriter.create<SelectOp>(loc, compareRes, posRes, negRes); 91 // Perform substitution and return success. 92 rewriter.replaceOp(op, {res}); 93 return success(); 94 } 95 }; 96 97 /// Expands FloorDivSIOp (n, m) into 98 /// 1) x = (m<0) ? 1 : -1 99 /// 2) return (n*m<0) ? - ((-n+x) / m) -1 : n / m 100 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> { 101 using OpRewritePattern::OpRewritePattern; 102 LogicalResult matchAndRewrite(arith::FloorDivSIOp op, 103 PatternRewriter &rewriter) const final { 104 Location loc = op.getLoc(); 105 arith::FloorDivSIOp signedFloorDivIOp = cast<arith::FloorDivSIOp>(op); 106 Type type = signedFloorDivIOp.getType(); 107 Value a = signedFloorDivIOp.getLhs(); 108 Value b = signedFloorDivIOp.getRhs(); 109 Value plusOne = rewriter.create<arith::ConstantOp>( 110 loc, rewriter.getIntegerAttr(type, 1)); 111 Value zero = rewriter.create<arith::ConstantOp>( 112 loc, rewriter.getIntegerAttr(type, 0)); 113 Value minusOne = rewriter.create<arith::ConstantOp>( 114 loc, rewriter.getIntegerAttr(type, -1)); 115 // Compute x = (b<0) ? 1 : -1. 116 Value compare = 117 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); 118 Value x = rewriter.create<SelectOp>(loc, compare, plusOne, minusOne); 119 // Compute negative res: -1 - ((x-a)/b). 120 Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a); 121 Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b); 122 Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB); 123 // Compute positive res: a/b. 124 Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b); 125 // Result is (a*b<0) ? negative result : positive result. 126 // Note, we want to avoid using a*b because of possible overflow. 127 // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do 128 // not particuliarly care if a*b<0 is true or false when b is zero 129 // as this will result in an illegal divide. So `a*b<0` can be reformulated 130 // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'. 131 // We pick the first expression here. 132 Value aNeg = 133 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero); 134 Value aPos = 135 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero); 136 Value bNeg = 137 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); 138 Value bPos = 139 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); 140 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos); 141 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg); 142 Value compareRes = 143 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm); 144 Value res = rewriter.create<SelectOp>(loc, compareRes, negRes, posRes); 145 // Perform substitution and return success. 146 rewriter.replaceOp(op, {res}); 147 return success(); 148 } 149 }; 150 151 template <typename OpTy, arith::CmpFPredicate pred> 152 struct MaxMinFOpConverter : public OpRewritePattern<OpTy> { 153 public: 154 using OpRewritePattern<OpTy>::OpRewritePattern; 155 156 LogicalResult matchAndRewrite(OpTy op, 157 PatternRewriter &rewriter) const final { 158 Value lhs = op.getLhs(); 159 Value rhs = op.getRhs(); 160 161 Location loc = op.getLoc(); 162 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs); 163 Value select = rewriter.create<SelectOp>(loc, cmp, lhs, rhs); 164 165 auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>(); 166 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO, 167 lhs, rhs); 168 169 Value nan = rewriter.create<arith::ConstantFloatOp>( 170 loc, APFloat::getQNaN(floatType.getFloatSemantics()), floatType); 171 if (VectorType vectorType = lhs.getType().dyn_cast<VectorType>()) 172 nan = rewriter.create<SplatOp>(loc, vectorType, nan); 173 174 rewriter.replaceOpWithNewOp<SelectOp>(op, isNaN, nan, select); 175 return success(); 176 } 177 }; 178 179 template <typename OpTy, arith::CmpIPredicate pred> 180 struct MaxMinIOpConverter : public OpRewritePattern<OpTy> { 181 public: 182 using OpRewritePattern<OpTy>::OpRewritePattern; 183 LogicalResult matchAndRewrite(OpTy op, 184 PatternRewriter &rewriter) const final { 185 Value lhs = op.getLhs(); 186 Value rhs = op.getRhs(); 187 188 Location loc = op.getLoc(); 189 Value cmp = rewriter.create<arith::CmpIOp>(loc, pred, lhs, rhs); 190 rewriter.replaceOpWithNewOp<SelectOp>(op, cmp, lhs, rhs); 191 return success(); 192 } 193 }; 194 195 struct ArithmeticExpandOpsPass 196 : public ArithmeticExpandOpsBase<ArithmeticExpandOpsPass> { 197 void runOnFunction() override { 198 RewritePatternSet patterns(&getContext()); 199 ConversionTarget target(getContext()); 200 201 arith::populateArithmeticExpandOpsPatterns(patterns); 202 203 target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect>(); 204 // clang-format off 205 target.addIllegalOp< 206 arith::CeilDivSIOp, 207 arith::CeilDivUIOp, 208 arith::FloorDivSIOp, 209 arith::MaxFOp, 210 arith::MaxSIOp, 211 arith::MaxUIOp, 212 arith::MinFOp, 213 arith::MinSIOp, 214 arith::MinUIOp 215 >(); 216 // clang-format on 217 if (failed( 218 applyPartialConversion(getFunction(), target, std::move(patterns)))) 219 signalPassFailure(); 220 } 221 }; 222 223 } // end anonymous namespace 224 225 void mlir::arith::populateArithmeticExpandOpsPatterns( 226 RewritePatternSet &patterns) { 227 // clang-format off 228 patterns.add< 229 CeilDivSIOpConverter, 230 CeilDivUIOpConverter, 231 FloorDivSIOpConverter, 232 MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::OGT>, 233 MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::OLT>, 234 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>, 235 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>, 236 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>, 237 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult> 238 >(patterns.getContext()); 239 // clang-format on 240 } 241 242 std::unique_ptr<Pass> mlir::arith::createArithmeticExpandOpsPass() { 243 return std::make_unique<ArithmeticExpandOpsPass>(); 244 } 245