1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 11 namespace mlir { 12 namespace sparse_tensor { 13 14 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) { 15 unsigned e = tensorExps.size(); 16 tensorExps.push_back(TensorExp(k, e0, e1, v)); 17 return e; 18 } 19 20 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { 21 assert(t < numTensors && i < numLoops); 22 unsigned p = latPoints.size(); 23 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 24 return p; 25 } 26 27 unsigned Merger::addSet() { 28 unsigned s = latSets.size(); 29 latSets.emplace_back(SmallVector<unsigned, 16>()); 30 return s; 31 } 32 33 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) { 34 unsigned p = latPoints.size(); 35 llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits); 36 nb |= latPoints[p1].bits; 37 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); 38 latPoints.push_back(LatPoint(nb, e)); 39 return p; 40 } 41 42 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) { 43 unsigned s = addSet(); 44 for (unsigned p0 : latSets[s0]) 45 for (unsigned p1 : latSets[s1]) 46 latSets[s].push_back(conjLatPoint(kind, p0, p1)); 47 return s; 48 } 49 50 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) { 51 unsigned s = takeConj(kind, s0, s1); 52 for (unsigned p : latSets[s0]) 53 latSets[s].push_back(p); 54 for (unsigned p : latSets[s1]) 55 latSets[s].push_back(p); 56 return s; 57 } 58 59 unsigned Merger::optimizeSet(unsigned s0) { 60 unsigned s = addSet(); 61 assert(latSets[s0].size() != 0); 62 unsigned p0 = latSets[s0][0]; 63 for (unsigned p1 : latSets[s0]) { 64 bool add = true; 65 if (p0 != p1) { 66 // Is this a straightforward copy? 67 unsigned e = latPoints[p1].exp; 68 if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor) 69 continue; 70 // Conjunction already covered? 71 for (unsigned p2 : latSets[s]) { 72 assert(!latGT(p1, p2)); // Lj => Li would be bad 73 if (onlyDenseDiff(p2, p1)) { 74 add = false; 75 break; 76 } 77 } 78 assert(!add || latGT(p0, p1)); 79 } 80 if (add) 81 latSets[s].push_back(p1); 82 } 83 for (unsigned p : latSets[s]) 84 latPoints[p].simple = simplifyCond(s, p); 85 return s; 86 } 87 88 llvm::BitVector Merger::simplifyCond(unsigned s, unsigned p0) { 89 // First determine if this lattice point is a *singleton*, i.e., 90 // the last point in a lattice, no other is less than this one. 91 bool isSingleton = true; 92 for (unsigned p1 : latSets[s]) { 93 if (p0 != p1 && latGT(p0, p1)) { 94 isSingleton = false; 95 break; 96 } 97 } 98 // Now apply the two basic rules. 99 llvm::BitVector simple = latPoints[p0].bits; 100 bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse); 101 for (unsigned b = 0, be = simple.size(); b < be; b++) { 102 if (simple[b] && !isDim(b, Dim::kSparse)) { 103 if (reset) 104 simple.reset(b); 105 reset = true; 106 } 107 } 108 return simple; 109 } 110 111 bool Merger::latGT(unsigned i, unsigned j) const { 112 const llvm::BitVector &bitsi = latPoints[i].bits; 113 const llvm::BitVector &bitsj = latPoints[j].bits; 114 assert(bitsi.size() == bitsj.size()); 115 if (bitsi.count() > bitsj.count()) { 116 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 117 if (bitsj[b] && !bitsi[b]) 118 return false; 119 return true; 120 } 121 return false; 122 } 123 124 bool Merger::onlyDenseDiff(unsigned i, unsigned j) { 125 llvm::BitVector tmp = latPoints[j].bits; 126 tmp ^= latPoints[i].bits; 127 return !hasAnyDimOf(tmp, Dim::kSparse); 128 } 129 130 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const { 131 for (unsigned b = 0, be = bits.size(); b < be; b++) 132 if (bits[b] && isDim(b, d)) 133 return true; 134 return false; 135 } 136 137 } // namespace sparse_tensor 138 } // namespace mlir 139