1 //===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/Presburger/PWMAFunction.h"
10 #include "mlir/Analysis/Presburger/Simplex.h"
11 
12 using namespace mlir;
13 using namespace presburger;
14 
15 // Return the result of subtracting the two given vectors pointwise.
16 // The vectors must be of the same size.
17 // e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
18 static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
19                                         ArrayRef<int64_t> vecB) {
20   assert(vecA.size() == vecB.size() &&
21          "Cannot subtract vectors of differing lengths!");
22   SmallVector<int64_t, 8> result;
23   result.reserve(vecA.size());
24   for (unsigned i = 0, e = vecA.size(); i < e; ++i)
25     result.push_back(vecA[i] - vecB[i]);
26   return result;
27 }
28 
29 PresburgerSet PWMAFunction::getDomain() const {
30   PresburgerSet domain =
31       PresburgerSet::getEmptySet(getNumDimIds(), getNumSymbolIds());
32   for (const MultiAffineFunction &piece : pieces)
33     domain.unionPolyInPlace(piece.getDomain());
34   return domain;
35 }
36 
37 Optional<SmallVector<int64_t, 8>>
38 MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
39   assert(getNumLocalIds() == 0 && "Local ids are not yet supported!");
40   assert(point.size() == getNumIds() && "Point has incorrect dimensionality!");
41 
42   if (!getDomain().containsPoint(point))
43     return {};
44 
45   // The point lies in the domain, so we need to compute the output value.
46   // The matrix `output` has an affine expression in the ith row, corresponding
47   // to the expression for the ith value in the output vector. The last column
48   // of the matrix contains the constant term. Let v be the input point with
49   // a 1 appended at the end. We can see that output * v gives the desired
50   // output vector.
51   SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
52   pointHomogenous.push_back(1);
53   SmallVector<int64_t, 8> result =
54       output.postMultiplyWithColumn(pointHomogenous);
55   assert(result.size() == getNumOutputs());
56   return result;
57 }
58 
59 Optional<SmallVector<int64_t, 8>>
60 PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
61   assert(point.size() == getNumInputs() &&
62          "Point has incorrect dimensionality!");
63   for (const MultiAffineFunction &piece : pieces)
64     if (Optional<SmallVector<int64_t, 8>> output = piece.valueAt(point))
65       return output;
66   return {};
67 }
68 
69 void MultiAffineFunction::print(raw_ostream &os) const {
70   os << "Domain:";
71   IntegerPolyhedron::print(os);
72   os << "Output:\n";
73   output.print(os);
74   os << "\n";
75 }
76 
77 void MultiAffineFunction::dump() const { print(llvm::errs()); }
78 
79 bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
80   return PresburgerSpace::isEqual(other) &&
81          getDomain().isEqual(other.getDomain()) &&
82          isEqualWhereDomainsOverlap(other);
83 }
84 
85 unsigned MultiAffineFunction::insertId(IdKind kind, unsigned pos,
86                                        unsigned num) {
87   unsigned absolutePos = getIdKindOffset(kind) + pos;
88   output.insertColumns(absolutePos, num);
89   return IntegerPolyhedron::insertId(kind, pos, num);
90 }
91 
92 void MultiAffineFunction::swapId(unsigned posA, unsigned posB) {
93   output.swapColumns(posA, posB);
94   IntegerPolyhedron::swapId(posA, posB);
95 }
96 
97 void MultiAffineFunction::removeIdRange(unsigned idStart, unsigned idLimit) {
98   output.removeColumns(idStart, idLimit - idStart);
99   IntegerPolyhedron::removeIdRange(idStart, idLimit);
100 }
101 
102 void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA,
103                                                     unsigned posB) {
104   output.addToColumn(posB, posA, /*scale=*/1);
105   IntegerPolyhedron::eliminateRedundantLocalId(posA, posB);
106 }
107 
108 bool MultiAffineFunction::isEqualWhereDomainsOverlap(
109     MultiAffineFunction other) const {
110   if (!PresburgerSpace::isEqual(other))
111     return false;
112 
113   // `commonFunc` has the same output as `this`.
114   MultiAffineFunction commonFunc = *this;
115   // After this merge, `commonFunc` and `other` have the same local ids; they
116   // are merged.
117   commonFunc.mergeLocalIds(other);
118   // After this, the domain of `commonFunc` will be the intersection of the
119   // domains of `this` and `other`.
120   commonFunc.IntegerPolyhedron::append(other);
121 
122   // `commonDomainMatching` contains the subset of the common domain
123   // where the outputs of `this` and `other` match.
124   //
125   // We want to add constraints equating the outputs of `this` and `other`.
126   // However, `this` may have difference local ids from `other`, whereas we
127   // need both to have the same locals. Accordingly, we use `commonFunc.output`
128   // in place of `this->output`, since `commonFunc` has the same output but also
129   // has its locals merged.
130   IntegerPolyhedron commonDomainMatching = commonFunc.getDomain();
131   for (unsigned row = 0, e = getNumOutputs(); row < e; ++row)
132     commonDomainMatching.addEquality(
133         subtract(commonFunc.output.getRow(row), other.output.getRow(row)));
134 
135   // If the whole common domain is a subset of commonDomainMatching, then they
136   // are equal and the two functions match on the whole common domain.
137   return commonFunc.getDomain().isSubsetOf(commonDomainMatching);
138 }
139 
140 /// Two PWMAFunctions are equal if they have the same dimensionalities,
141 /// the same domain, and take the same value at every point in the domain.
142 bool PWMAFunction::isEqual(const PWMAFunction &other) const {
143   if (!PresburgerSpace::isEqual(other))
144     return false;
145 
146   if (!this->getDomain().isEqual(other.getDomain()))
147     return false;
148 
149   // Check if, whenever the domains of a piece of `this` and a piece of `other`
150   // overlap, they take the same output value. If `this` and `other` have the
151   // same domain (checked above), then this check passes iff the two functions
152   // have the same output at every point in the domain.
153   for (const MultiAffineFunction &aPiece : this->pieces)
154     for (const MultiAffineFunction &bPiece : other.pieces)
155       if (!aPiece.isEqualWhereDomainsOverlap(bPiece))
156         return false;
157   return true;
158 }
159 
160 void PWMAFunction::addPiece(const MultiAffineFunction &piece) {
161   assert(piece.isSpaceEqual(*this) &&
162          "Piece to be added is not compatible with this PWMAFunction!");
163   assert(piece.isConsistent() && "Piece is internally inconsistent!");
164   assert(this->getDomain()
165              .intersect(PresburgerSet(piece.getDomain()))
166              .isIntegerEmpty() &&
167          "New piece's domain overlaps with that of existing pieces!");
168   pieces.push_back(piece);
169 }
170 
171 void PWMAFunction::addPiece(const IntegerPolyhedron &domain,
172                             const Matrix &output) {
173   addPiece(MultiAffineFunction(domain, output));
174 }
175 
176 void PWMAFunction::print(raw_ostream &os) const {
177   os << pieces.size() << " pieces:\n";
178   for (const MultiAffineFunction &piece : pieces)
179     piece.print(os);
180 }
181