1 //===- ExpandOps.cpp - Pass to legalize Arithmetic ops for LLVM lowering --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" 11 12 using namespace mlir; 13 14 namespace { 15 16 /// Expands CeilDivUIOp (n, m) into 17 /// n == 0 ? 0 : ((n-1) / m) + 1 18 struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> { 19 using OpRewritePattern::OpRewritePattern; 20 LogicalResult matchAndRewrite(arith::CeilDivUIOp op, 21 PatternRewriter &rewriter) const final { 22 Location loc = op.getLoc(); 23 Value a = op.lhs(); 24 Value b = op.rhs(); 25 Value zero = rewriter.create<arith::ConstantOp>( 26 loc, rewriter.getIntegerAttr(a.getType(), 0)); 27 Value compare = 28 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero); 29 Value one = rewriter.create<arith::ConstantOp>( 30 loc, rewriter.getIntegerAttr(a.getType(), 1)); 31 Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one); 32 Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b); 33 Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one); 34 Value res = rewriter.create<SelectOp>(loc, compare, zero, plusOne); 35 rewriter.replaceOp(op, {res}); 36 return success(); 37 } 38 }; 39 40 /// Expands CeilDivSIOp (n, m) into 41 /// 1) x = (m > 0) ? -1 : 1 42 /// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m) 43 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> { 44 using OpRewritePattern::OpRewritePattern; 45 LogicalResult matchAndRewrite(arith::CeilDivSIOp op, 46 PatternRewriter &rewriter) const final { 47 Location loc = op.getLoc(); 48 auto signedCeilDivIOp = cast<arith::CeilDivSIOp>(op); 49 Type type = signedCeilDivIOp.getType(); 50 Value a = signedCeilDivIOp.getLhs(); 51 Value b = signedCeilDivIOp.getRhs(); 52 Value plusOne = rewriter.create<arith::ConstantOp>( 53 loc, rewriter.getIntegerAttr(type, 1)); 54 Value zero = rewriter.create<arith::ConstantOp>( 55 loc, rewriter.getIntegerAttr(type, 0)); 56 Value minusOne = rewriter.create<arith::ConstantOp>( 57 loc, rewriter.getIntegerAttr(type, -1)); 58 // Compute x = (b>0) ? -1 : 1. 59 Value compare = 60 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); 61 Value x = rewriter.create<SelectOp>(loc, compare, minusOne, plusOne); 62 // Compute positive res: 1 + ((x+a)/b). 63 Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a); 64 Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b); 65 Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB); 66 // Compute negative res: - ((-a)/b). 67 Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a); 68 Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b); 69 Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB); 70 // Result is (a*b>0) ? pos result : neg result. 71 // Note, we want to avoid using a*b because of possible overflow. 72 // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do 73 // not particuliarly care if a*b<0 is true or false when b is zero 74 // as this will result in an illegal divide. So `a*b<0` can be reformulated 75 // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'. 76 // We pick the first expression here. 77 Value aNeg = 78 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero); 79 Value aPos = 80 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero); 81 Value bNeg = 82 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); 83 Value bPos = 84 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); 85 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg); 86 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos); 87 Value compareRes = 88 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm); 89 Value res = rewriter.create<SelectOp>(loc, compareRes, posRes, negRes); 90 // Perform substitution and return success. 91 rewriter.replaceOp(op, {res}); 92 return success(); 93 } 94 }; 95 96 /// Expands FloorDivSIOp (n, m) into 97 /// 1) x = (m<0) ? 1 : -1 98 /// 2) return (n*m<0) ? - ((-n+x) / m) -1 : n / m 99 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> { 100 using OpRewritePattern::OpRewritePattern; 101 LogicalResult matchAndRewrite(arith::FloorDivSIOp op, 102 PatternRewriter &rewriter) const final { 103 Location loc = op.getLoc(); 104 arith::FloorDivSIOp signedFloorDivIOp = cast<arith::FloorDivSIOp>(op); 105 Type type = signedFloorDivIOp.getType(); 106 Value a = signedFloorDivIOp.getLhs(); 107 Value b = signedFloorDivIOp.getRhs(); 108 Value plusOne = rewriter.create<arith::ConstantOp>( 109 loc, rewriter.getIntegerAttr(type, 1)); 110 Value zero = rewriter.create<arith::ConstantOp>( 111 loc, rewriter.getIntegerAttr(type, 0)); 112 Value minusOne = rewriter.create<arith::ConstantOp>( 113 loc, rewriter.getIntegerAttr(type, -1)); 114 // Compute x = (b<0) ? 1 : -1. 115 Value compare = 116 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); 117 Value x = rewriter.create<SelectOp>(loc, compare, plusOne, minusOne); 118 // Compute negative res: -1 - ((x-a)/b). 119 Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a); 120 Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b); 121 Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB); 122 // Compute positive res: a/b. 123 Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b); 124 // Result is (a*b<0) ? negative result : positive result. 125 // Note, we want to avoid using a*b because of possible overflow. 126 // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do 127 // not particuliarly care if a*b<0 is true or false when b is zero 128 // as this will result in an illegal divide. So `a*b<0` can be reformulated 129 // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'. 130 // We pick the first expression here. 131 Value aNeg = 132 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero); 133 Value aPos = 134 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero); 135 Value bNeg = 136 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); 137 Value bPos = 138 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); 139 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos); 140 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg); 141 Value compareRes = 142 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm); 143 Value res = rewriter.create<SelectOp>(loc, compareRes, negRes, posRes); 144 // Perform substitution and return success. 145 rewriter.replaceOp(op, {res}); 146 return success(); 147 } 148 }; 149 150 struct ArithmeticExpandOpsPass 151 : public ArithmeticExpandOpsBase<ArithmeticExpandOpsPass> { 152 void runOnFunction() override { 153 RewritePatternSet patterns(&getContext()); 154 ConversionTarget target(getContext()); 155 156 arith::populateArithmeticExpandOpsPatterns(patterns); 157 158 target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect>(); 159 target.addIllegalOp<arith::CeilDivUIOp, arith::CeilDivSIOp, 160 arith::FloorDivSIOp>(); 161 162 if (failed( 163 applyPartialConversion(getFunction(), target, std::move(patterns)))) 164 signalPassFailure(); 165 } 166 }; 167 168 } // end anonymous namespace 169 170 void mlir::arith::populateArithmeticExpandOpsPatterns( 171 RewritePatternSet &patterns) { 172 patterns 173 .add<CeilDivUIOpConverter, CeilDivSIOpConverter, FloorDivSIOpConverter>( 174 patterns.getContext()); 175 } 176 177 std::unique_ptr<Pass> mlir::arith::createArithmeticExpandOpsPass() { 178 return std::make_unique<ArithmeticExpandOpsPass>(); 179 } 180