1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 11 #include "mlir/IR/Operation.h" 12 #include "llvm/Support/Debug.h" 13 14 namespace mlir { 15 namespace sparse_tensor { 16 17 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) { 18 unsigned e = tensorExps.size(); 19 tensorExps.push_back(TensorExp(k, e0, e1, v)); 20 return e; 21 } 22 23 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { 24 assert(t < numTensors && i < numLoops); 25 unsigned p = latPoints.size(); 26 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 27 return p; 28 } 29 30 unsigned Merger::addSet() { 31 unsigned s = latSets.size(); 32 latSets.emplace_back(SmallVector<unsigned, 16>()); 33 return s; 34 } 35 36 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) { 37 unsigned p = latPoints.size(); 38 llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits); 39 nb |= latPoints[p1].bits; 40 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); 41 latPoints.push_back(LatPoint(nb, e)); 42 return p; 43 } 44 45 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) { 46 unsigned s = addSet(); 47 for (unsigned p0 : latSets[s0]) 48 for (unsigned p1 : latSets[s1]) 49 latSets[s].push_back(conjLatPoint(kind, p0, p1)); 50 return s; 51 } 52 53 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) { 54 unsigned s = takeConj(kind, s0, s1); 55 for (unsigned p : latSets[s0]) 56 latSets[s].push_back(p); 57 for (unsigned p : latSets[s1]) 58 latSets[s].push_back(p); 59 return s; 60 } 61 62 unsigned Merger::optimizeSet(unsigned s0) { 63 unsigned s = addSet(); 64 assert(latSets[s0].size() != 0); 65 unsigned p0 = latSets[s0][0]; 66 for (unsigned p1 : latSets[s0]) { 67 bool add = true; 68 if (p0 != p1) { 69 // Is this a straightforward copy? 70 unsigned e = latPoints[p1].exp; 71 if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor) 72 continue; 73 // Conjunction already covered? 74 for (unsigned p2 : latSets[s]) { 75 assert(!latGT(p1, p2)); // Lj => Li would be bad 76 if (onlyDenseDiff(p2, p1)) { 77 add = false; 78 break; 79 } 80 } 81 assert(!add || latGT(p0, p1)); 82 } 83 if (add) 84 latSets[s].push_back(p1); 85 } 86 for (unsigned p : latSets[s]) 87 latPoints[p].simple = simplifyCond(s, p); 88 return s; 89 } 90 91 llvm::BitVector Merger::simplifyCond(unsigned s, unsigned p0) { 92 // First determine if this lattice point is a *singleton*, i.e., 93 // the last point in a lattice, no other is less than this one. 94 bool isSingleton = true; 95 for (unsigned p1 : latSets[s]) { 96 if (p0 != p1 && latGT(p0, p1)) { 97 isSingleton = false; 98 break; 99 } 100 } 101 // Now apply the two basic rules. 102 llvm::BitVector simple = latPoints[p0].bits; 103 bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse); 104 for (unsigned b = 0, be = simple.size(); b < be; b++) { 105 if (simple[b] && !isDim(b, Dim::kSparse)) { 106 if (reset) 107 simple.reset(b); 108 reset = true; 109 } 110 } 111 return simple; 112 } 113 114 bool Merger::latGT(unsigned i, unsigned j) const { 115 const llvm::BitVector &bitsi = latPoints[i].bits; 116 const llvm::BitVector &bitsj = latPoints[j].bits; 117 assert(bitsi.size() == bitsj.size()); 118 if (bitsi.count() > bitsj.count()) { 119 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 120 if (bitsj[b] && !bitsi[b]) 121 return false; 122 return true; 123 } 124 return false; 125 } 126 127 bool Merger::onlyDenseDiff(unsigned i, unsigned j) { 128 llvm::BitVector tmp = latPoints[j].bits; 129 tmp ^= latPoints[i].bits; 130 return !hasAnyDimOf(tmp, Dim::kSparse); 131 } 132 133 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const { 134 for (unsigned b = 0, be = bits.size(); b < be; b++) 135 if (bits[b] && isDim(b, d)) 136 return true; 137 return false; 138 } 139 140 unsigned Merger::buildLattices(unsigned e, unsigned idx) { 141 Kind kind = exp(e).kind; 142 if (kind == Kind::kTensor || kind == Kind::kInvariant) { 143 // Either the index is really used in the tensor expression, or it is 144 // set to the undefined index in that dimension. An invariant expression 145 // is set to a synthetic tensor with undefined indices only. 146 unsigned s = addSet(); 147 unsigned t = kind == Kind::kTensor ? exp(e).e0 : syntheticTensor; 148 set(s).push_back(addLat(t, idx, e)); 149 return s; 150 } 151 unsigned s0 = buildLattices(exp(e).e0, idx); 152 unsigned s1 = buildLattices(exp(e).e1, idx); 153 switch (kind) { 154 case Kind::kTensor: 155 case Kind::kInvariant: 156 llvm_unreachable("handled above"); 157 case Kind::kMulF: 158 case Kind::kMulI: 159 return takeConj(kind, s0, s1); 160 case Kind::kAddF: 161 case Kind::kAddI: 162 return takeDisj(kind, s0, s1); 163 } 164 llvm_unreachable("unexpected expression kind"); 165 } 166 167 #ifndef NDEBUG 168 169 // 170 // Print methods (for debugging). 171 // 172 173 void Merger::dumpExp(unsigned e) const { 174 switch (tensorExps[e].kind) { 175 case Kind::kTensor: 176 llvm::dbgs() << "tensor_" << tensorExps[e].e0; 177 break; 178 case Kind::kInvariant: 179 llvm::dbgs() << "invariant"; 180 break; 181 default: 182 case Kind::kMulI: 183 llvm::dbgs() << "("; 184 dumpExp(tensorExps[e].e0); 185 llvm::dbgs() << " * "; 186 dumpExp(tensorExps[e].e1); 187 llvm::dbgs() << ")"; 188 break; 189 case Kind::kAddF: 190 case Kind::kAddI: 191 llvm::dbgs() << "("; 192 dumpExp(tensorExps[e].e0); 193 llvm::dbgs() << " + "; 194 dumpExp(tensorExps[e].e1); 195 llvm::dbgs() << ")"; 196 break; 197 } 198 } 199 200 void Merger::dumpLat(unsigned p) const { 201 llvm::dbgs() << "lat("; 202 dumpBits(latPoints[p].bits); 203 llvm::dbgs() << " :"; 204 dumpBits(latPoints[p].simple); 205 llvm::dbgs() << " / "; 206 dumpExp(latPoints[p].exp); 207 llvm::dbgs() << " )\n"; 208 } 209 210 void Merger::dumpSet(unsigned s) const { 211 llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; 212 for (unsigned p : latSets[s]) { 213 llvm::dbgs() << " "; 214 dumpLat(p); 215 } 216 llvm::dbgs() << "}\n"; 217 } 218 219 void Merger::dumpBits(const llvm::BitVector &bits) const { 220 for (unsigned b = 0, be = bits.size(); b < be; b++) { 221 if (bits[b]) { 222 unsigned t = tensor(b); 223 unsigned i = index(b); 224 llvm::dbgs() << " i_" << t << "_" << i << "_"; 225 switch (dims[t][i]) { 226 case Dim::kSparse: 227 llvm::dbgs() << "S"; 228 break; 229 case Dim::kDense: 230 llvm::dbgs() << "D"; 231 break; 232 case Dim::kSingle: 233 llvm::dbgs() << "T"; 234 break; 235 case Dim::kUndef: 236 llvm::dbgs() << "U"; 237 break; 238 } 239 } 240 } 241 } 242 243 #endif // NDEBUG 244 245 } // namespace sparse_tensor 246 } // namespace mlir 247