18506c8c1SGroverkss //===- LinearTransform.cpp - MLIR LinearTransform Class -------------------===//
28506c8c1SGroverkss //
38506c8c1SGroverkss // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48506c8c1SGroverkss // See https://llvm.org/LICENSE.txt for license information.
58506c8c1SGroverkss // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68506c8c1SGroverkss //
78506c8c1SGroverkss //===----------------------------------------------------------------------===//
88506c8c1SGroverkss 
98506c8c1SGroverkss #include "mlir/Analysis/Presburger/LinearTransform.h"
10bb901355SGroverkss #include "mlir/Analysis/Presburger/IntegerRelation.h"
118506c8c1SGroverkss 
120c1f6865SGroverkss using namespace mlir;
130c1f6865SGroverkss using namespace presburger;
148506c8c1SGroverkss 
LinearTransform(Matrix && oMatrix)158506c8c1SGroverkss LinearTransform::LinearTransform(Matrix &&oMatrix) : matrix(oMatrix) {}
LinearTransform(const Matrix & oMatrix)168506c8c1SGroverkss LinearTransform::LinearTransform(const Matrix &oMatrix) : matrix(oMatrix) {}
178506c8c1SGroverkss 
188506c8c1SGroverkss // Set M(row, targetCol) to its remainder on division by M(row, sourceCol)
198506c8c1SGroverkss // by subtracting from column targetCol an appropriate integer multiple of
208506c8c1SGroverkss // sourceCol. This brings M(row, targetCol) to the range [0, M(row, sourceCol)).
218506c8c1SGroverkss // Apply the same column operation to otherMatrix, with the same integer
228506c8c1SGroverkss // multiple.
modEntryColumnOperation(Matrix & m,unsigned row,unsigned sourceCol,unsigned targetCol,Matrix & otherMatrix)238506c8c1SGroverkss static void modEntryColumnOperation(Matrix &m, unsigned row, unsigned sourceCol,
248506c8c1SGroverkss                                     unsigned targetCol, Matrix &otherMatrix) {
258506c8c1SGroverkss   assert(m(row, sourceCol) != 0 && "Cannot divide by zero!");
268506c8c1SGroverkss   assert((m(row, sourceCol) > 0 && m(row, targetCol) > 0) &&
278506c8c1SGroverkss          "Operands must be positive!");
288506c8c1SGroverkss   int64_t ratio = m(row, targetCol) / m(row, sourceCol);
298506c8c1SGroverkss   m.addToColumn(sourceCol, targetCol, -ratio);
308506c8c1SGroverkss   otherMatrix.addToColumn(sourceCol, targetCol, -ratio);
318506c8c1SGroverkss }
328506c8c1SGroverkss 
338506c8c1SGroverkss std::pair<unsigned, LinearTransform>
makeTransformToColumnEchelon(Matrix m)348506c8c1SGroverkss LinearTransform::makeTransformToColumnEchelon(Matrix m) {
358506c8c1SGroverkss   // We start with an identity result matrix and perform operations on m
368506c8c1SGroverkss   // until m is in column echelon form. We apply the same sequence of operations
378506c8c1SGroverkss   // on resultMatrix to obtain a transform that takes m to column echelon
388506c8c1SGroverkss   // form.
398506c8c1SGroverkss   Matrix resultMatrix = Matrix::identity(m.getNumColumns());
408506c8c1SGroverkss 
418506c8c1SGroverkss   unsigned echelonCol = 0;
428506c8c1SGroverkss   // Invariant: in all rows above row, all columns from echelonCol onwards
438506c8c1SGroverkss   // are all zero elements. In an iteration, if the curent row has any non-zero
448506c8c1SGroverkss   // elements echelonCol onwards, we bring one to echelonCol and use it to
458506c8c1SGroverkss   // make all elements echelonCol + 1 onwards zero.
468506c8c1SGroverkss   for (unsigned row = 0; row < m.getNumRows(); ++row) {
478506c8c1SGroverkss     // Search row for a non-empty entry, starting at echelonCol.
488506c8c1SGroverkss     unsigned nonZeroCol = echelonCol;
498506c8c1SGroverkss     for (unsigned e = m.getNumColumns(); nonZeroCol < e; ++nonZeroCol) {
508506c8c1SGroverkss       if (m(row, nonZeroCol) == 0)
518506c8c1SGroverkss         continue;
528506c8c1SGroverkss       break;
538506c8c1SGroverkss     }
548506c8c1SGroverkss 
558506c8c1SGroverkss     // Continue to the next row with the same echelonCol if this row is all
568506c8c1SGroverkss     // zeros from echelonCol onwards.
578506c8c1SGroverkss     if (nonZeroCol == m.getNumColumns())
588506c8c1SGroverkss       continue;
598506c8c1SGroverkss 
608506c8c1SGroverkss     // Bring the non-zero column to echelonCol. This doesn't affect rows
618506c8c1SGroverkss     // above since they are all zero at these columns.
628506c8c1SGroverkss     if (nonZeroCol != echelonCol) {
638506c8c1SGroverkss       m.swapColumns(nonZeroCol, echelonCol);
648506c8c1SGroverkss       resultMatrix.swapColumns(nonZeroCol, echelonCol);
658506c8c1SGroverkss     }
668506c8c1SGroverkss 
678506c8c1SGroverkss     // Make m(row, echelonCol) non-negative.
688506c8c1SGroverkss     if (m(row, echelonCol) < 0) {
698506c8c1SGroverkss       m.negateColumn(echelonCol);
708506c8c1SGroverkss       resultMatrix.negateColumn(echelonCol);
718506c8c1SGroverkss     }
728506c8c1SGroverkss 
738506c8c1SGroverkss     // Make all the entries in row after echelonCol zero.
748506c8c1SGroverkss     for (unsigned i = echelonCol + 1, e = m.getNumColumns(); i < e; ++i) {
758506c8c1SGroverkss       // We make m(row, i) non-negative, and then apply the Euclidean GCD
768506c8c1SGroverkss       // algorithm to (row, i) and (row, echelonCol). At the end, one of them
778506c8c1SGroverkss       // has value equal to the gcd of the two entries, and the other is zero.
788506c8c1SGroverkss 
798506c8c1SGroverkss       if (m(row, i) < 0) {
808506c8c1SGroverkss         m.negateColumn(i);
818506c8c1SGroverkss         resultMatrix.negateColumn(i);
828506c8c1SGroverkss       }
838506c8c1SGroverkss 
848506c8c1SGroverkss       unsigned targetCol = i, sourceCol = echelonCol;
858506c8c1SGroverkss       // At every step, we set m(row, targetCol) %= m(row, sourceCol), and
868506c8c1SGroverkss       // swap the indices sourceCol and targetCol. (not the columns themselves)
878506c8c1SGroverkss       // This modulo is implemented as a subtraction
888506c8c1SGroverkss       // m(row, targetCol) -= quotient * m(row, sourceCol),
898506c8c1SGroverkss       // where quotient = floor(m(row, targetCol) / m(row, sourceCol)),
908506c8c1SGroverkss       // which brings m(row, targetCol) to the range [0, m(row, sourceCol)).
918506c8c1SGroverkss       //
928506c8c1SGroverkss       // We are only allowed column operations; we perform the above
938506c8c1SGroverkss       // for every row, i.e., the above subtraction is done as a column
948506c8c1SGroverkss       // operation. This does not affect any rows above us since they are
958506c8c1SGroverkss       // guaranteed to be zero at these columns.
968506c8c1SGroverkss       while (m(row, targetCol) != 0 && m(row, sourceCol) != 0) {
978506c8c1SGroverkss         modEntryColumnOperation(m, row, sourceCol, targetCol, resultMatrix);
988506c8c1SGroverkss         std::swap(targetCol, sourceCol);
998506c8c1SGroverkss       }
1008506c8c1SGroverkss 
1018506c8c1SGroverkss       // One of (row, echelonCol) and (row, i) is zero and the other is the gcd.
1028506c8c1SGroverkss       // Make it so that (row, echelonCol) holds the non-zero value.
1038506c8c1SGroverkss       if (m(row, echelonCol) == 0) {
1048506c8c1SGroverkss         m.swapColumns(i, echelonCol);
1058506c8c1SGroverkss         resultMatrix.swapColumns(i, echelonCol);
1068506c8c1SGroverkss       }
1078506c8c1SGroverkss     }
1088506c8c1SGroverkss 
1098506c8c1SGroverkss     ++echelonCol;
1108506c8c1SGroverkss   }
1118506c8c1SGroverkss 
1128506c8c1SGroverkss   return {echelonCol, LinearTransform(std::move(resultMatrix))};
1138506c8c1SGroverkss }
1148506c8c1SGroverkss 
applyTo(const IntegerRelation & rel) const115bb901355SGroverkss IntegerRelation LinearTransform::applyTo(const IntegerRelation &rel) const {
116*a5a598beSGroverkss   IntegerRelation result(rel.getSpace());
1178506c8c1SGroverkss 
118bb901355SGroverkss   for (unsigned i = 0, e = rel.getNumEqualities(); i < e; ++i) {
119bb901355SGroverkss     ArrayRef<int64_t> eq = rel.getEquality(i);
1208506c8c1SGroverkss 
1218506c8c1SGroverkss     int64_t c = eq.back();
1228506c8c1SGroverkss 
123c1562683SArjun P     SmallVector<int64_t, 8> newEq = preMultiplyWithRow(eq.drop_back());
1248506c8c1SGroverkss     newEq.push_back(c);
1258506c8c1SGroverkss     result.addEquality(newEq);
1268506c8c1SGroverkss   }
1278506c8c1SGroverkss 
128bb901355SGroverkss   for (unsigned i = 0, e = rel.getNumInequalities(); i < e; ++i) {
129bb901355SGroverkss     ArrayRef<int64_t> ineq = rel.getInequality(i);
1308506c8c1SGroverkss 
1318506c8c1SGroverkss     int64_t c = ineq.back();
1328506c8c1SGroverkss 
133c1562683SArjun P     SmallVector<int64_t, 8> newIneq = preMultiplyWithRow(ineq.drop_back());
1348506c8c1SGroverkss     newIneq.push_back(c);
1358506c8c1SGroverkss     result.addInequality(newIneq);
1368506c8c1SGroverkss   }
1378506c8c1SGroverkss 
1388506c8c1SGroverkss   return result;
1398506c8c1SGroverkss }
140