1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 11 #include "mlir/IR/Operation.h" 12 #include "llvm/Support/Debug.h" 13 14 namespace mlir { 15 namespace sparse_tensor { 16 17 // 18 // Constructors. 19 // 20 21 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v) 22 : kind(k), val(v) { 23 switch (kind) { 24 case kTensor: 25 assert(x != -1u && y == -1u && !v); 26 tensor = x; 27 break; 28 case kInvariant: 29 assert(x == -1u && y == -1u && v); 30 break; 31 case kAbsF: 32 case kCeilF: 33 case kFloorF: 34 case kNegF: 35 case kNegI: 36 assert(x != -1u && y == -1u && !v); 37 children.e0 = x; 38 children.e1 = y; 39 break; 40 default: 41 assert(x != -1u && y != -1u && !v); 42 children.e0 = x; 43 children.e1 = y; 44 break; 45 } 46 } 47 48 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b) 49 : bits(n, false), simple(), exp(e) { 50 bits.set(b); 51 } 52 53 LatPoint::LatPoint(const llvm::BitVector &b, unsigned e) 54 : bits(b), simple(), exp(e) {} 55 56 // 57 // Lattice methods. 58 // 59 60 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) { 61 unsigned e = tensorExps.size(); 62 tensorExps.push_back(TensorExp(k, e0, e1, v)); 63 return e; 64 } 65 66 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { 67 assert(t < numTensors && i < numLoops); 68 unsigned p = latPoints.size(); 69 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 70 return p; 71 } 72 73 unsigned Merger::addSet() { 74 unsigned s = latSets.size(); 75 latSets.emplace_back(SmallVector<unsigned, 16>()); 76 return s; 77 } 78 79 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) { 80 unsigned p = latPoints.size(); 81 llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits); 82 nb |= latPoints[p1].bits; 83 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); 84 latPoints.push_back(LatPoint(nb, e)); 85 return p; 86 } 87 88 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) { 89 unsigned s = addSet(); 90 for (unsigned p0 : latSets[s0]) 91 for (unsigned p1 : latSets[s1]) 92 latSets[s].push_back(conjLatPoint(kind, p0, p1)); 93 return s; 94 } 95 96 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) { 97 unsigned s = takeConj(kind, s0, s1); 98 // Followed by all in s0. 99 for (unsigned p : latSets[s0]) 100 latSets[s].push_back(p); 101 // Map binary 0-y to unary -y. 102 if (kind == kSubF) 103 s1 = mapSet(kNegF, s1); 104 else if (kind == kSubI) 105 s1 = mapSet(kNegI, s1); 106 // Followed by all in s1. 107 for (unsigned p : latSets[s1]) 108 latSets[s].push_back(p); 109 return s; 110 } 111 112 unsigned Merger::mapSet(Kind kind, unsigned s0) { 113 assert(kAbsF <= kind && kind <= kNegI); 114 unsigned s = addSet(); 115 for (unsigned p : latSets[s0]) { 116 unsigned e = addExp(kind, latPoints[p].exp); 117 latPoints.push_back(LatPoint(latPoints[p].bits, e)); 118 latSets[s].push_back(latPoints.size() - 1); 119 } 120 return s; 121 } 122 123 unsigned Merger::optimizeSet(unsigned s0) { 124 unsigned s = addSet(); 125 assert(latSets[s0].size() != 0); 126 unsigned p0 = latSets[s0][0]; 127 for (unsigned p1 : latSets[s0]) { 128 bool add = true; 129 if (p0 != p1) { 130 // Is this a straightforward copy? 131 unsigned e = latPoints[p1].exp; 132 if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor) 133 continue; 134 // Conjunction already covered? 135 for (unsigned p2 : latSets[s]) { 136 assert(!latGT(p1, p2)); // Lj => Li would be bad 137 if (onlyDenseDiff(p2, p1)) { 138 add = false; 139 break; 140 } 141 } 142 assert(!add || latGT(p0, p1)); 143 } 144 if (add) 145 latSets[s].push_back(p1); 146 } 147 for (unsigned p : latSets[s]) 148 latPoints[p].simple = simplifyCond(s, p); 149 return s; 150 } 151 152 llvm::BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { 153 // First determine if this lattice point is a *singleton*, i.e., 154 // the last point in a lattice, no other is less than this one. 155 bool isSingleton = true; 156 for (unsigned p1 : latSets[s0]) { 157 if (p0 != p1 && latGT(p0, p1)) { 158 isSingleton = false; 159 break; 160 } 161 } 162 // Now apply the two basic rules. 163 llvm::BitVector simple = latPoints[p0].bits; 164 bool reset = isSingleton && hasAnyDimOf(simple, kSparse); 165 for (unsigned b = 0, be = simple.size(); b < be; b++) { 166 if (simple[b] && !isDim(b, kSparse)) { 167 if (reset) 168 simple.reset(b); 169 reset = true; 170 } 171 } 172 return simple; 173 } 174 175 bool Merger::latGT(unsigned i, unsigned j) const { 176 const llvm::BitVector &bitsi = latPoints[i].bits; 177 const llvm::BitVector &bitsj = latPoints[j].bits; 178 assert(bitsi.size() == bitsj.size()); 179 if (bitsi.count() > bitsj.count()) { 180 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 181 if (bitsj[b] && !bitsi[b]) 182 return false; 183 return true; 184 } 185 return false; 186 } 187 188 bool Merger::onlyDenseDiff(unsigned i, unsigned j) { 189 llvm::BitVector tmp = latPoints[j].bits; 190 tmp ^= latPoints[i].bits; 191 return !hasAnyDimOf(tmp, kSparse); 192 } 193 194 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const { 195 for (unsigned b = 0, be = bits.size(); b < be; b++) 196 if (bits[b] && isDim(b, d)) 197 return true; 198 return false; 199 } 200 201 bool Merger::isConjunction(unsigned t, unsigned e) const { 202 switch (tensorExps[e].kind) { 203 case kTensor: 204 return tensorExps[e].tensor == t; 205 case kAbsF: 206 case kCeilF: 207 case kFloorF: 208 case kNegF: 209 case kNegI: 210 return isConjunction(t, tensorExps[e].children.e0); 211 case kDivF: // note: x / c only 212 case kDivS: 213 case kDivU: 214 assert(!maybeZero(tensorExps[e].children.e1)); 215 return isConjunction(t, tensorExps[e].children.e0); 216 case kShrS: // note: x >> inv only 217 case kShrU: 218 case kShlI: 219 assert(isInvariant(tensorExps[e].children.e1)); 220 return isConjunction(t, tensorExps[e].children.e0); 221 case kMulF: 222 case kMulI: 223 case kAndI: 224 return isConjunction(t, tensorExps[e].children.e0) || 225 isConjunction(t, tensorExps[e].children.e1); 226 default: 227 return false; 228 } 229 } 230 231 #ifndef NDEBUG 232 233 // 234 // Print methods (for debugging). 235 // 236 237 static const char *kindToOpSymbol(Kind kind) { 238 switch (kind) { 239 case kTensor: 240 return "tensor"; 241 case kInvariant: 242 return "invariant"; 243 case kAbsF: 244 return "abs"; 245 case kCeilF: 246 return "ceil"; 247 case kFloorF: 248 return "floor"; 249 case kNegF: 250 return "-"; 251 case kNegI: 252 return "-"; 253 case kMulF: 254 return "*"; 255 case kMulI: 256 return "*"; 257 case kDivF: 258 return "/"; 259 case kDivS: 260 return "/"; 261 case kDivU: 262 return "/"; 263 case kAddF: 264 return "+"; 265 case kAddI: 266 return "+"; 267 case kSubF: 268 return "-"; 269 case kSubI: 270 return "-"; 271 case kAndI: 272 return "&"; 273 case kOrI: 274 return "|"; 275 case kXorI: 276 return "^"; 277 case kShrS: 278 return "a>>"; 279 case kShrU: 280 return ">>"; 281 case kShlI: 282 return "<<"; 283 } 284 llvm_unreachable("unexpected kind for symbol"); 285 } 286 287 void Merger::dumpExp(unsigned e) const { 288 switch (tensorExps[e].kind) { 289 case kTensor: 290 if (tensorExps[e].tensor == syntheticTensor) 291 llvm::dbgs() << "synthetic_"; 292 else if (tensorExps[e].tensor == outTensor) 293 llvm::dbgs() << "output_"; 294 llvm::dbgs() << "tensor_" << tensorExps[e].tensor; 295 break; 296 case kInvariant: 297 llvm::dbgs() << "invariant"; 298 break; 299 case kAbsF: 300 case kCeilF: 301 case kFloorF: 302 case kNegF: 303 case kNegI: 304 llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; 305 dumpExp(tensorExps[e].children.e0); 306 break; 307 default: 308 llvm::dbgs() << "("; 309 dumpExp(tensorExps[e].children.e0); 310 llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; 311 dumpExp(tensorExps[e].children.e1); 312 llvm::dbgs() << ")"; 313 } 314 } 315 316 void Merger::dumpLat(unsigned p) const { 317 llvm::dbgs() << "lat("; 318 dumpBits(latPoints[p].bits); 319 llvm::dbgs() << " :"; 320 dumpBits(latPoints[p].simple); 321 llvm::dbgs() << " : "; 322 dumpExp(latPoints[p].exp); 323 llvm::dbgs() << " )\n"; 324 } 325 326 void Merger::dumpSet(unsigned s) const { 327 llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; 328 for (unsigned p : latSets[s]) { 329 llvm::dbgs() << " "; 330 dumpLat(p); 331 } 332 llvm::dbgs() << "}\n"; 333 } 334 335 void Merger::dumpBits(const llvm::BitVector &bits) const { 336 for (unsigned b = 0, be = bits.size(); b < be; b++) { 337 if (bits[b]) { 338 unsigned t = tensor(b); 339 unsigned i = index(b); 340 llvm::dbgs() << " i_" << t << "_" << i << "_"; 341 switch (dims[t][i]) { 342 case kSparse: 343 llvm::dbgs() << "S"; 344 break; 345 case kDense: 346 llvm::dbgs() << "D"; 347 break; 348 case kSingle: 349 llvm::dbgs() << "T"; 350 break; 351 case kUndef: 352 llvm::dbgs() << "U"; 353 break; 354 } 355 } 356 } 357 } 358 359 #endif // NDEBUG 360 361 // 362 // Builder methods. 363 // 364 365 unsigned Merger::buildLattices(unsigned e, unsigned i) { 366 Kind kind = tensorExps[e].kind; 367 switch (kind) { 368 case kTensor: 369 case kInvariant: { 370 // Either the index is really used in the tensor expression, or it is 371 // set to the undefined index in that dimension. An invariant expression 372 // is set to a synthetic tensor with undefined indices only. 373 unsigned s = addSet(); 374 unsigned t = kind == kTensor ? tensorExps[e].tensor : syntheticTensor; 375 latSets[s].push_back(addLat(t, i, e)); 376 return s; 377 } 378 case kAbsF: 379 case kCeilF: 380 case kFloorF: 381 case kNegF: 382 case kNegI: 383 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 384 // lattice set of the operand through the operator into a new set. 385 // 386 // -y|!y | y | 387 // --+---+---+ 388 // | 0 |-y | 389 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i)); 390 case kMulF: 391 case kMulI: 392 case kAndI: 393 // A multiplicative operation only needs to be performed 394 // for the conjunction of sparse iteration spaces. 395 // 396 // x*y|!y | y | 397 // ---+---+---+ 398 // !x | 0 | 0 | 399 // x | 0 |x*y| 400 return takeConj(kind, // take binary conjunction 401 buildLattices(tensorExps[e].children.e0, i), 402 buildLattices(tensorExps[e].children.e1, i)); 403 case kDivF: 404 case kDivS: 405 case kDivU: 406 // A division is tricky, since 0/0, 0/c, c/0 all have 407 // specific outcomes for floating-point and integers. 408 // Thus, we need to traverse the full iteration space. 409 // 410 // x/y|!y | y | 411 // ---+---+---+ 412 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 413 // x |x/0|x/y| INT: x/0=exception for any x 414 // 415 // TODO: for now we "fixed" this by only accepting x/c cases 416 // during expression building, so that the conjunction 417 // rules applies (viz. x/c = x*(1/c) as far as lattice 418 // construction is concerned). 419 assert(!maybeZero(tensorExps[e].children.e1)); 420 return takeConj(kind, // take binary conjunction 421 buildLattices(tensorExps[e].children.e0, i), 422 buildLattices(tensorExps[e].children.e1, i)); 423 case kAddF: 424 case kAddI: 425 case kSubF: 426 case kSubI: 427 case kOrI: 428 case kXorI: 429 // An additive operation needs to be performed 430 // for the disjunction of sparse iteration spaces. 431 // 432 // x+y|!y | y | x-y|!y | y | 433 // ---+---+---+ ---+---+---+ 434 // !x | 0 | y | !x | 0 |-y | 435 // x | x |x+y| x | x |x-y| 436 return takeDisj(kind, // take binary disjunction 437 buildLattices(tensorExps[e].children.e0, i), 438 buildLattices(tensorExps[e].children.e1, i)); 439 case kShrS: 440 case kShrU: 441 case kShlI: 442 // A shift operation by an invariant amount (viz. tensor expressions 443 // can only occur at the left-hand-side of the operator) can be handled 444 // with the conjuction rule. 445 assert(isInvariant(tensorExps[e].children.e1)); 446 return takeConj(kind, // take binary conjunction 447 buildLattices(tensorExps[e].children.e0, i), 448 buildLattices(tensorExps[e].children.e1, i)); 449 } 450 llvm_unreachable("unexpected expression kind"); 451 } 452 453 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 454 Operation *yield = op.region().front().getTerminator(); 455 return buildTensorExp(op, yield->getOperand(0)); 456 } 457 458 bool Merger::maybeZero(unsigned e) const { 459 if (tensorExps[e].kind == kInvariant) { 460 if (auto c = tensorExps[e].val.getDefiningOp<ConstantIntOp>()) 461 return c.getValue() == 0; 462 if (auto c = tensorExps[e].val.getDefiningOp<ConstantFloatOp>()) 463 return c.getValue().isZero(); 464 } 465 return true; 466 } 467 468 bool Merger::isInvariant(unsigned e) const { 469 return tensorExps[e].kind == kInvariant; 470 } 471 472 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) { 473 if (auto arg = v.dyn_cast<BlockArgument>()) { 474 unsigned argN = arg.getArgNumber(); 475 // Any argument of the generic op that is not marked as a scalar 476 // argument is considered a tensor, indexed by the implicit loop 477 // bounds. This includes rank-0 tensor arguments. 478 if (arg.getOwner()->getParentOp() == op) { 479 OpOperand *t = op.getInputAndOutputOperands()[argN]; 480 if (!op.isScalar(t)) 481 return addExp(kTensor, argN); 482 v = t->get(); // get scalar value 483 } 484 // Any other argument (marked as scalar argument for the generic op 485 // or belonging to an enveloping op) is considered invariant. 486 return addExp(kInvariant, v); 487 } 488 // Something defined outside is invariant. 489 Operation *def = v.getDefiningOp(); 490 if (def->getBlock() != &op.region().front()) 491 return addExp(kInvariant, v); 492 // Construct unary operations if subexpression can be built. 493 if (def->getNumOperands() == 1) { 494 auto x = buildTensorExp(op, def->getOperand(0)); 495 if (x.hasValue()) { 496 unsigned e = x.getValue(); 497 if (isa<AbsFOp>(def)) 498 return addExp(kAbsF, e); 499 if (isa<CeilFOp>(def)) 500 return addExp(kCeilF, e); 501 if (isa<FloorFOp>(def)) 502 return addExp(kFloorF, e); 503 if (isa<NegFOp>(def)) 504 return addExp(kNegF, e); 505 // TODO: no negi in std? 506 } 507 } 508 // Construct binary operations if subexpressions can be built. 509 // TODO: see buildLattices() for an explanation of rejecting certain divisions 510 if (def->getNumOperands() == 2) { 511 auto x = buildTensorExp(op, def->getOperand(0)); 512 auto y = buildTensorExp(op, def->getOperand(1)); 513 if (x.hasValue() && y.hasValue()) { 514 unsigned e0 = x.getValue(); 515 unsigned e1 = y.getValue(); 516 if (isa<MulFOp>(def)) 517 return addExp(kMulF, e0, e1); 518 if (isa<MulIOp>(def)) 519 return addExp(kMulI, e0, e1); 520 if (isa<DivFOp>(def) && !maybeZero(e1)) 521 return addExp(kDivF, e0, e1); 522 if (isa<SignedDivIOp>(def) && !maybeZero(e1)) 523 return addExp(kDivS, e0, e1); 524 if (isa<UnsignedDivIOp>(def) && !maybeZero(e1)) 525 return addExp(kDivU, e0, e1); 526 if (isa<AddFOp>(def)) 527 return addExp(kAddF, e0, e1); 528 if (isa<AddIOp>(def)) 529 return addExp(kAddI, e0, e1); 530 if (isa<SubFOp>(def)) 531 return addExp(kSubF, e0, e1); 532 if (isa<SubIOp>(def)) 533 return addExp(kSubI, e0, e1); 534 if (isa<AndOp>(def)) 535 return addExp(kAndI, e0, e1); 536 if (isa<OrOp>(def)) 537 return addExp(kOrI, e0, e1); 538 if (isa<XOrOp>(def)) 539 return addExp(kXorI, e0, e1); 540 if (isa<SignedShiftRightOp>(def) && isInvariant(e1)) 541 return addExp(kShrS, e0, e1); 542 if (isa<UnsignedShiftRightOp>(def) && isInvariant(e1)) 543 return addExp(kShrU, e0, e1); 544 if (isa<ShiftLeftOp>(def) && isInvariant(e1)) 545 return addExp(kShlI, e0, e1); 546 } 547 } 548 // Cannot build. 549 return None; 550 } 551 552 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e, 553 Value v0, Value v1) { 554 switch (tensorExps[e].kind) { 555 case kTensor: 556 case kInvariant: 557 llvm_unreachable("unexpected non-op"); 558 case kAbsF: 559 return rewriter.create<AbsFOp>(loc, v0); 560 case kCeilF: 561 return rewriter.create<CeilFOp>(loc, v0); 562 case kFloorF: 563 return rewriter.create<FloorFOp>(loc, v0); 564 case kNegF: 565 return rewriter.create<NegFOp>(loc, v0); 566 case kNegI: 567 assert(v1); // no negi in std 568 return rewriter.create<SubIOp>(loc, v0, v1); 569 case kMulF: 570 return rewriter.create<MulFOp>(loc, v0, v1); 571 case kMulI: 572 return rewriter.create<MulIOp>(loc, v0, v1); 573 case kDivF: 574 return rewriter.create<DivFOp>(loc, v0, v1); 575 case kDivS: 576 return rewriter.create<SignedDivIOp>(loc, v0, v1); 577 case kDivU: 578 return rewriter.create<UnsignedDivIOp>(loc, v0, v1); 579 case kAddF: 580 return rewriter.create<AddFOp>(loc, v0, v1); 581 case kAddI: 582 return rewriter.create<AddIOp>(loc, v0, v1); 583 case kSubF: 584 return rewriter.create<SubFOp>(loc, v0, v1); 585 case kSubI: 586 return rewriter.create<SubIOp>(loc, v0, v1); 587 case kAndI: 588 return rewriter.create<AndOp>(loc, v0, v1); 589 case kOrI: 590 return rewriter.create<OrOp>(loc, v0, v1); 591 case kXorI: 592 return rewriter.create<XOrOp>(loc, v0, v1); 593 case kShrS: 594 return rewriter.create<SignedShiftRightOp>(loc, v0, v1); 595 case kShrU: 596 return rewriter.create<UnsignedShiftRightOp>(loc, v0, v1); 597 case kShlI: 598 return rewriter.create<ShiftLeftOp>(loc, v0, v1); 599 } 600 llvm_unreachable("unexpected expression kind in build"); 601 } 602 603 } // namespace sparse_tensor 604 } // namespace mlir 605