1 //===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/Presburger/PWMAFunction.h" 10 #include "mlir/Analysis/Presburger/Simplex.h" 11 12 using namespace mlir; 13 using namespace presburger; 14 15 // Return the result of subtracting the two given vectors pointwise. 16 // The vectors must be of the same size. 17 // e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5]. 18 static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA, 19 ArrayRef<int64_t> vecB) { 20 assert(vecA.size() == vecB.size() && 21 "Cannot subtract vectors of differing lengths!"); 22 SmallVector<int64_t, 8> result; 23 result.reserve(vecA.size()); 24 for (unsigned i = 0, e = vecA.size(); i < e; ++i) 25 result.push_back(vecA[i] - vecB[i]); 26 return result; 27 } 28 29 PresburgerSet PWMAFunction::getDomain() const { 30 PresburgerSet domain = 31 PresburgerSet::getEmpty(getNumDimIds(), getNumSymbolIds()); 32 for (const MultiAffineFunction &piece : pieces) 33 domain.unionInPlace(piece.getDomain()); 34 return domain; 35 } 36 37 Optional<SmallVector<int64_t, 8>> 38 MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const { 39 assert(point.size() == getNumDimAndSymbolIds() && 40 "Point has incorrect dimensionality!"); 41 42 Optional<SmallVector<int64_t, 8>> maybeLocalValues = 43 getDomain().containsPointNoLocal(point); 44 if (!maybeLocalValues) 45 return {}; 46 47 // The point lies in the domain, so we need to compute the output value. 48 SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)}; 49 // The given point didn't include the values of locals which the output is a 50 // function of; we have computed one possible set of values and use them 51 // here. The function is not allowed to have local ids that take more than 52 // one possible value. 53 pointHomogenous.append(*maybeLocalValues); 54 // The matrix `output` has an affine expression in the ith row, corresponding 55 // to the expression for the ith value in the output vector. The last column 56 // of the matrix contains the constant term. Let v be the input point with 57 // a 1 appended at the end. We can see that output * v gives the desired 58 // output vector. 59 pointHomogenous.push_back(1); 60 SmallVector<int64_t, 8> result = 61 output.postMultiplyWithColumn(pointHomogenous); 62 assert(result.size() == getNumOutputs()); 63 return result; 64 } 65 66 Optional<SmallVector<int64_t, 8>> 67 PWMAFunction::valueAt(ArrayRef<int64_t> point) const { 68 assert(point.size() == getNumInputs() && 69 "Point has incorrect dimensionality!"); 70 for (const MultiAffineFunction &piece : pieces) 71 if (Optional<SmallVector<int64_t, 8>> output = piece.valueAt(point)) 72 return output; 73 return {}; 74 } 75 76 void MultiAffineFunction::print(raw_ostream &os) const { 77 os << "Domain:"; 78 IntegerPolyhedron::print(os); 79 os << "Output:\n"; 80 output.print(os); 81 os << "\n"; 82 } 83 84 void MultiAffineFunction::dump() const { print(llvm::errs()); } 85 86 bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const { 87 return isSpaceCompatible(other) && getDomain().isEqual(other.getDomain()) && 88 isEqualWhereDomainsOverlap(other); 89 } 90 91 unsigned MultiAffineFunction::insertId(IdKind kind, unsigned pos, 92 unsigned num) { 93 assert((kind != IdKind::Domain || num == 0) && 94 "Domain has to be zero in a set"); 95 unsigned absolutePos = getIdKindOffset(kind) + pos; 96 output.insertColumns(absolutePos, num); 97 return IntegerPolyhedron::insertId(kind, pos, num); 98 } 99 100 void MultiAffineFunction::swapId(unsigned posA, unsigned posB) { 101 output.swapColumns(posA, posB); 102 IntegerPolyhedron::swapId(posA, posB); 103 } 104 105 void MultiAffineFunction::removeIdRange(IdKind kind, unsigned idStart, 106 unsigned idLimit) { 107 output.removeColumns(idStart + getIdKindOffset(kind), idLimit - idStart); 108 IntegerPolyhedron::removeIdRange(kind, idStart, idLimit); 109 } 110 111 void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA, 112 unsigned posB) { 113 output.addToColumn(posB, posA, /*scale=*/1); 114 IntegerPolyhedron::eliminateRedundantLocalId(posA, posB); 115 } 116 117 bool MultiAffineFunction::isEqualWhereDomainsOverlap( 118 MultiAffineFunction other) const { 119 if (!isSpaceCompatible(other)) 120 return false; 121 122 // `commonFunc` has the same output as `this`. 123 MultiAffineFunction commonFunc = *this; 124 // After this merge, `commonFunc` and `other` have the same local ids; they 125 // are merged. 126 commonFunc.mergeLocalIds(other); 127 // After this, the domain of `commonFunc` will be the intersection of the 128 // domains of `this` and `other`. 129 commonFunc.IntegerPolyhedron::append(other); 130 131 // `commonDomainMatching` contains the subset of the common domain 132 // where the outputs of `this` and `other` match. 133 // 134 // We want to add constraints equating the outputs of `this` and `other`. 135 // However, `this` may have difference local ids from `other`, whereas we 136 // need both to have the same locals. Accordingly, we use `commonFunc.output` 137 // in place of `this->output`, since `commonFunc` has the same output but also 138 // has its locals merged. 139 IntegerPolyhedron commonDomainMatching = commonFunc.getDomain(); 140 for (unsigned row = 0, e = getNumOutputs(); row < e; ++row) 141 commonDomainMatching.addEquality( 142 subtract(commonFunc.output.getRow(row), other.output.getRow(row))); 143 144 // If the whole common domain is a subset of commonDomainMatching, then they 145 // are equal and the two functions match on the whole common domain. 146 return commonFunc.getDomain().isSubsetOf(commonDomainMatching); 147 } 148 149 /// Two PWMAFunctions are equal if they have the same dimensionalities, 150 /// the same domain, and take the same value at every point in the domain. 151 bool PWMAFunction::isEqual(const PWMAFunction &other) const { 152 if (!isSpaceCompatible(other)) 153 return false; 154 155 if (!this->getDomain().isEqual(other.getDomain())) 156 return false; 157 158 // Check if, whenever the domains of a piece of `this` and a piece of `other` 159 // overlap, they take the same output value. If `this` and `other` have the 160 // same domain (checked above), then this check passes iff the two functions 161 // have the same output at every point in the domain. 162 for (const MultiAffineFunction &aPiece : this->pieces) 163 for (const MultiAffineFunction &bPiece : other.pieces) 164 if (!aPiece.isEqualWhereDomainsOverlap(bPiece)) 165 return false; 166 return true; 167 } 168 169 void PWMAFunction::addPiece(const MultiAffineFunction &piece) { 170 assert(piece.isSpaceCompatible(*this) && 171 "Piece to be added is not compatible with this PWMAFunction!"); 172 assert(piece.isConsistent() && "Piece is internally inconsistent!"); 173 assert(this->getDomain() 174 .intersect(PresburgerSet(piece.getDomain())) 175 .isIntegerEmpty() && 176 "New piece's domain overlaps with that of existing pieces!"); 177 pieces.push_back(piece); 178 } 179 180 void PWMAFunction::addPiece(const IntegerPolyhedron &domain, 181 const Matrix &output) { 182 addPiece(MultiAffineFunction(domain, output)); 183 } 184 185 void PWMAFunction::print(raw_ostream &os) const { 186 os << pieces.size() << " pieces:\n"; 187 for (const MultiAffineFunction &piece : pieces) 188 piece.print(os); 189 } 190 191 void PWMAFunction::dump() const { print(llvm::errs()); } 192