1 //===- PresburgerSpace.cpp - MLIR PresburgerSpace Class -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Analysis/Presburger/PresburgerSpace.h"
10 #include <algorithm>
11 #include <cassert>
12 
13 using namespace mlir;
14 using namespace presburger;
15 
getNumVarKind(VarKind kind) const16 unsigned PresburgerSpace::getNumVarKind(VarKind kind) const {
17   if (kind == VarKind::Domain)
18     return getNumDomainVars();
19   if (kind == VarKind::Range)
20     return getNumRangeVars();
21   if (kind == VarKind::Symbol)
22     return getNumSymbolVars();
23   if (kind == VarKind::Local)
24     return numLocals;
25   llvm_unreachable("VarKind does not exist!");
26 }
27 
getVarKindOffset(VarKind kind) const28 unsigned PresburgerSpace::getVarKindOffset(VarKind kind) const {
29   if (kind == VarKind::Domain)
30     return 0;
31   if (kind == VarKind::Range)
32     return getNumDomainVars();
33   if (kind == VarKind::Symbol)
34     return getNumDimVars();
35   if (kind == VarKind::Local)
36     return getNumDimAndSymbolVars();
37   llvm_unreachable("VarKind does not exist!");
38 }
39 
getVarKindEnd(VarKind kind) const40 unsigned PresburgerSpace::getVarKindEnd(VarKind kind) const {
41   return getVarKindOffset(kind) + getNumVarKind(kind);
42 }
43 
getVarKindOverlap(VarKind kind,unsigned varStart,unsigned varLimit) const44 unsigned PresburgerSpace::getVarKindOverlap(VarKind kind, unsigned varStart,
45                                             unsigned varLimit) const {
46   unsigned varRangeStart = getVarKindOffset(kind);
47   unsigned varRangeEnd = getVarKindEnd(kind);
48 
49   // Compute number of elements in intersection of the ranges [varStart,
50   // varLimit) and [varRangeStart, varRangeEnd).
51   unsigned overlapStart = std::max(varStart, varRangeStart);
52   unsigned overlapEnd = std::min(varLimit, varRangeEnd);
53 
54   if (overlapStart > overlapEnd)
55     return 0;
56   return overlapEnd - overlapStart;
57 }
58 
getVarKindAt(unsigned pos) const59 VarKind PresburgerSpace::getVarKindAt(unsigned pos) const {
60   assert(pos < getNumVars() && "`pos` should represent a valid var position");
61   if (pos < getVarKindEnd(VarKind::Domain))
62     return VarKind::Domain;
63   if (pos < getVarKindEnd(VarKind::Range))
64     return VarKind::Range;
65   if (pos < getVarKindEnd(VarKind::Symbol))
66     return VarKind::Symbol;
67   if (pos < getVarKindEnd(VarKind::Local))
68     return VarKind::Local;
69   llvm_unreachable("`pos` should represent a valid var position");
70 }
71 
insertVar(VarKind kind,unsigned pos,unsigned num)72 unsigned PresburgerSpace::insertVar(VarKind kind, unsigned pos, unsigned num) {
73   assert(pos <= getNumVarKind(kind));
74 
75   unsigned absolutePos = getVarKindOffset(kind) + pos;
76 
77   if (kind == VarKind::Domain)
78     numDomain += num;
79   else if (kind == VarKind::Range)
80     numRange += num;
81   else if (kind == VarKind::Symbol)
82     numSymbols += num;
83   else
84     numLocals += num;
85 
86   // Insert NULL identifiers if `usingIds` and variables inserted are
87   // not locals.
88   if (usingIds && kind != VarKind::Local)
89     identifiers.insert(identifiers.begin() + absolutePos, num, nullptr);
90 
91   return absolutePos;
92 }
93 
removeVarRange(VarKind kind,unsigned varStart,unsigned varLimit)94 void PresburgerSpace::removeVarRange(VarKind kind, unsigned varStart,
95                                      unsigned varLimit) {
96   assert(varLimit <= getNumVarKind(kind) && "invalid var limit");
97 
98   if (varStart >= varLimit)
99     return;
100 
101   unsigned numVarsEliminated = varLimit - varStart;
102   if (kind == VarKind::Domain)
103     numDomain -= numVarsEliminated;
104   else if (kind == VarKind::Range)
105     numRange -= numVarsEliminated;
106   else if (kind == VarKind::Symbol)
107     numSymbols -= numVarsEliminated;
108   else
109     numLocals -= numVarsEliminated;
110 
111   // Remove identifiers if `usingIds` and variables removed are not
112   // locals.
113   if (usingIds && kind != VarKind::Local)
114     identifiers.erase(identifiers.begin() + getVarKindOffset(kind) + varStart,
115                       identifiers.begin() + getVarKindOffset(kind) + varLimit);
116 }
117 
swapVar(VarKind kindA,VarKind kindB,unsigned posA,unsigned posB)118 void PresburgerSpace::swapVar(VarKind kindA, VarKind kindB, unsigned posA,
119                               unsigned posB) {
120 
121   if (!usingIds)
122     return;
123 
124   if (kindA == VarKind::Local && kindB == VarKind::Local)
125     return;
126 
127   if (kindA == VarKind::Local) {
128     atId(kindB, posB) = nullptr;
129     return;
130   }
131 
132   if (kindB == VarKind::Local) {
133     atId(kindA, posA) = nullptr;
134     return;
135   }
136 
137   std::swap(atId(kindA, posA), atId(kindB, posB));
138 }
139 
isCompatible(const PresburgerSpace & other) const140 bool PresburgerSpace::isCompatible(const PresburgerSpace &other) const {
141   return getNumDomainVars() == other.getNumDomainVars() &&
142          getNumRangeVars() == other.getNumRangeVars() &&
143          getNumSymbolVars() == other.getNumSymbolVars();
144 }
145 
isEqual(const PresburgerSpace & other) const146 bool PresburgerSpace::isEqual(const PresburgerSpace &other) const {
147   return isCompatible(other) && getNumLocalVars() == other.getNumLocalVars();
148 }
149 
isAligned(const PresburgerSpace & other) const150 bool PresburgerSpace::isAligned(const PresburgerSpace &other) const {
151   assert(isUsingIds() && other.isUsingIds() &&
152          "Both spaces should be using identifiers to check for "
153          "alignment.");
154   return isCompatible(other) && identifiers == other.identifiers;
155 }
156 
isAligned(const PresburgerSpace & other,VarKind kind) const157 bool PresburgerSpace::isAligned(const PresburgerSpace &other,
158                                 VarKind kind) const {
159   assert(isUsingIds() && other.isUsingIds() &&
160          "Both spaces should be using identifiers to check for "
161          "alignment.");
162 
163   ArrayRef<void *> kindAttachments =
164       makeArrayRef(identifiers)
165           .slice(getVarKindOffset(kind), getNumVarKind(kind));
166   ArrayRef<void *> otherKindAttachments =
167       makeArrayRef(other.identifiers)
168           .slice(other.getVarKindOffset(kind), other.getNumVarKind(kind));
169   return kindAttachments == otherKindAttachments;
170 }
171 
setVarSymbolSeperation(unsigned newSymbolCount)172 void PresburgerSpace::setVarSymbolSeperation(unsigned newSymbolCount) {
173   assert(newSymbolCount <= getNumDimAndSymbolVars() &&
174          "invalid separation position");
175   numRange = numRange + numSymbols - newSymbolCount;
176   numSymbols = newSymbolCount;
177   // We do not need to change `identifiers` since the ordering of
178   // `identifiers` remains same.
179 }
180 
print(llvm::raw_ostream & os) const181 void PresburgerSpace::print(llvm::raw_ostream &os) const {
182   os << "Domain: " << getNumDomainVars() << ", "
183      << "Range: " << getNumRangeVars() << ", "
184      << "Symbols: " << getNumSymbolVars() << ", "
185      << "Locals: " << getNumLocalVars() << "\n";
186 
187   if (usingIds) {
188 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
189     os << "TypeID of identifiers: " << idType.getAsOpaquePointer() << "\n";
190 #endif
191 
192     os << "(";
193     for (void *identifier : identifiers)
194       os << identifier << " ";
195     os << ")\n";
196   }
197 }
198 
dump() const199 void PresburgerSpace::dump() const { print(llvm::errs()); }
200