11e2691feSRoman Lebedev //===- ScalarEvolutionDivision.h - See below --------------------*- C++ -*-===//
21e2691feSRoman Lebedev //
31e2691feSRoman Lebedev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
41e2691feSRoman Lebedev // See https://llvm.org/LICENSE.txt for license information.
51e2691feSRoman Lebedev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
61e2691feSRoman Lebedev //
71e2691feSRoman Lebedev //===----------------------------------------------------------------------===//
81e2691feSRoman Lebedev //
91e2691feSRoman Lebedev // This file defines the class that knows how to divide SCEV's.
101e2691feSRoman Lebedev //
111e2691feSRoman Lebedev //===----------------------------------------------------------------------===//
121e2691feSRoman Lebedev
131e2691feSRoman Lebedev #include "llvm/Analysis/ScalarEvolutionDivision.h"
141e2691feSRoman Lebedev #include "llvm/ADT/APInt.h"
151e2691feSRoman Lebedev #include "llvm/ADT/DenseMap.h"
161e2691feSRoman Lebedev #include "llvm/ADT/SmallVector.h"
171e2691feSRoman Lebedev #include "llvm/Analysis/ScalarEvolution.h"
181e2691feSRoman Lebedev #include "llvm/Support/Casting.h"
191e2691feSRoman Lebedev #include <cassert>
201e2691feSRoman Lebedev #include <cstdint>
211e2691feSRoman Lebedev
221e2691feSRoman Lebedev namespace llvm {
231e2691feSRoman Lebedev class Type;
241e2691feSRoman Lebedev }
251e2691feSRoman Lebedev
261e2691feSRoman Lebedev using namespace llvm;
271e2691feSRoman Lebedev
281e2691feSRoman Lebedev namespace {
291e2691feSRoman Lebedev
sizeOfSCEV(const SCEV * S)301e2691feSRoman Lebedev static inline int sizeOfSCEV(const SCEV *S) {
311e2691feSRoman Lebedev struct FindSCEVSize {
321e2691feSRoman Lebedev int Size = 0;
331e2691feSRoman Lebedev
341e2691feSRoman Lebedev FindSCEVSize() = default;
351e2691feSRoman Lebedev
361e2691feSRoman Lebedev bool follow(const SCEV *S) {
371e2691feSRoman Lebedev ++Size;
381e2691feSRoman Lebedev // Keep looking at all operands of S.
391e2691feSRoman Lebedev return true;
401e2691feSRoman Lebedev }
411e2691feSRoman Lebedev
421e2691feSRoman Lebedev bool isDone() const { return false; }
431e2691feSRoman Lebedev };
441e2691feSRoman Lebedev
451e2691feSRoman Lebedev FindSCEVSize F;
461e2691feSRoman Lebedev SCEVTraversal<FindSCEVSize> ST(F);
471e2691feSRoman Lebedev ST.visitAll(S);
481e2691feSRoman Lebedev return F.Size;
491e2691feSRoman Lebedev }
501e2691feSRoman Lebedev
511e2691feSRoman Lebedev } // namespace
521e2691feSRoman Lebedev
531e2691feSRoman Lebedev // Computes the Quotient and Remainder of the division of Numerator by
541e2691feSRoman Lebedev // Denominator.
divide(ScalarEvolution & SE,const SCEV * Numerator,const SCEV * Denominator,const SCEV ** Quotient,const SCEV ** Remainder)551e2691feSRoman Lebedev void SCEVDivision::divide(ScalarEvolution &SE, const SCEV *Numerator,
561e2691feSRoman Lebedev const SCEV *Denominator, const SCEV **Quotient,
571e2691feSRoman Lebedev const SCEV **Remainder) {
581e2691feSRoman Lebedev assert(Numerator && Denominator && "Uninitialized SCEV");
591e2691feSRoman Lebedev
601e2691feSRoman Lebedev SCEVDivision D(SE, Numerator, Denominator);
611e2691feSRoman Lebedev
621e2691feSRoman Lebedev // Check for the trivial case here to avoid having to check for it in the
631e2691feSRoman Lebedev // rest of the code.
641e2691feSRoman Lebedev if (Numerator == Denominator) {
651e2691feSRoman Lebedev *Quotient = D.One;
661e2691feSRoman Lebedev *Remainder = D.Zero;
671e2691feSRoman Lebedev return;
681e2691feSRoman Lebedev }
691e2691feSRoman Lebedev
701e2691feSRoman Lebedev if (Numerator->isZero()) {
711e2691feSRoman Lebedev *Quotient = D.Zero;
721e2691feSRoman Lebedev *Remainder = D.Zero;
731e2691feSRoman Lebedev return;
741e2691feSRoman Lebedev }
751e2691feSRoman Lebedev
761e2691feSRoman Lebedev // A simple case when N/1. The quotient is N.
771e2691feSRoman Lebedev if (Denominator->isOne()) {
781e2691feSRoman Lebedev *Quotient = Numerator;
791e2691feSRoman Lebedev *Remainder = D.Zero;
801e2691feSRoman Lebedev return;
811e2691feSRoman Lebedev }
821e2691feSRoman Lebedev
831e2691feSRoman Lebedev // Split the Denominator when it is a product.
841e2691feSRoman Lebedev if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) {
851e2691feSRoman Lebedev const SCEV *Q, *R;
861e2691feSRoman Lebedev *Quotient = Numerator;
871e2691feSRoman Lebedev for (const SCEV *Op : T->operands()) {
881e2691feSRoman Lebedev divide(SE, *Quotient, Op, &Q, &R);
891e2691feSRoman Lebedev *Quotient = Q;
901e2691feSRoman Lebedev
911e2691feSRoman Lebedev // Bail out when the Numerator is not divisible by one of the terms of
921e2691feSRoman Lebedev // the Denominator.
931e2691feSRoman Lebedev if (!R->isZero()) {
941e2691feSRoman Lebedev *Quotient = D.Zero;
951e2691feSRoman Lebedev *Remainder = Numerator;
961e2691feSRoman Lebedev return;
971e2691feSRoman Lebedev }
981e2691feSRoman Lebedev }
991e2691feSRoman Lebedev *Remainder = D.Zero;
1001e2691feSRoman Lebedev return;
1011e2691feSRoman Lebedev }
1021e2691feSRoman Lebedev
1031e2691feSRoman Lebedev D.visit(Numerator);
1041e2691feSRoman Lebedev *Quotient = D.Quotient;
1051e2691feSRoman Lebedev *Remainder = D.Remainder;
1061e2691feSRoman Lebedev }
1071e2691feSRoman Lebedev
visitConstant(const SCEVConstant * Numerator)1081e2691feSRoman Lebedev void SCEVDivision::visitConstant(const SCEVConstant *Numerator) {
1091e2691feSRoman Lebedev if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
1101e2691feSRoman Lebedev APInt NumeratorVal = Numerator->getAPInt();
1111e2691feSRoman Lebedev APInt DenominatorVal = D->getAPInt();
1121e2691feSRoman Lebedev uint32_t NumeratorBW = NumeratorVal.getBitWidth();
1131e2691feSRoman Lebedev uint32_t DenominatorBW = DenominatorVal.getBitWidth();
1141e2691feSRoman Lebedev
1151e2691feSRoman Lebedev if (NumeratorBW > DenominatorBW)
1161e2691feSRoman Lebedev DenominatorVal = DenominatorVal.sext(NumeratorBW);
1171e2691feSRoman Lebedev else if (NumeratorBW < DenominatorBW)
1181e2691feSRoman Lebedev NumeratorVal = NumeratorVal.sext(DenominatorBW);
1191e2691feSRoman Lebedev
1201e2691feSRoman Lebedev APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
1211e2691feSRoman Lebedev APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
1221e2691feSRoman Lebedev APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
1231e2691feSRoman Lebedev Quotient = SE.getConstant(QuotientVal);
1241e2691feSRoman Lebedev Remainder = SE.getConstant(RemainderVal);
1251e2691feSRoman Lebedev return;
1261e2691feSRoman Lebedev }
1271e2691feSRoman Lebedev }
1281e2691feSRoman Lebedev
visitAddRecExpr(const SCEVAddRecExpr * Numerator)1291e2691feSRoman Lebedev void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
1301e2691feSRoman Lebedev const SCEV *StartQ, *StartR, *StepQ, *StepR;
1311e2691feSRoman Lebedev if (!Numerator->isAffine())
1321e2691feSRoman Lebedev return cannotDivide(Numerator);
1331e2691feSRoman Lebedev divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
1341e2691feSRoman Lebedev divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
1351e2691feSRoman Lebedev // Bail out if the types do not match.
1361e2691feSRoman Lebedev Type *Ty = Denominator->getType();
1371e2691feSRoman Lebedev if (Ty != StartQ->getType() || Ty != StartR->getType() ||
1381e2691feSRoman Lebedev Ty != StepQ->getType() || Ty != StepR->getType())
1391e2691feSRoman Lebedev return cannotDivide(Numerator);
1401e2691feSRoman Lebedev Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
1411e2691feSRoman Lebedev Numerator->getNoWrapFlags());
1421e2691feSRoman Lebedev Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
1431e2691feSRoman Lebedev Numerator->getNoWrapFlags());
1441e2691feSRoman Lebedev }
1451e2691feSRoman Lebedev
visitAddExpr(const SCEVAddExpr * Numerator)1461e2691feSRoman Lebedev void SCEVDivision::visitAddExpr(const SCEVAddExpr *Numerator) {
1471e2691feSRoman Lebedev SmallVector<const SCEV *, 2> Qs, Rs;
1481e2691feSRoman Lebedev Type *Ty = Denominator->getType();
1491e2691feSRoman Lebedev
1501e2691feSRoman Lebedev for (const SCEV *Op : Numerator->operands()) {
1511e2691feSRoman Lebedev const SCEV *Q, *R;
1521e2691feSRoman Lebedev divide(SE, Op, Denominator, &Q, &R);
1531e2691feSRoman Lebedev
1541e2691feSRoman Lebedev // Bail out if types do not match.
1551e2691feSRoman Lebedev if (Ty != Q->getType() || Ty != R->getType())
1561e2691feSRoman Lebedev return cannotDivide(Numerator);
1571e2691feSRoman Lebedev
1581e2691feSRoman Lebedev Qs.push_back(Q);
1591e2691feSRoman Lebedev Rs.push_back(R);
1601e2691feSRoman Lebedev }
1611e2691feSRoman Lebedev
1621e2691feSRoman Lebedev if (Qs.size() == 1) {
1631e2691feSRoman Lebedev Quotient = Qs[0];
1641e2691feSRoman Lebedev Remainder = Rs[0];
1651e2691feSRoman Lebedev return;
1661e2691feSRoman Lebedev }
1671e2691feSRoman Lebedev
1681e2691feSRoman Lebedev Quotient = SE.getAddExpr(Qs);
1691e2691feSRoman Lebedev Remainder = SE.getAddExpr(Rs);
1701e2691feSRoman Lebedev }
1711e2691feSRoman Lebedev
visitMulExpr(const SCEVMulExpr * Numerator)1721e2691feSRoman Lebedev void SCEVDivision::visitMulExpr(const SCEVMulExpr *Numerator) {
1731e2691feSRoman Lebedev SmallVector<const SCEV *, 2> Qs;
1741e2691feSRoman Lebedev Type *Ty = Denominator->getType();
1751e2691feSRoman Lebedev
1761e2691feSRoman Lebedev bool FoundDenominatorTerm = false;
1771e2691feSRoman Lebedev for (const SCEV *Op : Numerator->operands()) {
1781e2691feSRoman Lebedev // Bail out if types do not match.
1791e2691feSRoman Lebedev if (Ty != Op->getType())
1801e2691feSRoman Lebedev return cannotDivide(Numerator);
1811e2691feSRoman Lebedev
1821e2691feSRoman Lebedev if (FoundDenominatorTerm) {
1831e2691feSRoman Lebedev Qs.push_back(Op);
1841e2691feSRoman Lebedev continue;
1851e2691feSRoman Lebedev }
1861e2691feSRoman Lebedev
1871e2691feSRoman Lebedev // Check whether Denominator divides one of the product operands.
1881e2691feSRoman Lebedev const SCEV *Q, *R;
1891e2691feSRoman Lebedev divide(SE, Op, Denominator, &Q, &R);
1901e2691feSRoman Lebedev if (!R->isZero()) {
1911e2691feSRoman Lebedev Qs.push_back(Op);
1921e2691feSRoman Lebedev continue;
1931e2691feSRoman Lebedev }
1941e2691feSRoman Lebedev
1951e2691feSRoman Lebedev // Bail out if types do not match.
1961e2691feSRoman Lebedev if (Ty != Q->getType())
1971e2691feSRoman Lebedev return cannotDivide(Numerator);
1981e2691feSRoman Lebedev
1991e2691feSRoman Lebedev FoundDenominatorTerm = true;
2001e2691feSRoman Lebedev Qs.push_back(Q);
2011e2691feSRoman Lebedev }
2021e2691feSRoman Lebedev
2031e2691feSRoman Lebedev if (FoundDenominatorTerm) {
2041e2691feSRoman Lebedev Remainder = Zero;
2051e2691feSRoman Lebedev if (Qs.size() == 1)
2061e2691feSRoman Lebedev Quotient = Qs[0];
2071e2691feSRoman Lebedev else
2081e2691feSRoman Lebedev Quotient = SE.getMulExpr(Qs);
2091e2691feSRoman Lebedev return;
2101e2691feSRoman Lebedev }
2111e2691feSRoman Lebedev
2121e2691feSRoman Lebedev if (!isa<SCEVUnknown>(Denominator))
2131e2691feSRoman Lebedev return cannotDivide(Numerator);
2141e2691feSRoman Lebedev
2151e2691feSRoman Lebedev // The Remainder is obtained by replacing Denominator by 0 in Numerator.
216*4635f605SFlorian Hahn ValueToSCEVMapTy RewriteMap;
217*4635f605SFlorian Hahn RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = Zero;
218*4635f605SFlorian Hahn Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
2191e2691feSRoman Lebedev
2201e2691feSRoman Lebedev if (Remainder->isZero()) {
2211e2691feSRoman Lebedev // The Quotient is obtained by replacing Denominator by 1 in Numerator.
222*4635f605SFlorian Hahn RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = One;
223*4635f605SFlorian Hahn Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap);
2241e2691feSRoman Lebedev return;
2251e2691feSRoman Lebedev }
2261e2691feSRoman Lebedev
2271e2691feSRoman Lebedev // Quotient is (Numerator - Remainder) divided by Denominator.
2281e2691feSRoman Lebedev const SCEV *Q, *R;
2291e2691feSRoman Lebedev const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
2301e2691feSRoman Lebedev // This SCEV does not seem to simplify: fail the division here.
2311e2691feSRoman Lebedev if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator))
2321e2691feSRoman Lebedev return cannotDivide(Numerator);
2331e2691feSRoman Lebedev divide(SE, Diff, Denominator, &Q, &R);
2341e2691feSRoman Lebedev if (R != Zero)
2351e2691feSRoman Lebedev return cannotDivide(Numerator);
2361e2691feSRoman Lebedev Quotient = Q;
2371e2691feSRoman Lebedev }
2381e2691feSRoman Lebedev
SCEVDivision(ScalarEvolution & S,const SCEV * Numerator,const SCEV * Denominator)2391e2691feSRoman Lebedev SCEVDivision::SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
2401e2691feSRoman Lebedev const SCEV *Denominator)
2411e2691feSRoman Lebedev : SE(S), Denominator(Denominator) {
2421e2691feSRoman Lebedev Zero = SE.getZero(Denominator->getType());
2431e2691feSRoman Lebedev One = SE.getOne(Denominator->getType());
2441e2691feSRoman Lebedev
2451e2691feSRoman Lebedev // We generally do not know how to divide Expr by Denominator. We initialize
2461e2691feSRoman Lebedev // the division to a "cannot divide" state to simplify the rest of the code.
2471e2691feSRoman Lebedev cannotDivide(Numerator);
2481e2691feSRoman Lebedev }
2491e2691feSRoman Lebedev
2501e2691feSRoman Lebedev // Convenience function for giving up on the division. We set the quotient to
2511e2691feSRoman Lebedev // be equal to zero and the remainder to be equal to the numerator.
cannotDivide(const SCEV * Numerator)2521e2691feSRoman Lebedev void SCEVDivision::cannotDivide(const SCEV *Numerator) {
2531e2691feSRoman Lebedev Quotient = Zero;
2541e2691feSRoman Lebedev Remainder = Numerator;
2551e2691feSRoman Lebedev }
256