1 //===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/Presburger/PWMAFunction.h"
10 #include "mlir/Analysis/Presburger/Simplex.h"
11 
12 using namespace mlir;
13 using namespace presburger;
14 
15 // Return the result of subtracting the two given vectors pointwise.
16 // The vectors must be of the same size.
17 // e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
18 static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
19                                         ArrayRef<int64_t> vecB) {
20   assert(vecA.size() == vecB.size() &&
21          "Cannot subtract vectors of differing lengths!");
22   SmallVector<int64_t, 8> result;
23   result.reserve(vecA.size());
24   for (unsigned i = 0, e = vecA.size(); i < e; ++i)
25     result.push_back(vecA[i] - vecB[i]);
26   return result;
27 }
28 
29 PresburgerSet PWMAFunction::getDomain() const {
30   PresburgerSet domain =
31       PresburgerSet::getEmpty(getNumDimIds(), getNumSymbolIds());
32   for (const MultiAffineFunction &piece : pieces)
33     domain.unionInPlace(piece.getDomain());
34   return domain;
35 }
36 
37 Optional<SmallVector<int64_t, 8>>
38 MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
39   assert(point.size() == getNumDimAndSymbolIds() &&
40          "Point has incorrect dimensionality!");
41 
42   Optional<SmallVector<int64_t, 8>> maybeLocalValues =
43       getDomain().containsPointNoLocal(point);
44   if (!maybeLocalValues)
45     return {};
46 
47   // The point lies in the domain, so we need to compute the output value.
48   SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
49   // The given point didn't include the values of locals which the output is a
50   // function of; we have computed one possible set of values and use them
51   // here. The function is not allowed to have local ids that take more than
52   // one possible value.
53   pointHomogenous.append(*maybeLocalValues);
54   // The matrix `output` has an affine expression in the ith row, corresponding
55   // to the expression for the ith value in the output vector. The last column
56   // of the matrix contains the constant term. Let v be the input point with
57   // a 1 appended at the end. We can see that output * v gives the desired
58   // output vector.
59   pointHomogenous.push_back(1);
60   SmallVector<int64_t, 8> result =
61       output.postMultiplyWithColumn(pointHomogenous);
62   assert(result.size() == getNumOutputs());
63   return result;
64 }
65 
66 Optional<SmallVector<int64_t, 8>>
67 PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
68   assert(point.size() == getNumInputs() &&
69          "Point has incorrect dimensionality!");
70   for (const MultiAffineFunction &piece : pieces)
71     if (Optional<SmallVector<int64_t, 8>> output = piece.valueAt(point))
72       return output;
73   return {};
74 }
75 
76 void MultiAffineFunction::print(raw_ostream &os) const {
77   os << "Domain:";
78   IntegerPolyhedron::print(os);
79   os << "Output:\n";
80   output.print(os);
81   os << "\n";
82 }
83 
84 void MultiAffineFunction::dump() const { print(llvm::errs()); }
85 
86 bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
87   return isSpaceCompatible(other) && getDomain().isEqual(other.getDomain()) &&
88          isEqualWhereDomainsOverlap(other);
89 }
90 
91 unsigned MultiAffineFunction::insertId(IdKind kind, unsigned pos,
92                                        unsigned num) {
93   assert((kind != IdKind::Domain || num == 0) &&
94          "Domain has to be zero in a set");
95   unsigned absolutePos = getIdKindOffset(kind) + pos;
96   output.insertColumns(absolutePos, num);
97   return IntegerPolyhedron::insertId(kind, pos, num);
98 }
99 
100 void MultiAffineFunction::swapId(unsigned posA, unsigned posB) {
101   output.swapColumns(posA, posB);
102   IntegerPolyhedron::swapId(posA, posB);
103 }
104 
105 void MultiAffineFunction::removeIdRange(IdKind kind, unsigned idStart,
106                                         unsigned idLimit) {
107   output.removeColumns(idStart + getIdKindOffset(kind), idLimit - idStart);
108   IntegerPolyhedron::removeIdRange(kind, idStart, idLimit);
109 }
110 
111 void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA,
112                                                     unsigned posB) {
113   output.addToColumn(posB, posA, /*scale=*/1);
114   IntegerPolyhedron::eliminateRedundantLocalId(posA, posB);
115 }
116 
117 bool MultiAffineFunction::isEqualWhereDomainsOverlap(
118     MultiAffineFunction other) const {
119   if (!isSpaceCompatible(other))
120     return false;
121 
122   // `commonFunc` has the same output as `this`.
123   MultiAffineFunction commonFunc = *this;
124   // After this merge, `commonFunc` and `other` have the same local ids; they
125   // are merged.
126   commonFunc.mergeLocalIds(other);
127   // After this, the domain of `commonFunc` will be the intersection of the
128   // domains of `this` and `other`.
129   commonFunc.IntegerPolyhedron::append(other);
130 
131   // `commonDomainMatching` contains the subset of the common domain
132   // where the outputs of `this` and `other` match.
133   //
134   // We want to add constraints equating the outputs of `this` and `other`.
135   // However, `this` may have difference local ids from `other`, whereas we
136   // need both to have the same locals. Accordingly, we use `commonFunc.output`
137   // in place of `this->output`, since `commonFunc` has the same output but also
138   // has its locals merged.
139   IntegerPolyhedron commonDomainMatching = commonFunc.getDomain();
140   for (unsigned row = 0, e = getNumOutputs(); row < e; ++row)
141     commonDomainMatching.addEquality(
142         subtract(commonFunc.output.getRow(row), other.output.getRow(row)));
143 
144   // If the whole common domain is a subset of commonDomainMatching, then they
145   // are equal and the two functions match on the whole common domain.
146   return commonFunc.getDomain().isSubsetOf(commonDomainMatching);
147 }
148 
149 /// Two PWMAFunctions are equal if they have the same dimensionalities,
150 /// the same domain, and take the same value at every point in the domain.
151 bool PWMAFunction::isEqual(const PWMAFunction &other) const {
152   if (!isSpaceCompatible(other))
153     return false;
154 
155   if (!this->getDomain().isEqual(other.getDomain()))
156     return false;
157 
158   // Check if, whenever the domains of a piece of `this` and a piece of `other`
159   // overlap, they take the same output value. If `this` and `other` have the
160   // same domain (checked above), then this check passes iff the two functions
161   // have the same output at every point in the domain.
162   for (const MultiAffineFunction &aPiece : this->pieces)
163     for (const MultiAffineFunction &bPiece : other.pieces)
164       if (!aPiece.isEqualWhereDomainsOverlap(bPiece))
165         return false;
166   return true;
167 }
168 
169 void PWMAFunction::addPiece(const MultiAffineFunction &piece) {
170   assert(piece.isSpaceCompatible(*this) &&
171          "Piece to be added is not compatible with this PWMAFunction!");
172   assert(piece.isConsistent() && "Piece is internally inconsistent!");
173   assert(this->getDomain()
174              .intersect(PresburgerSet(piece.getDomain()))
175              .isIntegerEmpty() &&
176          "New piece's domain overlaps with that of existing pieces!");
177   pieces.push_back(piece);
178 }
179 
180 void PWMAFunction::addPiece(const IntegerPolyhedron &domain,
181                             const Matrix &output) {
182   addPiece(MultiAffineFunction(domain, output));
183 }
184 
185 void PWMAFunction::print(raw_ostream &os) const {
186   os << pieces.size() << " pieces:\n";
187   for (const MultiAffineFunction &piece : pieces)
188     piece.print(os);
189 }
190 
191 void PWMAFunction::dump() const { print(llvm::errs()); }
192