1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 11 #include "mlir/IR/Operation.h" 12 #include "llvm/Support/Debug.h" 13 14 namespace mlir { 15 namespace sparse_tensor { 16 17 // 18 // Constructors. 19 // 20 21 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v) 22 : kind(k), val(v) { 23 switch (kind) { 24 case Kind::kTensor: 25 assert(x != -1u && y == -1u && !v); 26 tensor = x; 27 break; 28 case Kind::kInvariant: 29 assert(x == -1u && y == -1u && v); 30 break; 31 case Kind::kZero: 32 assert(x == -1u && y == -1u && !v); 33 break; 34 default: 35 assert(x != -1u && y != -1u && !v); 36 children.e0 = x; 37 children.e1 = y; 38 break; 39 } 40 } 41 42 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b) 43 : bits(n, false), simple(), exp(e) { 44 bits.set(b); 45 } 46 47 LatPoint::LatPoint(const llvm::BitVector &b, unsigned e) 48 : bits(b), simple(), exp(e) {} 49 50 // 51 // Lattice methods. 52 // 53 54 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) { 55 unsigned e = tensorExps.size(); 56 tensorExps.push_back(TensorExp(k, e0, e1, v)); 57 return e; 58 } 59 60 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { 61 assert(t < numTensors && i < numLoops); 62 unsigned p = latPoints.size(); 63 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 64 return p; 65 } 66 67 unsigned Merger::addSet() { 68 unsigned s = latSets.size(); 69 latSets.emplace_back(SmallVector<unsigned, 16>()); 70 return s; 71 } 72 73 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) { 74 unsigned p = latPoints.size(); 75 llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits); 76 nb |= latPoints[p1].bits; 77 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); 78 latPoints.push_back(LatPoint(nb, e)); 79 return p; 80 } 81 82 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) { 83 unsigned s = addSet(); 84 for (unsigned p0 : latSets[s0]) 85 for (unsigned p1 : latSets[s1]) 86 latSets[s].push_back(conjLatPoint(kind, p0, p1)); 87 return s; 88 } 89 90 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) { 91 unsigned s = takeConj(kind, s0, s1); 92 // Followed by all in s0 and s1. 93 for (unsigned p : latSets[s0]) 94 latSets[s].push_back(p); 95 if (Kind::kSubF <= kind && kind <= Kind::kSubI) 96 s1 = mapZero(kind, s1); 97 for (unsigned p : latSets[s1]) 98 latSets[s].push_back(p); 99 return s; 100 } 101 102 unsigned Merger::mapZero(Kind kind, unsigned s0) { 103 assert(Kind::kSubF <= kind && kind <= Kind::kSubI); 104 unsigned s = addSet(); 105 unsigned z = addExp(Kind::kZero); 106 for (unsigned p : latSets[s0]) { 107 unsigned e = addExp(kind, z, latPoints[p].exp); 108 latPoints.push_back(LatPoint(latPoints[p].bits, e)); 109 latSets[s].push_back(latPoints.size() - 1); 110 } 111 return s; 112 } 113 114 unsigned Merger::optimizeSet(unsigned s0) { 115 unsigned s = addSet(); 116 assert(latSets[s0].size() != 0); 117 unsigned p0 = latSets[s0][0]; 118 for (unsigned p1 : latSets[s0]) { 119 bool add = true; 120 if (p0 != p1) { 121 // Is this a straightforward copy? 122 unsigned e = latPoints[p1].exp; 123 if (tensorExps[e].kind == Kind::kTensor && 124 tensorExps[e].tensor == outTensor) 125 continue; 126 // Conjunction already covered? 127 for (unsigned p2 : latSets[s]) { 128 assert(!latGT(p1, p2)); // Lj => Li would be bad 129 if (onlyDenseDiff(p2, p1)) { 130 add = false; 131 break; 132 } 133 } 134 assert(!add || latGT(p0, p1)); 135 } 136 if (add) 137 latSets[s].push_back(p1); 138 } 139 for (unsigned p : latSets[s]) 140 latPoints[p].simple = simplifyCond(s, p); 141 return s; 142 } 143 144 llvm::BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { 145 // First determine if this lattice point is a *singleton*, i.e., 146 // the last point in a lattice, no other is less than this one. 147 bool isSingleton = true; 148 for (unsigned p1 : latSets[s0]) { 149 if (p0 != p1 && latGT(p0, p1)) { 150 isSingleton = false; 151 break; 152 } 153 } 154 // Now apply the two basic rules. 155 llvm::BitVector simple = latPoints[p0].bits; 156 bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse); 157 for (unsigned b = 0, be = simple.size(); b < be; b++) { 158 if (simple[b] && !isDim(b, Dim::kSparse)) { 159 if (reset) 160 simple.reset(b); 161 reset = true; 162 } 163 } 164 return simple; 165 } 166 167 bool Merger::latGT(unsigned i, unsigned j) const { 168 const llvm::BitVector &bitsi = latPoints[i].bits; 169 const llvm::BitVector &bitsj = latPoints[j].bits; 170 assert(bitsi.size() == bitsj.size()); 171 if (bitsi.count() > bitsj.count()) { 172 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 173 if (bitsj[b] && !bitsi[b]) 174 return false; 175 return true; 176 } 177 return false; 178 } 179 180 bool Merger::onlyDenseDiff(unsigned i, unsigned j) { 181 llvm::BitVector tmp = latPoints[j].bits; 182 tmp ^= latPoints[i].bits; 183 return !hasAnyDimOf(tmp, Dim::kSparse); 184 } 185 186 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const { 187 for (unsigned b = 0, be = bits.size(); b < be; b++) 188 if (bits[b] && isDim(b, d)) 189 return true; 190 return false; 191 } 192 193 bool Merger::isConjunction(unsigned t, unsigned e) const { 194 switch (tensorExps[e].kind) { 195 case Kind::kTensor: 196 return tensorExps[e].tensor == t; 197 case Kind::kMulF: 198 case Kind::kMulI: 199 case Kind::kAndI: 200 case Kind::kDivF: // note: x / c only 201 case Kind::kDivS: 202 case Kind::kDivU: 203 return isConjunction(t, tensorExps[e].children.e0) || 204 isConjunction(t, tensorExps[e].children.e1); 205 default: 206 return false; 207 } 208 } 209 210 #ifndef NDEBUG 211 212 // 213 // Print methods (for debugging). 214 // 215 216 static char kindToOpSymbol(Kind kind) { 217 switch (kind) { 218 case Kind::kMulF: 219 case Kind::kMulI: 220 return '*'; 221 case Kind::kDivF: 222 case Kind::kDivS: 223 case Kind::kDivU: 224 return '/'; 225 case Kind::kAddF: 226 case Kind::kAddI: 227 return '+'; 228 case Kind::kSubF: 229 case Kind::kSubI: 230 return '-'; 231 case Kind::kAndI: 232 return '&'; 233 case Kind::kOrI: 234 return '|'; 235 default: 236 break; 237 } 238 llvm_unreachable("unexpected kind"); 239 } 240 241 void Merger::dumpExp(unsigned e) const { 242 switch (tensorExps[e].kind) { 243 case Kind::kTensor: 244 if (tensorExps[e].tensor == syntheticTensor) 245 llvm::dbgs() << "synthetic_"; 246 else if (tensorExps[e].tensor == outTensor) 247 llvm::dbgs() << "output_"; 248 llvm::dbgs() << "tensor_" << tensorExps[e].tensor; 249 break; 250 case Kind::kInvariant: 251 llvm::dbgs() << "invariant"; 252 break; 253 case Kind::kZero: 254 llvm::dbgs() << "zero"; 255 break; 256 default: 257 llvm::dbgs() << "("; 258 dumpExp(tensorExps[e].children.e0); 259 llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; 260 dumpExp(tensorExps[e].children.e1); 261 llvm::dbgs() << ")"; 262 } 263 } 264 265 void Merger::dumpLat(unsigned p) const { 266 llvm::dbgs() << "lat("; 267 dumpBits(latPoints[p].bits); 268 llvm::dbgs() << " :"; 269 dumpBits(latPoints[p].simple); 270 llvm::dbgs() << " : "; 271 dumpExp(latPoints[p].exp); 272 llvm::dbgs() << " )\n"; 273 } 274 275 void Merger::dumpSet(unsigned s) const { 276 llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; 277 for (unsigned p : latSets[s]) { 278 llvm::dbgs() << " "; 279 dumpLat(p); 280 } 281 llvm::dbgs() << "}\n"; 282 } 283 284 void Merger::dumpBits(const llvm::BitVector &bits) const { 285 for (unsigned b = 0, be = bits.size(); b < be; b++) { 286 if (bits[b]) { 287 unsigned t = tensor(b); 288 unsigned i = index(b); 289 llvm::dbgs() << " i_" << t << "_" << i << "_"; 290 switch (dims[t][i]) { 291 case Dim::kSparse: 292 llvm::dbgs() << "S"; 293 break; 294 case Dim::kDense: 295 llvm::dbgs() << "D"; 296 break; 297 case Dim::kSingle: 298 llvm::dbgs() << "T"; 299 break; 300 case Dim::kUndef: 301 llvm::dbgs() << "U"; 302 break; 303 } 304 } 305 } 306 } 307 308 #endif // NDEBUG 309 310 // 311 // Builder methods. 312 // 313 314 unsigned Merger::buildLattices(unsigned e, unsigned i) { 315 Kind kind = tensorExps[e].kind; 316 switch (kind) { 317 case Kind::kTensor: 318 case Kind::kInvariant: 319 case Kind::kZero: { 320 // Either the index is really used in the tensor expression, or it is 321 // set to the undefined index in that dimension. An invariant expression 322 // is set to a synthetic tensor with undefined indices only. 323 unsigned s = addSet(); 324 unsigned t = kind == Kind::kTensor ? tensorExps[e].tensor : syntheticTensor; 325 latSets[s].push_back(addLat(t, i, e)); 326 return s; 327 } 328 case Kind::kMulF: 329 case Kind::kMulI: 330 case Kind::kAndI: 331 // A multiplicative operation only needs to be performed 332 // for the conjunction of sparse iteration spaces. 333 // 334 // x*y|!y | y | 335 // ---+---+---+ 336 // !x | 0 | 0 | 337 // x | 0 |x*y| 338 return takeConj(kind, // take binary conjunction 339 buildLattices(tensorExps[e].children.e0, i), 340 buildLattices(tensorExps[e].children.e1, i)); 341 case Kind::kDivF: 342 case Kind::kDivS: 343 case Kind::kDivU: 344 // A division is tricky, since 0/0, 0/c, c/0 all have 345 // specific outcomes for floating-point and integers. 346 // Thus, we need to traverse the full iteration space. 347 // 348 // x/y|!y | y | 349 // ---+---+---+ 350 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 351 // x |x/0|x/y| INT: x/0=exception for any x 352 // 353 // TODO: for now we "fixed" this by only accepting x/c cases 354 // during expression building, so that the conjunction 355 // rules applies (viz. x/c = x*(1/c) as far as lattice 356 // construction is concerned). 357 return takeConj(kind, // take binary conjunction 358 buildLattices(tensorExps[e].children.e0, i), 359 buildLattices(tensorExps[e].children.e1, i)); 360 case Kind::kSubF: 361 case Kind::kSubI: 362 // Special case: 0-y is -y. 363 if (tensorExps[tensorExps[e].children.e0].kind == Kind::kZero) 364 return mapZero(kind, // maps to 0-y with just y's lattices 365 buildLattices(tensorExps[e].children.e1, i)); 366 LLVM_FALLTHROUGH; 367 case Kind::kAddF: 368 case Kind::kAddI: 369 case Kind::kOrI: 370 // An additive operation needs to be performed 371 // for the disjunction of sparse iteration spaces. 372 // 373 // x+y|!y | y | x-y|!y | y | 374 // ---+---+---+ ---+---+---+ 375 // !x | 0 | y | !x | 0 |-y | 376 // x | x |x+y| x | x |x-y| 377 return takeDisj(kind, // take binary disjunction 378 buildLattices(tensorExps[e].children.e0, i), 379 buildLattices(tensorExps[e].children.e1, i)); 380 } 381 llvm_unreachable("unexpected expression kind"); 382 } 383 384 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 385 Operation *yield = op.region().front().getTerminator(); 386 return buildTensorExp(op, yield->getOperand(0)); 387 } 388 389 bool Merger::maybeZero(unsigned e) { 390 if (tensorExps[e].kind == Kind::kInvariant) { 391 if (auto c = tensorExps[e].val.getDefiningOp<ConstantIntOp>()) 392 return c.getValue() == 0; 393 if (auto c = tensorExps[e].val.getDefiningOp<ConstantFloatOp>()) 394 return c.getValue().isZero(); 395 } 396 return true; 397 } 398 399 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) { 400 if (auto arg = v.dyn_cast<BlockArgument>()) { 401 unsigned argN = arg.getArgNumber(); 402 // Any argument of the generic op that is not marked as a scalar 403 // argument is considered a tensor, indexed by the implicit loop 404 // bounds. This includes rank-0 tensor arguments. 405 if (arg.getOwner()->getParentOp() == op) { 406 OpOperand *t = op.getInputAndOutputOperands()[argN]; 407 if (!op.isScalar(t)) 408 return addExp(Kind::kTensor, argN); 409 v = t->get(); // get scalar value 410 } 411 // Any other argument (marked as scalar argument for the generic op 412 // or belonging to an enveloping op) is considered invariant. 413 return addExp(Kind::kInvariant, v); 414 } 415 // Something defined outside is invariant. 416 Operation *def = v.getDefiningOp(); 417 if (def->getBlock() != &op.region().front()) 418 return addExp(Kind::kInvariant, v); 419 // Construct unary operations if subexpression can be built. 420 if (def->getNumOperands() == 1) { 421 auto x = buildTensorExp(op, def->getOperand(0)); 422 if (x.hasValue()) { 423 unsigned e0 = addExp(Kind::kZero); 424 unsigned e1 = x.getValue(); 425 if (isa<NegFOp>(def)) 426 return addExp(Kind::kSubF, e0, e1); 427 // TODO: no negi in std? 428 } 429 } 430 // Construct binary operations if subexpressions can be built. 431 // TODO: see buildLattices() for an explanation of rejecting certain divisions 432 if (def->getNumOperands() == 2) { 433 auto x = buildTensorExp(op, def->getOperand(0)); 434 auto y = buildTensorExp(op, def->getOperand(1)); 435 if (x.hasValue() && y.hasValue()) { 436 unsigned e0 = x.getValue(); 437 unsigned e1 = y.getValue(); 438 if (isa<MulFOp>(def)) 439 return addExp(Kind::kMulF, e0, e1); 440 if (isa<MulIOp>(def)) 441 return addExp(Kind::kMulI, e0, e1); 442 if (isa<DivFOp>(def) && !maybeZero(e1)) 443 return addExp(Kind::kDivF, e0, e1); 444 if (isa<SignedDivIOp>(def) && !maybeZero(e1)) 445 return addExp(Kind::kDivS, e0, e1); 446 if (isa<UnsignedDivIOp>(def) && !maybeZero(e1)) 447 return addExp(Kind::kDivU, e0, e1); 448 if (isa<AddFOp>(def)) 449 return addExp(Kind::kAddF, e0, e1); 450 if (isa<AddIOp>(def)) 451 return addExp(Kind::kAddI, e0, e1); 452 if (isa<SubFOp>(def)) 453 return addExp(Kind::kSubF, e0, e1); 454 if (isa<SubIOp>(def)) 455 return addExp(Kind::kSubI, e0, e1); 456 if (isa<AndOp>(def)) 457 return addExp(Kind::kAndI, e0, e1); 458 if (isa<OrOp>(def)) 459 return addExp(Kind::kOrI, e0, e1); 460 } 461 } 462 // Cannot build. 463 return None; 464 } 465 466 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e, 467 Value v0, Value v1) { 468 switch (tensorExps[e].kind) { 469 case Kind::kTensor: 470 case Kind::kInvariant: 471 case Kind::kZero: 472 llvm_unreachable("unexpected non-op"); 473 case Kind::kMulF: 474 return rewriter.create<MulFOp>(loc, v0, v1); 475 case Kind::kMulI: 476 return rewriter.create<MulIOp>(loc, v0, v1); 477 case Kind::kDivF: 478 return rewriter.create<DivFOp>(loc, v0, v1); 479 case Kind::kDivS: 480 return rewriter.create<SignedDivIOp>(loc, v0, v1); 481 case Kind::kDivU: 482 return rewriter.create<UnsignedDivIOp>(loc, v0, v1); 483 case Kind::kAddF: 484 return rewriter.create<AddFOp>(loc, v0, v1); 485 case Kind::kAddI: 486 return rewriter.create<AddIOp>(loc, v0, v1); 487 case Kind::kSubF: 488 return rewriter.create<SubFOp>(loc, v0, v1); 489 case Kind::kSubI: 490 return rewriter.create<SubIOp>(loc, v0, v1); 491 case Kind::kAndI: 492 return rewriter.create<AndOp>(loc, v0, v1); 493 case Kind::kOrI: 494 return rewriter.create<OrOp>(loc, v0, v1); 495 } 496 llvm_unreachable("unexpected expression kind in build"); 497 } 498 499 } // namespace sparse_tensor 500 } // namespace mlir 501