1 //===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/Presburger/PWMAFunction.h"
10 #include "mlir/Analysis/Presburger/Simplex.h"
11 
12 using namespace mlir;
13 using namespace presburger;
14 
15 // Return the result of subtracting the two given vectors pointwise.
16 // The vectors must be of the same size.
17 // e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
18 static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
19                                         ArrayRef<int64_t> vecB) {
20   assert(vecA.size() == vecB.size() &&
21          "Cannot subtract vectors of differing lengths!");
22   SmallVector<int64_t, 8> result;
23   result.reserve(vecA.size());
24   for (unsigned i = 0, e = vecA.size(); i < e; ++i)
25     result.push_back(vecA[i] - vecB[i]);
26   return result;
27 }
28 
29 PresburgerSet PWMAFunction::getDomain() const {
30   PresburgerSet domain = PresburgerSet::getEmpty(getSpace());
31   for (const MultiAffineFunction &piece : pieces)
32     domain.unionInPlace(piece.getDomain());
33   return domain;
34 }
35 
36 Optional<SmallVector<int64_t, 8>>
37 MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
38   assert(point.size() == domainSet.getNumDimAndSymbolIds() &&
39          "Point has incorrect dimensionality!");
40 
41   Optional<SmallVector<int64_t, 8>> maybeLocalValues =
42       getDomain().containsPointNoLocal(point);
43   if (!maybeLocalValues)
44     return {};
45 
46   // The point lies in the domain, so we need to compute the output value.
47   SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
48   // The given point didn't include the values of locals which the output is a
49   // function of; we have computed one possible set of values and use them
50   // here. The function is not allowed to have local ids that take more than
51   // one possible value.
52   pointHomogenous.append(*maybeLocalValues);
53   // The matrix `output` has an affine expression in the ith row, corresponding
54   // to the expression for the ith value in the output vector. The last column
55   // of the matrix contains the constant term. Let v be the input point with
56   // a 1 appended at the end. We can see that output * v gives the desired
57   // output vector.
58   pointHomogenous.emplace_back(1);
59   SmallVector<int64_t, 8> result =
60       output.postMultiplyWithColumn(pointHomogenous);
61   assert(result.size() == getNumOutputs());
62   return result;
63 }
64 
65 Optional<SmallVector<int64_t, 8>>
66 PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
67   assert(point.size() == getNumInputs() &&
68          "Point has incorrect dimensionality!");
69   for (const MultiAffineFunction &piece : pieces)
70     if (Optional<SmallVector<int64_t, 8>> output = piece.valueAt(point))
71       return output;
72   return {};
73 }
74 
75 void MultiAffineFunction::print(raw_ostream &os) const {
76   os << "Domain:";
77   domainSet.print(os);
78   os << "Output:\n";
79   output.print(os);
80   os << "\n";
81 }
82 
83 void MultiAffineFunction::dump() const { print(llvm::errs()); }
84 
85 bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
86   return getDomainSpace().isCompatible(other.getDomainSpace()) &&
87          getDomain().isEqual(other.getDomain()) &&
88          isEqualWhereDomainsOverlap(other);
89 }
90 
91 unsigned MultiAffineFunction::insertId(IdKind kind, unsigned pos,
92                                        unsigned num) {
93   assert(kind != IdKind::Domain && "Domain has to be zero in a set");
94   unsigned absolutePos = domainSet.getIdKindOffset(kind) + pos;
95   output.insertColumns(absolutePos, num);
96   return domainSet.insertId(kind, pos, num);
97 }
98 
99 void MultiAffineFunction::removeIdRange(IdKind kind, unsigned idStart,
100                                         unsigned idLimit) {
101   output.removeColumns(idStart + domainSet.getIdKindOffset(kind),
102                        idLimit - idStart);
103   domainSet.removeIdRange(kind, idStart, idLimit);
104 }
105 
106 void MultiAffineFunction::truncateOutput(unsigned count) {
107   assert(count <= output.getNumRows());
108   output.resizeVertically(count);
109 }
110 
111 void PWMAFunction::truncateOutput(unsigned count) {
112   assert(count <= numOutputs);
113   for (MultiAffineFunction &piece : pieces)
114     piece.truncateOutput(count);
115   numOutputs = count;
116 }
117 
118 void MultiAffineFunction::mergeLocalIds(MultiAffineFunction &other) {
119   // Merge output local ids of both functions without using division
120   // information i.e. append local ids of `other` to `this` and insert
121   // local ids of `this` to `other` at the start of it's local ids.
122   output.insertColumns(domainSet.getIdKindEnd(IdKind::Local),
123                        other.domainSet.getNumLocalIds());
124   other.output.insertColumns(other.domainSet.getIdKindOffset(IdKind::Local),
125                              domainSet.getNumLocalIds());
126 
127   auto merge = [this, &other](unsigned i, unsigned j) -> bool {
128     // Merge local at position j into local at position i in function domain.
129     domainSet.eliminateRedundantLocalId(i, j);
130     other.domainSet.eliminateRedundantLocalId(i, j);
131 
132     unsigned localOffset = domainSet.getIdKindOffset(IdKind::Local);
133 
134     // Merge local at position j into local at position i in output domain.
135     output.addToColumn(localOffset + j, localOffset + i, 1);
136     output.removeColumn(localOffset + j);
137     other.output.addToColumn(localOffset + j, localOffset + i, 1);
138     other.output.removeColumn(localOffset + j);
139 
140     return true;
141   };
142 
143   presburger::mergeLocalIds(domainSet, other.domainSet, merge);
144 }
145 
146 bool MultiAffineFunction::isEqualWhereDomainsOverlap(
147     MultiAffineFunction other) const {
148   if (!getDomainSpace().isCompatible(other.getDomainSpace()))
149     return false;
150 
151   // `commonFunc` has the same output as `this`.
152   MultiAffineFunction commonFunc = *this;
153   // After this merge, `commonFunc` and `other` have the same local ids; they
154   // are merged.
155   commonFunc.mergeLocalIds(other);
156   // After this, the domain of `commonFunc` will be the intersection of the
157   // domains of `this` and `other`.
158   commonFunc.domainSet.append(other.domainSet);
159 
160   // `commonDomainMatching` contains the subset of the common domain
161   // where the outputs of `this` and `other` match.
162   //
163   // We want to add constraints equating the outputs of `this` and `other`.
164   // However, `this` may have difference local ids from `other`, whereas we
165   // need both to have the same locals. Accordingly, we use `commonFunc.output`
166   // in place of `this->output`, since `commonFunc` has the same output but also
167   // has its locals merged.
168   IntegerPolyhedron commonDomainMatching = commonFunc.getDomain();
169   for (unsigned row = 0, e = getNumOutputs(); row < e; ++row)
170     commonDomainMatching.addEquality(
171         subtract(commonFunc.output.getRow(row), other.output.getRow(row)));
172 
173   // If the whole common domain is a subset of commonDomainMatching, then they
174   // are equal and the two functions match on the whole common domain.
175   return commonFunc.getDomain().isSubsetOf(commonDomainMatching);
176 }
177 
178 /// Two PWMAFunctions are equal if they have the same dimensionalities,
179 /// the same domain, and take the same value at every point in the domain.
180 bool PWMAFunction::isEqual(const PWMAFunction &other) const {
181   if (!space.isCompatible(other.space))
182     return false;
183 
184   if (!this->getDomain().isEqual(other.getDomain()))
185     return false;
186 
187   // Check if, whenever the domains of a piece of `this` and a piece of `other`
188   // overlap, they take the same output value. If `this` and `other` have the
189   // same domain (checked above), then this check passes iff the two functions
190   // have the same output at every point in the domain.
191   for (const MultiAffineFunction &aPiece : this->pieces)
192     for (const MultiAffineFunction &bPiece : other.pieces)
193       if (!aPiece.isEqualWhereDomainsOverlap(bPiece))
194         return false;
195   return true;
196 }
197 
198 void PWMAFunction::addPiece(const MultiAffineFunction &piece) {
199   assert(space.isCompatible(piece.getDomainSpace()) &&
200          "Piece to be added is not compatible with this PWMAFunction!");
201   assert(piece.isConsistent() && "Piece is internally inconsistent!");
202   assert(this->getDomain()
203              .intersect(PresburgerSet(piece.getDomain()))
204              .isIntegerEmpty() &&
205          "New piece's domain overlaps with that of existing pieces!");
206   pieces.push_back(piece);
207 }
208 
209 void PWMAFunction::addPiece(const IntegerPolyhedron &domain,
210                             const Matrix &output) {
211   addPiece(MultiAffineFunction(domain, output));
212 }
213 
214 void PWMAFunction::print(raw_ostream &os) const {
215   os << pieces.size() << " pieces:\n";
216   for (const MultiAffineFunction &piece : pieces)
217     piece.print(os);
218 }
219 
220 void PWMAFunction::dump() const { print(llvm::errs()); }
221