1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 11 #include "mlir/IR/Operation.h" 12 #include "llvm/Support/Debug.h" 13 14 namespace mlir { 15 namespace sparse_tensor { 16 17 // 18 // Lattice methods. 19 // 20 21 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) { 22 unsigned e = tensorExps.size(); 23 tensorExps.push_back(TensorExp(k, e0, e1, v)); 24 return e; 25 } 26 27 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { 28 assert(t < numTensors && i < numLoops); 29 unsigned p = latPoints.size(); 30 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 31 return p; 32 } 33 34 unsigned Merger::addSet() { 35 unsigned s = latSets.size(); 36 latSets.emplace_back(SmallVector<unsigned, 16>()); 37 return s; 38 } 39 40 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) { 41 unsigned p = latPoints.size(); 42 llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits); 43 nb |= latPoints[p1].bits; 44 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); 45 latPoints.push_back(LatPoint(nb, e)); 46 return p; 47 } 48 49 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) { 50 unsigned s = addSet(); 51 for (unsigned p0 : latSets[s0]) 52 for (unsigned p1 : latSets[s1]) 53 latSets[s].push_back(conjLatPoint(kind, p0, p1)); 54 return s; 55 } 56 57 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) { 58 unsigned s = takeConj(kind, s0, s1); 59 for (unsigned p : latSets[s0]) 60 latSets[s].push_back(p); 61 for (unsigned p : latSets[s1]) 62 latSets[s].push_back(p); 63 return s; 64 } 65 66 unsigned Merger::optimizeSet(unsigned s0) { 67 unsigned s = addSet(); 68 assert(latSets[s0].size() != 0); 69 unsigned p0 = latSets[s0][0]; 70 for (unsigned p1 : latSets[s0]) { 71 bool add = true; 72 if (p0 != p1) { 73 // Is this a straightforward copy? 74 unsigned e = latPoints[p1].exp; 75 if (tensorExps[e].kind == Kind::kTensor && 76 tensorExps[e].tensor == outTensor) 77 continue; 78 // Conjunction already covered? 79 for (unsigned p2 : latSets[s]) { 80 assert(!latGT(p1, p2)); // Lj => Li would be bad 81 if (onlyDenseDiff(p2, p1)) { 82 add = false; 83 break; 84 } 85 } 86 assert(!add || latGT(p0, p1)); 87 } 88 if (add) 89 latSets[s].push_back(p1); 90 } 91 for (unsigned p : latSets[s]) 92 latPoints[p].simple = simplifyCond(s, p); 93 return s; 94 } 95 96 llvm::BitVector Merger::simplifyCond(unsigned s, unsigned p0) { 97 // First determine if this lattice point is a *singleton*, i.e., 98 // the last point in a lattice, no other is less than this one. 99 bool isSingleton = true; 100 for (unsigned p1 : latSets[s]) { 101 if (p0 != p1 && latGT(p0, p1)) { 102 isSingleton = false; 103 break; 104 } 105 } 106 // Now apply the two basic rules. 107 llvm::BitVector simple = latPoints[p0].bits; 108 bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse); 109 for (unsigned b = 0, be = simple.size(); b < be; b++) { 110 if (simple[b] && !isDim(b, Dim::kSparse)) { 111 if (reset) 112 simple.reset(b); 113 reset = true; 114 } 115 } 116 return simple; 117 } 118 119 bool Merger::latGT(unsigned i, unsigned j) const { 120 const llvm::BitVector &bitsi = latPoints[i].bits; 121 const llvm::BitVector &bitsj = latPoints[j].bits; 122 assert(bitsi.size() == bitsj.size()); 123 if (bitsi.count() > bitsj.count()) { 124 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 125 if (bitsj[b] && !bitsi[b]) 126 return false; 127 return true; 128 } 129 return false; 130 } 131 132 bool Merger::onlyDenseDiff(unsigned i, unsigned j) { 133 llvm::BitVector tmp = latPoints[j].bits; 134 tmp ^= latPoints[i].bits; 135 return !hasAnyDimOf(tmp, Dim::kSparse); 136 } 137 138 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const { 139 for (unsigned b = 0, be = bits.size(); b < be; b++) 140 if (bits[b] && isDim(b, d)) 141 return true; 142 return false; 143 } 144 145 #ifndef NDEBUG 146 147 // 148 // Print methods (for debugging). 149 // 150 151 void Merger::dumpExp(unsigned e) const { 152 switch (tensorExps[e].kind) { 153 case Kind::kTensor: 154 if (tensorExps[e].tensor == syntheticTensor) 155 llvm::dbgs() << "synthetic_"; 156 else if (tensorExps[e].tensor == outTensor) 157 llvm::dbgs() << "output_"; 158 llvm::dbgs() << "tensor_" << tensorExps[e].tensor; 159 break; 160 case Kind::kInvariant: 161 llvm::dbgs() << "invariant"; 162 break; 163 default: 164 case Kind::kMulI: 165 llvm::dbgs() << "("; 166 dumpExp(tensorExps[e].children.e0); 167 llvm::dbgs() << " * "; 168 dumpExp(tensorExps[e].children.e1); 169 llvm::dbgs() << ")"; 170 break; 171 case Kind::kAddF: 172 case Kind::kAddI: 173 llvm::dbgs() << "("; 174 dumpExp(tensorExps[e].children.e0); 175 llvm::dbgs() << " + "; 176 dumpExp(tensorExps[e].children.e1); 177 llvm::dbgs() << ")"; 178 break; 179 } 180 } 181 182 void Merger::dumpLat(unsigned p) const { 183 llvm::dbgs() << "lat("; 184 dumpBits(latPoints[p].bits); 185 llvm::dbgs() << " :"; 186 dumpBits(latPoints[p].simple); 187 llvm::dbgs() << " / "; 188 dumpExp(latPoints[p].exp); 189 llvm::dbgs() << " )\n"; 190 } 191 192 void Merger::dumpSet(unsigned s) const { 193 llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; 194 for (unsigned p : latSets[s]) { 195 llvm::dbgs() << " "; 196 dumpLat(p); 197 } 198 llvm::dbgs() << "}\n"; 199 } 200 201 void Merger::dumpBits(const llvm::BitVector &bits) const { 202 for (unsigned b = 0, be = bits.size(); b < be; b++) { 203 if (bits[b]) { 204 unsigned t = tensor(b); 205 unsigned i = index(b); 206 llvm::dbgs() << " i_" << t << "_" << i << "_"; 207 switch (dims[t][i]) { 208 case Dim::kSparse: 209 llvm::dbgs() << "S"; 210 break; 211 case Dim::kDense: 212 llvm::dbgs() << "D"; 213 break; 214 case Dim::kSingle: 215 llvm::dbgs() << "T"; 216 break; 217 case Dim::kUndef: 218 llvm::dbgs() << "U"; 219 break; 220 } 221 } 222 } 223 } 224 225 #endif // NDEBUG 226 227 // 228 // Builder methods. 229 // 230 231 unsigned Merger::buildLattices(unsigned e, unsigned idx) { 232 Kind kind = tensorExps[e].kind; 233 if (kind == Kind::kTensor || kind == Kind::kInvariant) { 234 // Either the index is really used in the tensor expression, or it is 235 // set to the undefined index in that dimension. An invariant expression 236 // is set to a synthetic tensor with undefined indices only. 237 unsigned s = addSet(); 238 unsigned t = 239 kind == Kind::kTensor ? tensorExps[e].children.e0 : syntheticTensor; 240 latSets[s].push_back(addLat(t, idx, e)); 241 return s; 242 } 243 unsigned s0 = buildLattices(tensorExps[e].children.e0, idx); 244 unsigned s1 = buildLattices(tensorExps[e].children.e1, idx); 245 switch (kind) { 246 case Kind::kTensor: 247 case Kind::kInvariant: 248 llvm_unreachable("handled above"); 249 case Kind::kMulF: 250 case Kind::kMulI: 251 return takeConj(kind, s0, s1); 252 case Kind::kAddF: 253 case Kind::kAddI: 254 return takeDisj(kind, s0, s1); 255 } 256 llvm_unreachable("unexpected expression kind"); 257 } 258 259 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 260 Operation *yield = op.region().front().getTerminator(); 261 return buildTensorExp(op, yield->getOperand(0)); 262 } 263 264 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value val) { 265 if (auto arg = val.dyn_cast<BlockArgument>()) { 266 unsigned argN = arg.getArgNumber(); 267 // Any argument of the generic op that is not marked as a scalar 268 // argument is considered a tensor, indexed by the implicit loop 269 // bounds. This includes rank-0 tensor arguments. 270 if (arg.getOwner()->getParentOp() == op) { 271 OpOperand *t = op.getInputAndOutputOperands()[argN]; 272 if (!op.isScalar(t)) 273 return addExp(Kind::kTensor, argN); 274 val = t->get(); // get scalar value 275 } 276 // Any other argument (marked as scalar argument for the generic op 277 // or belonging to an enveloping op) is considered invariant. 278 return addExp(Kind::kInvariant, val); 279 } 280 // Something defined outside is invariant. 281 Operation *def = val.getDefiningOp(); 282 if (def->getBlock() != &op.region().front()) 283 return addExp(Kind::kInvariant, val); 284 // Construct binary operations if subexpressions could be built. 285 if (def->getNumOperands() == 2) { 286 auto x = buildTensorExp(op, def->getOperand(0)); 287 auto y = buildTensorExp(op, def->getOperand(1)); 288 if (x.hasValue() && y.hasValue()) { 289 unsigned e0 = x.getValue(); 290 unsigned e1 = y.getValue(); 291 if (isa<MulFOp>(def)) 292 return addExp(Kind::kMulF, e0, e1); 293 if (isa<MulIOp>(def)) 294 return addExp(Kind::kMulI, e0, e1); 295 if (isa<AddFOp>(def)) 296 return addExp(Kind::kAddF, e0, e1); 297 if (isa<AddIOp>(def)) 298 return addExp(Kind::kAddI, e0, e1); 299 } 300 } 301 // Cannot build. 302 return None; 303 } 304 305 } // namespace sparse_tensor 306 } // namespace mlir 307