1 //===- ExpandOps.cpp - Pass to legalize Arithmetic ops for LLVM lowering --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" 11 12 using namespace mlir; 13 14 namespace { 15 16 /// Expands CeilDivSIOp (n, m) into 17 /// 1) x = (m > 0) ? -1 : 1 18 /// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m) 19 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> { 20 using OpRewritePattern::OpRewritePattern; 21 LogicalResult matchAndRewrite(arith::CeilDivSIOp op, 22 PatternRewriter &rewriter) const final { 23 Location loc = op.getLoc(); 24 auto signedCeilDivIOp = cast<arith::CeilDivSIOp>(op); 25 Type type = signedCeilDivIOp.getType(); 26 Value a = signedCeilDivIOp.getLhs(); 27 Value b = signedCeilDivIOp.getRhs(); 28 Value plusOne = rewriter.create<arith::ConstantOp>( 29 loc, rewriter.getIntegerAttr(type, 1)); 30 Value zero = rewriter.create<arith::ConstantOp>( 31 loc, rewriter.getIntegerAttr(type, 0)); 32 Value minusOne = rewriter.create<arith::ConstantOp>( 33 loc, rewriter.getIntegerAttr(type, -1)); 34 // Compute x = (b>0) ? -1 : 1. 35 Value compare = 36 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); 37 Value x = rewriter.create<SelectOp>(loc, compare, minusOne, plusOne); 38 // Compute positive res: 1 + ((x+a)/b). 39 Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a); 40 Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b); 41 Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB); 42 // Compute negative res: - ((-a)/b). 43 Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a); 44 Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b); 45 Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB); 46 // Result is (a*b>0) ? pos result : neg result. 47 // Note, we want to avoid using a*b because of possible overflow. 48 // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do 49 // not particuliarly care if a*b<0 is true or false when b is zero 50 // as this will result in an illegal divide. So `a*b<0` can be reformulated 51 // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'. 52 // We pick the first expression here. 53 Value aNeg = 54 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero); 55 Value aPos = 56 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero); 57 Value bNeg = 58 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); 59 Value bPos = 60 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); 61 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg); 62 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos); 63 Value compareRes = 64 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm); 65 Value res = rewriter.create<SelectOp>(loc, compareRes, posRes, negRes); 66 // Perform substitution and return success. 67 rewriter.replaceOp(op, {res}); 68 return success(); 69 } 70 }; 71 72 /// Expands FloorDivSIOp (n, m) into 73 /// 1) x = (m<0) ? 1 : -1 74 /// 2) return (n*m<0) ? - ((-n+x) / m) -1 : n / m 75 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> { 76 using OpRewritePattern::OpRewritePattern; 77 LogicalResult matchAndRewrite(arith::FloorDivSIOp op, 78 PatternRewriter &rewriter) const final { 79 Location loc = op.getLoc(); 80 arith::FloorDivSIOp signedFloorDivIOp = cast<arith::FloorDivSIOp>(op); 81 Type type = signedFloorDivIOp.getType(); 82 Value a = signedFloorDivIOp.getLhs(); 83 Value b = signedFloorDivIOp.getRhs(); 84 Value plusOne = rewriter.create<arith::ConstantOp>( 85 loc, rewriter.getIntegerAttr(type, 1)); 86 Value zero = rewriter.create<arith::ConstantOp>( 87 loc, rewriter.getIntegerAttr(type, 0)); 88 Value minusOne = rewriter.create<arith::ConstantOp>( 89 loc, rewriter.getIntegerAttr(type, -1)); 90 // Compute x = (b<0) ? 1 : -1. 91 Value compare = 92 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); 93 Value x = rewriter.create<SelectOp>(loc, compare, plusOne, minusOne); 94 // Compute negative res: -1 - ((x-a)/b). 95 Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a); 96 Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b); 97 Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB); 98 // Compute positive res: a/b. 99 Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b); 100 // Result is (a*b<0) ? negative result : positive result. 101 // Note, we want to avoid using a*b because of possible overflow. 102 // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do 103 // not particuliarly care if a*b<0 is true or false when b is zero 104 // as this will result in an illegal divide. So `a*b<0` can be reformulated 105 // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'. 106 // We pick the first expression here. 107 Value aNeg = 108 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero); 109 Value aPos = 110 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero); 111 Value bNeg = 112 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); 113 Value bPos = 114 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); 115 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos); 116 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg); 117 Value compareRes = 118 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm); 119 Value res = rewriter.create<SelectOp>(loc, compareRes, negRes, posRes); 120 // Perform substitution and return success. 121 rewriter.replaceOp(op, {res}); 122 return success(); 123 } 124 }; 125 126 struct ArithmeticExpandOpsPass 127 : public ArithmeticExpandOpsBase<ArithmeticExpandOpsPass> { 128 void runOnFunction() override { 129 RewritePatternSet patterns(&getContext()); 130 ConversionTarget target(getContext()); 131 132 arith::populateArithmeticExpandOpsPatterns(patterns); 133 134 target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect>(); 135 target.addIllegalOp<arith::CeilDivSIOp, arith::FloorDivSIOp>(); 136 137 if (failed( 138 applyPartialConversion(getFunction(), target, std::move(patterns)))) 139 signalPassFailure(); 140 } 141 }; 142 143 } // end anonymous namespace 144 145 void mlir::arith::populateArithmeticExpandOpsPatterns( 146 RewritePatternSet &patterns) { 147 patterns.add<CeilDivSIOpConverter, FloorDivSIOpConverter>( 148 patterns.getContext()); 149 } 150 151 std::unique_ptr<Pass> mlir::arith::createArithmeticExpandOpsPass() { 152 return std::make_unique<ArithmeticExpandOpsPass>(); 153 } 154