1 //===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/Presburger/PWMAFunction.h"
10 #include "mlir/Analysis/Presburger/Simplex.h"
11 
12 using namespace mlir;
13 
14 // Return the result of subtracting the two given vectors pointwise.
15 // The vectors must be of the same size.
16 // e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
17 static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
18                                         ArrayRef<int64_t> vecB) {
19   assert(vecA.size() == vecB.size() &&
20          "Cannot subtract vectors of differing lengths!");
21   SmallVector<int64_t, 8> result;
22   result.reserve(vecA.size());
23   for (unsigned i = 0, e = vecA.size(); i < e; ++i)
24     result.push_back(vecA[i] - vecB[i]);
25   return result;
26 }
27 
28 PresburgerSet PWMAFunction::getDomain() const {
29   PresburgerSet domain =
30       PresburgerSet::getEmptySet(getNumDimIds(), getNumSymbolIds());
31   for (const MultiAffineFunction &piece : pieces)
32     domain.unionPolyInPlace(piece.getDomain());
33   return domain;
34 }
35 
36 Optional<SmallVector<int64_t, 8>>
37 MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
38   assert(getNumLocalIds() == 0 && "Local ids are not yet supported!");
39   assert(point.size() == getNumIds() && "Point has incorrect dimensionality!");
40 
41   if (!getDomain().containsPoint(point))
42     return {};
43 
44   // The point lies in the domain, so we need to compute the output value.
45   // The matrix `output` has an affine expression in the ith row, corresponding
46   // to the expression for the ith value in the output vector. The last column
47   // of the matrix contains the constant term. Let v be the input point with
48   // a 1 appended at the end. We can see that output * v gives the desired
49   // output vector.
50   SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
51   pointHomogenous.push_back(1);
52   SmallVector<int64_t, 8> result =
53       output.postMultiplyWithColumn(pointHomogenous);
54   assert(result.size() == getNumOutputs());
55   return result;
56 }
57 
58 Optional<SmallVector<int64_t, 8>>
59 PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
60   assert(point.size() == getNumInputs() &&
61          "Point has incorrect dimensionality!");
62   for (const MultiAffineFunction &piece : pieces)
63     if (Optional<SmallVector<int64_t, 8>> output = piece.valueAt(point))
64       return output;
65   return {};
66 }
67 
68 void MultiAffineFunction::print(raw_ostream &os) const {
69   os << "Domain:";
70   IntegerPolyhedron::print(os);
71   os << "Output:\n";
72   output.print(os);
73   os << "\n";
74 }
75 
76 void MultiAffineFunction::dump() const { print(llvm::errs()); }
77 
78 bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
79   return hasCompatibleDimensions(other) &&
80          getDomain().isEqual(other.getDomain()) &&
81          isEqualWhereDomainsOverlap(other);
82 }
83 
84 unsigned MultiAffineFunction::insertId(IdKind kind, unsigned pos,
85                                        unsigned num) {
86   unsigned absolutePos = getIdKindOffset(kind) + pos;
87   output.insertColumns(absolutePos, num);
88   return IntegerPolyhedron::insertId(kind, pos, num);
89 }
90 
91 void MultiAffineFunction::swapId(unsigned posA, unsigned posB) {
92   output.swapColumns(posA, posB);
93   IntegerPolyhedron::swapId(posA, posB);
94 }
95 
96 void MultiAffineFunction::removeIdRange(unsigned idStart, unsigned idLimit) {
97   output.removeColumns(idStart, idLimit - idStart);
98   IntegerPolyhedron::removeIdRange(idStart, idLimit);
99 }
100 
101 void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA,
102                                                     unsigned posB) {
103   output.addToColumn(posB, posA, /*scale=*/1);
104   IntegerPolyhedron::eliminateRedundantLocalId(posA, posB);
105 }
106 
107 bool MultiAffineFunction::isEqualWhereDomainsOverlap(
108     MultiAffineFunction other) const {
109   if (!hasCompatibleDimensions(other))
110     return false;
111 
112   // `commonFunc` has the same output as `this`.
113   MultiAffineFunction commonFunc = *this;
114   // After this merge, `commonFunc` and `other` have the same local ids; they
115   // are merged.
116   commonFunc.mergeLocalIds(other);
117   // After this, the domain of `commonFunc` will be the intersection of the
118   // domains of `this` and `other`.
119   commonFunc.IntegerPolyhedron::append(other);
120 
121   // `commonDomainMatching` contains the subset of the common domain
122   // where the outputs of `this` and `other` match.
123   //
124   // We want to add constraints equating the outputs of `this` and `other`.
125   // However, `this` may have difference local ids from `other`, whereas we
126   // need both to have the same locals. Accordingly, we use `commonFunc.output`
127   // in place of `this->output`, since `commonFunc` has the same output but also
128   // has its locals merged.
129   IntegerPolyhedron commonDomainMatching = commonFunc.getDomain();
130   for (unsigned row = 0, e = getNumOutputs(); row < e; ++row)
131     commonDomainMatching.addEquality(
132         subtract(commonFunc.output.getRow(row), other.output.getRow(row)));
133 
134   // If the whole common domain is a subset of commonDomainMatching, then they
135   // are equal and the two functions match on the whole common domain.
136   return commonFunc.getDomain().isSubsetOf(commonDomainMatching);
137 }
138 
139 /// Two PWMAFunctions are equal if they have the same dimensionalities,
140 /// the same domain, and take the same value at every point in the domain.
141 bool PWMAFunction::isEqual(const PWMAFunction &other) const {
142   if (!hasCompatibleDimensions(other))
143     return false;
144 
145   if (!this->getDomain().isEqual(other.getDomain()))
146     return false;
147 
148   // Check if, whenever the domains of a piece of `this` and a piece of `other`
149   // overlap, they take the same output value. If `this` and `other` have the
150   // same domain (checked above), then this check passes iff the two functions
151   // have the same output at every point in the domain.
152   for (const MultiAffineFunction &aPiece : this->pieces)
153     for (const MultiAffineFunction &bPiece : other.pieces)
154       if (!aPiece.isEqualWhereDomainsOverlap(bPiece))
155         return false;
156   return true;
157 }
158 
159 void PWMAFunction::addPiece(const MultiAffineFunction &piece) {
160   assert(hasCompatibleDimensions(piece) &&
161          "Piece to be added is not compatible with this PWMAFunction!");
162   assert(piece.isConsistent() && "Piece is internally inconsistent!");
163   assert(this->getDomain()
164              .intersect(PresburgerSet(piece.getDomain()))
165              .isIntegerEmpty() &&
166          "New piece's domain overlaps with that of existing pieces!");
167   pieces.push_back(piece);
168 }
169 
170 void PWMAFunction::addPiece(const IntegerPolyhedron &domain,
171                             const Matrix &output) {
172   addPiece(MultiAffineFunction(domain, output));
173 }
174 
175 void PWMAFunction::print(raw_ostream &os) const {
176   os << pieces.size() << " pieces:\n";
177   for (const MultiAffineFunction &piece : pieces)
178     piece.print(os);
179 }
180 
181 /// The hasCompatibleDimensions functions don't check the number of local ids;
182 /// functions are still compatible if they have differing number of locals.
183 bool MultiAffineFunction::hasCompatibleDimensions(
184     const MultiAffineFunction &f) const {
185   return getNumDimIds() == f.getNumDimIds() &&
186          getNumSymbolIds() == f.getNumSymbolIds() &&
187          getNumOutputs() == f.getNumOutputs();
188 }
189 bool PWMAFunction::hasCompatibleDimensions(const MultiAffineFunction &f) const {
190   return getNumDimIds() == f.getNumDimIds() &&
191          getNumSymbolIds() == f.getNumSymbolIds() &&
192          getNumOutputs() == f.getNumOutputs();
193 }
194 bool PWMAFunction::hasCompatibleDimensions(const PWMAFunction &f) const {
195   return getNumDimIds() == f.getNumDimIds() &&
196          getNumSymbolIds() == f.getNumSymbolIds() &&
197          getNumOutputs() == f.getNumOutputs();
198 }
199