1 //===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/Presburger/PWMAFunction.h" 10 #include "mlir/Analysis/Presburger/Simplex.h" 11 12 using namespace mlir; 13 14 // Return the result of subtracting the two given vectors pointwise. 15 // The vectors must be of the same size. 16 // e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5]. 17 static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA, 18 ArrayRef<int64_t> vecB) { 19 assert(vecA.size() == vecB.size() && 20 "Cannot subtract vectors of differing lengths!"); 21 SmallVector<int64_t, 8> result; 22 result.reserve(vecA.size()); 23 for (unsigned i = 0, e = vecA.size(); i < e; ++i) 24 result.push_back(vecA[i] - vecB[i]); 25 return result; 26 } 27 28 PresburgerSet PWMAFunction::getDomain() const { 29 PresburgerSet domain = 30 PresburgerSet::getEmptySet(getNumDimIds(), getNumSymbolIds()); 31 for (const MultiAffineFunction &piece : pieces) 32 domain.unionPolyInPlace(piece.getDomain()); 33 return domain; 34 } 35 36 Optional<SmallVector<int64_t, 8>> 37 MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const { 38 assert(getNumLocalIds() == 0 && "Local ids are not yet supported!"); 39 assert(point.size() == getNumIds() && "Point has incorrect dimensionality!"); 40 41 if (!getDomain().containsPoint(point)) 42 return {}; 43 44 // The point lies in the domain, so we need to compute the output value. 45 // The matrix `output` has an affine expression in the ith row, corresponding 46 // to the expression for the ith value in the output vector. The last column 47 // of the matrix contains the constant term. Let v be the input point with 48 // a 1 appended at the end. We can see that output * v gives the desired 49 // output vector. 50 SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)}; 51 pointHomogenous.push_back(1); 52 SmallVector<int64_t, 8> result = 53 output.postMultiplyWithColumn(pointHomogenous); 54 assert(result.size() == getNumOutputs()); 55 return result; 56 } 57 58 Optional<SmallVector<int64_t, 8>> 59 PWMAFunction::valueAt(ArrayRef<int64_t> point) const { 60 assert(point.size() == getNumInputs() && 61 "Point has incorrect dimensionality!"); 62 for (const MultiAffineFunction &piece : pieces) 63 if (Optional<SmallVector<int64_t, 8>> output = piece.valueAt(point)) 64 return output; 65 return {}; 66 } 67 68 void MultiAffineFunction::print(raw_ostream &os) const { 69 os << "Domain:"; 70 IntegerPolyhedron::print(os); 71 os << "Output:\n"; 72 output.print(os); 73 os << "\n"; 74 } 75 76 void MultiAffineFunction::dump() const { print(llvm::errs()); } 77 78 bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const { 79 return hasCompatibleDimensions(other) && 80 getDomain().isEqual(other.getDomain()) && 81 isEqualWhereDomainsOverlap(other); 82 } 83 84 unsigned MultiAffineFunction::insertId(IdKind kind, unsigned pos, 85 unsigned num) { 86 unsigned absolutePos = getIdKindOffset(kind) + pos; 87 output.insertColumns(absolutePos, num); 88 return IntegerPolyhedron::insertId(kind, pos, num); 89 } 90 91 void MultiAffineFunction::swapId(unsigned posA, unsigned posB) { 92 output.swapColumns(posA, posB); 93 IntegerPolyhedron::swapId(posA, posB); 94 } 95 96 void MultiAffineFunction::removeIdRange(unsigned idStart, unsigned idLimit) { 97 output.removeColumns(idStart, idLimit - idStart); 98 IntegerPolyhedron::removeIdRange(idStart, idLimit); 99 } 100 101 void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA, 102 unsigned posB) { 103 output.addToColumn(posB, posA, /*scale=*/1); 104 IntegerPolyhedron::eliminateRedundantLocalId(posA, posB); 105 } 106 107 bool MultiAffineFunction::isEqualWhereDomainsOverlap( 108 MultiAffineFunction other) const { 109 if (!hasCompatibleDimensions(other)) 110 return false; 111 112 // `commonFunc` has the same output as `this`. 113 MultiAffineFunction commonFunc = *this; 114 // After this merge, `commonFunc` and `other` have the same local ids; they 115 // are merged. 116 commonFunc.mergeLocalIds(other); 117 // After this, the domain of `commonFunc` will be the intersection of the 118 // domains of `this` and `other`. 119 commonFunc.IntegerPolyhedron::append(other); 120 121 // `commonDomainMatching` contains the subset of the common domain 122 // where the outputs of `this` and `other` match. 123 // 124 // We want to add constraints equating the outputs of `this` and `other`. 125 // However, `this` may have difference local ids from `other`, whereas we 126 // need both to have the same locals. Accordingly, we use `commonFunc.output` 127 // in place of `this->output`, since `commonFunc` has the same output but also 128 // has its locals merged. 129 IntegerPolyhedron commonDomainMatching = commonFunc.getDomain(); 130 for (unsigned row = 0, e = getNumOutputs(); row < e; ++row) 131 commonDomainMatching.addEquality( 132 subtract(commonFunc.output.getRow(row), other.output.getRow(row))); 133 134 // If the whole common domain is a subset of commonDomainMatching, then they 135 // are equal and the two functions match on the whole common domain. 136 return commonFunc.getDomain().isSubsetOf(commonDomainMatching); 137 } 138 139 /// Two PWMAFunctions are equal if they have the same dimensionalities, 140 /// the same domain, and take the same value at every point in the domain. 141 bool PWMAFunction::isEqual(const PWMAFunction &other) const { 142 if (!hasCompatibleDimensions(other)) 143 return false; 144 145 if (!this->getDomain().isEqual(other.getDomain())) 146 return false; 147 148 // Check if, whenever the domains of a piece of `this` and a piece of `other` 149 // overlap, they take the same output value. If `this` and `other` have the 150 // same domain (checked above), then this check passes iff the two functions 151 // have the same output at every point in the domain. 152 for (const MultiAffineFunction &aPiece : this->pieces) 153 for (const MultiAffineFunction &bPiece : other.pieces) 154 if (!aPiece.isEqualWhereDomainsOverlap(bPiece)) 155 return false; 156 return true; 157 } 158 159 void PWMAFunction::addPiece(const MultiAffineFunction &piece) { 160 assert(hasCompatibleDimensions(piece) && 161 "Piece to be added is not compatible with this PWMAFunction!"); 162 assert(piece.isConsistent() && "Piece is internally inconsistent!"); 163 assert(this->getDomain() 164 .intersect(PresburgerSet(piece.getDomain())) 165 .isIntegerEmpty() && 166 "New piece's domain overlaps with that of existing pieces!"); 167 pieces.push_back(piece); 168 } 169 170 void PWMAFunction::addPiece(const IntegerPolyhedron &domain, 171 const Matrix &output) { 172 addPiece(MultiAffineFunction(domain, output)); 173 } 174 175 void PWMAFunction::print(raw_ostream &os) const { 176 os << pieces.size() << " pieces:\n"; 177 for (const MultiAffineFunction &piece : pieces) 178 piece.print(os); 179 } 180 181 /// The hasCompatibleDimensions functions don't check the number of local ids; 182 /// functions are still compatible if they have differing number of locals. 183 bool MultiAffineFunction::hasCompatibleDimensions( 184 const MultiAffineFunction &f) const { 185 return getNumDimIds() == f.getNumDimIds() && 186 getNumSymbolIds() == f.getNumSymbolIds() && 187 getNumOutputs() == f.getNumOutputs(); 188 } 189 bool PWMAFunction::hasCompatibleDimensions(const MultiAffineFunction &f) const { 190 return getNumDimIds() == f.getNumDimIds() && 191 getNumSymbolIds() == f.getNumSymbolIds() && 192 getNumOutputs() == f.getNumOutputs(); 193 } 194 bool PWMAFunction::hasCompatibleDimensions(const PWMAFunction &f) const { 195 return getNumDimIds() == f.getNumDimIds() && 196 getNumSymbolIds() == f.getNumSymbolIds() && 197 getNumOutputs() == f.getNumOutputs(); 198 } 199