1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 11 #include "mlir/IR/Operation.h" 12 #include "llvm/Support/Debug.h" 13 14 namespace mlir { 15 namespace sparse_tensor { 16 17 // 18 // Constructors. 19 // 20 21 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v) 22 : kind(k), val(v) { 23 switch (kind) { 24 case Kind::kTensor: 25 assert(x != -1u && y == -1u && !v); 26 tensor = x; 27 break; 28 case Kind::kInvariant: 29 assert(x == -1u && y == -1u && v); 30 break; 31 case kAbsF: 32 case kCeilF: 33 case kFloorF: 34 case kNegF: 35 case kNegI: 36 assert(x != -1u && y == -1u && !v); 37 children.e0 = x; 38 children.e1 = y; 39 break; 40 default: 41 assert(x != -1u && y != -1u && !v); 42 children.e0 = x; 43 children.e1 = y; 44 break; 45 } 46 } 47 48 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b) 49 : bits(n, false), simple(), exp(e) { 50 bits.set(b); 51 } 52 53 LatPoint::LatPoint(const llvm::BitVector &b, unsigned e) 54 : bits(b), simple(), exp(e) {} 55 56 // 57 // Lattice methods. 58 // 59 60 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) { 61 unsigned e = tensorExps.size(); 62 tensorExps.push_back(TensorExp(k, e0, e1, v)); 63 return e; 64 } 65 66 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { 67 assert(t < numTensors && i < numLoops); 68 unsigned p = latPoints.size(); 69 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 70 return p; 71 } 72 73 unsigned Merger::addSet() { 74 unsigned s = latSets.size(); 75 latSets.emplace_back(SmallVector<unsigned, 16>()); 76 return s; 77 } 78 79 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) { 80 unsigned p = latPoints.size(); 81 llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits); 82 nb |= latPoints[p1].bits; 83 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); 84 latPoints.push_back(LatPoint(nb, e)); 85 return p; 86 } 87 88 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) { 89 unsigned s = addSet(); 90 for (unsigned p0 : latSets[s0]) 91 for (unsigned p1 : latSets[s1]) 92 latSets[s].push_back(conjLatPoint(kind, p0, p1)); 93 return s; 94 } 95 96 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) { 97 unsigned s = takeConj(kind, s0, s1); 98 // Followed by all in s0. 99 for (unsigned p : latSets[s0]) 100 latSets[s].push_back(p); 101 // Map binary 0-y to unary -y. 102 if (kind == Kind::kSubF) 103 s1 = mapSet(Kind::kNegF, s1); 104 else if (kind == Kind::kSubI) 105 s1 = mapSet(Kind::kNegI, s1); 106 // Followed by all in s1. 107 for (unsigned p : latSets[s1]) 108 latSets[s].push_back(p); 109 return s; 110 } 111 112 unsigned Merger::mapSet(Kind kind, unsigned s0) { 113 assert(Kind::kAbsF <= kind && kind <= Kind::kNegI); 114 unsigned s = addSet(); 115 for (unsigned p : latSets[s0]) { 116 unsigned e = addExp(kind, latPoints[p].exp); 117 latPoints.push_back(LatPoint(latPoints[p].bits, e)); 118 latSets[s].push_back(latPoints.size() - 1); 119 } 120 return s; 121 } 122 123 unsigned Merger::optimizeSet(unsigned s0) { 124 unsigned s = addSet(); 125 assert(latSets[s0].size() != 0); 126 unsigned p0 = latSets[s0][0]; 127 for (unsigned p1 : latSets[s0]) { 128 bool add = true; 129 if (p0 != p1) { 130 // Is this a straightforward copy? 131 unsigned e = latPoints[p1].exp; 132 if (tensorExps[e].kind == Kind::kTensor && 133 tensorExps[e].tensor == outTensor) 134 continue; 135 // Conjunction already covered? 136 for (unsigned p2 : latSets[s]) { 137 assert(!latGT(p1, p2)); // Lj => Li would be bad 138 if (onlyDenseDiff(p2, p1)) { 139 add = false; 140 break; 141 } 142 } 143 assert(!add || latGT(p0, p1)); 144 } 145 if (add) 146 latSets[s].push_back(p1); 147 } 148 for (unsigned p : latSets[s]) 149 latPoints[p].simple = simplifyCond(s, p); 150 return s; 151 } 152 153 llvm::BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { 154 // First determine if this lattice point is a *singleton*, i.e., 155 // the last point in a lattice, no other is less than this one. 156 bool isSingleton = true; 157 for (unsigned p1 : latSets[s0]) { 158 if (p0 != p1 && latGT(p0, p1)) { 159 isSingleton = false; 160 break; 161 } 162 } 163 // Now apply the two basic rules. 164 llvm::BitVector simple = latPoints[p0].bits; 165 bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse); 166 for (unsigned b = 0, be = simple.size(); b < be; b++) { 167 if (simple[b] && !isDim(b, Dim::kSparse)) { 168 if (reset) 169 simple.reset(b); 170 reset = true; 171 } 172 } 173 return simple; 174 } 175 176 bool Merger::latGT(unsigned i, unsigned j) const { 177 const llvm::BitVector &bitsi = latPoints[i].bits; 178 const llvm::BitVector &bitsj = latPoints[j].bits; 179 assert(bitsi.size() == bitsj.size()); 180 if (bitsi.count() > bitsj.count()) { 181 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 182 if (bitsj[b] && !bitsi[b]) 183 return false; 184 return true; 185 } 186 return false; 187 } 188 189 bool Merger::onlyDenseDiff(unsigned i, unsigned j) { 190 llvm::BitVector tmp = latPoints[j].bits; 191 tmp ^= latPoints[i].bits; 192 return !hasAnyDimOf(tmp, Dim::kSparse); 193 } 194 195 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const { 196 for (unsigned b = 0, be = bits.size(); b < be; b++) 197 if (bits[b] && isDim(b, d)) 198 return true; 199 return false; 200 } 201 202 bool Merger::isConjunction(unsigned t, unsigned e) const { 203 switch (tensorExps[e].kind) { 204 case Kind::kTensor: 205 return tensorExps[e].tensor == t; 206 case kAbsF: 207 case kCeilF: 208 case kFloorF: 209 case kNegF: 210 case kNegI: 211 case Kind::kDivF: // note: x / c only 212 case Kind::kDivS: 213 case Kind::kDivU: 214 case Kind::kShrS: // note: x >> inv only 215 case Kind::kShrU: 216 case Kind::kShlI: 217 return isConjunction(t, tensorExps[e].children.e0); 218 case Kind::kMulF: 219 case Kind::kMulI: 220 case Kind::kAndI: 221 return isConjunction(t, tensorExps[e].children.e0) || 222 isConjunction(t, tensorExps[e].children.e1); 223 default: 224 return false; 225 } 226 } 227 228 #ifndef NDEBUG 229 230 // 231 // Print methods (for debugging). 232 // 233 234 static const char *kOpSymbols[] = { 235 "", "", "abs", "ceil", "floor", "-", "-", "*", "*", "/", "/", 236 "+", "+", "-", "-", "&", "|", "^", "a>>", ">>", "<<"}; 237 238 void Merger::dumpExp(unsigned e) const { 239 switch (tensorExps[e].kind) { 240 case Kind::kTensor: 241 if (tensorExps[e].tensor == syntheticTensor) 242 llvm::dbgs() << "synthetic_"; 243 else if (tensorExps[e].tensor == outTensor) 244 llvm::dbgs() << "output_"; 245 llvm::dbgs() << "tensor_" << tensorExps[e].tensor; 246 break; 247 case Kind::kInvariant: 248 llvm::dbgs() << "invariant"; 249 break; 250 case kAbsF: 251 case kCeilF: 252 case kFloorF: 253 case kNegF: 254 case kNegI: 255 llvm::dbgs() << kOpSymbols[tensorExps[e].kind] << " "; 256 dumpExp(tensorExps[e].children.e0); 257 break; 258 default: 259 llvm::dbgs() << "("; 260 dumpExp(tensorExps[e].children.e0); 261 llvm::dbgs() << " " << kOpSymbols[tensorExps[e].kind] << " "; 262 dumpExp(tensorExps[e].children.e1); 263 llvm::dbgs() << ")"; 264 } 265 } 266 267 void Merger::dumpLat(unsigned p) const { 268 llvm::dbgs() << "lat("; 269 dumpBits(latPoints[p].bits); 270 llvm::dbgs() << " :"; 271 dumpBits(latPoints[p].simple); 272 llvm::dbgs() << " : "; 273 dumpExp(latPoints[p].exp); 274 llvm::dbgs() << " )\n"; 275 } 276 277 void Merger::dumpSet(unsigned s) const { 278 llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; 279 for (unsigned p : latSets[s]) { 280 llvm::dbgs() << " "; 281 dumpLat(p); 282 } 283 llvm::dbgs() << "}\n"; 284 } 285 286 void Merger::dumpBits(const llvm::BitVector &bits) const { 287 for (unsigned b = 0, be = bits.size(); b < be; b++) { 288 if (bits[b]) { 289 unsigned t = tensor(b); 290 unsigned i = index(b); 291 llvm::dbgs() << " i_" << t << "_" << i << "_"; 292 switch (dims[t][i]) { 293 case Dim::kSparse: 294 llvm::dbgs() << "S"; 295 break; 296 case Dim::kDense: 297 llvm::dbgs() << "D"; 298 break; 299 case Dim::kSingle: 300 llvm::dbgs() << "T"; 301 break; 302 case Dim::kUndef: 303 llvm::dbgs() << "U"; 304 break; 305 } 306 } 307 } 308 } 309 310 #endif // NDEBUG 311 312 // 313 // Builder methods. 314 // 315 316 unsigned Merger::buildLattices(unsigned e, unsigned i) { 317 Kind kind = tensorExps[e].kind; 318 switch (kind) { 319 case Kind::kTensor: 320 case Kind::kInvariant: { 321 // Either the index is really used in the tensor expression, or it is 322 // set to the undefined index in that dimension. An invariant expression 323 // is set to a synthetic tensor with undefined indices only. 324 unsigned s = addSet(); 325 unsigned t = kind == Kind::kTensor ? tensorExps[e].tensor : syntheticTensor; 326 latSets[s].push_back(addLat(t, i, e)); 327 return s; 328 } 329 case kAbsF: 330 case kCeilF: 331 case kFloorF: 332 case kNegF: 333 case kNegI: 334 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 335 // lattice set of the operand through the operator into a new set. 336 // 337 // -y|!y | y | 338 // --+---+---+ 339 // | 0 |-y | 340 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i)); 341 case Kind::kMulF: 342 case Kind::kMulI: 343 case Kind::kAndI: 344 // A multiplicative operation only needs to be performed 345 // for the conjunction of sparse iteration spaces. 346 // 347 // x*y|!y | y | 348 // ---+---+---+ 349 // !x | 0 | 0 | 350 // x | 0 |x*y| 351 return takeConj(kind, // take binary conjunction 352 buildLattices(tensorExps[e].children.e0, i), 353 buildLattices(tensorExps[e].children.e1, i)); 354 case Kind::kDivF: 355 case Kind::kDivS: 356 case Kind::kDivU: 357 // A division is tricky, since 0/0, 0/c, c/0 all have 358 // specific outcomes for floating-point and integers. 359 // Thus, we need to traverse the full iteration space. 360 // 361 // x/y|!y | y | 362 // ---+---+---+ 363 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 364 // x |x/0|x/y| INT: x/0=exception for any x 365 // 366 // TODO: for now we "fixed" this by only accepting x/c cases 367 // during expression building, so that the conjunction 368 // rules applies (viz. x/c = x*(1/c) as far as lattice 369 // construction is concerned). 370 return takeConj(kind, // take binary conjunction 371 buildLattices(tensorExps[e].children.e0, i), 372 buildLattices(tensorExps[e].children.e1, i)); 373 case Kind::kAddF: 374 case Kind::kAddI: 375 case Kind::kSubF: 376 case Kind::kSubI: 377 case Kind::kOrI: 378 case Kind::kXorI: 379 // An additive operation needs to be performed 380 // for the disjunction of sparse iteration spaces. 381 // 382 // x+y|!y | y | x-y|!y | y | 383 // ---+---+---+ ---+---+---+ 384 // !x | 0 | y | !x | 0 |-y | 385 // x | x |x+y| x | x |x-y| 386 return takeDisj(kind, // take binary disjunction 387 buildLattices(tensorExps[e].children.e0, i), 388 buildLattices(tensorExps[e].children.e1, i)); 389 case Kind::kShrS: 390 case Kind::kShrU: 391 case Kind::kShlI: 392 // A shift operation by an invariant amount (viz. tensor expressions 393 // can only occur at the left-hand-side of the operator) can be handled 394 // with the conjuction rule. 395 return takeConj(kind, // take binary conjunction 396 buildLattices(tensorExps[e].children.e0, i), 397 buildLattices(tensorExps[e].children.e1, i)); 398 } 399 llvm_unreachable("unexpected expression kind"); 400 } 401 402 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 403 Operation *yield = op.region().front().getTerminator(); 404 return buildTensorExp(op, yield->getOperand(0)); 405 } 406 407 bool Merger::maybeZero(unsigned e) const { 408 if (tensorExps[e].kind == Kind::kInvariant) { 409 if (auto c = tensorExps[e].val.getDefiningOp<ConstantIntOp>()) 410 return c.getValue() == 0; 411 if (auto c = tensorExps[e].val.getDefiningOp<ConstantFloatOp>()) 412 return c.getValue().isZero(); 413 } 414 return true; 415 } 416 417 bool Merger::isInvariant(unsigned e) const { 418 return tensorExps[e].kind == Kind::kInvariant; 419 } 420 421 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) { 422 if (auto arg = v.dyn_cast<BlockArgument>()) { 423 unsigned argN = arg.getArgNumber(); 424 // Any argument of the generic op that is not marked as a scalar 425 // argument is considered a tensor, indexed by the implicit loop 426 // bounds. This includes rank-0 tensor arguments. 427 if (arg.getOwner()->getParentOp() == op) { 428 OpOperand *t = op.getInputAndOutputOperands()[argN]; 429 if (!op.isScalar(t)) 430 return addExp(Kind::kTensor, argN); 431 v = t->get(); // get scalar value 432 } 433 // Any other argument (marked as scalar argument for the generic op 434 // or belonging to an enveloping op) is considered invariant. 435 return addExp(Kind::kInvariant, v); 436 } 437 // Something defined outside is invariant. 438 Operation *def = v.getDefiningOp(); 439 if (def->getBlock() != &op.region().front()) 440 return addExp(Kind::kInvariant, v); 441 // Construct unary operations if subexpression can be built. 442 if (def->getNumOperands() == 1) { 443 auto x = buildTensorExp(op, def->getOperand(0)); 444 if (x.hasValue()) { 445 unsigned e = x.getValue(); 446 if (isa<AbsFOp>(def)) 447 return addExp(Kind::kAbsF, e); 448 if (isa<CeilFOp>(def)) 449 return addExp(Kind::kCeilF, e); 450 if (isa<FloorFOp>(def)) 451 return addExp(Kind::kFloorF, e); 452 if (isa<NegFOp>(def)) 453 return addExp(Kind::kNegF, e); 454 // TODO: no negi in std? 455 } 456 } 457 // Construct binary operations if subexpressions can be built. 458 // TODO: see buildLattices() for an explanation of rejecting certain divisions 459 if (def->getNumOperands() == 2) { 460 auto x = buildTensorExp(op, def->getOperand(0)); 461 auto y = buildTensorExp(op, def->getOperand(1)); 462 if (x.hasValue() && y.hasValue()) { 463 unsigned e0 = x.getValue(); 464 unsigned e1 = y.getValue(); 465 if (isa<MulFOp>(def)) 466 return addExp(Kind::kMulF, e0, e1); 467 if (isa<MulIOp>(def)) 468 return addExp(Kind::kMulI, e0, e1); 469 if (isa<DivFOp>(def) && !maybeZero(e1)) 470 return addExp(Kind::kDivF, e0, e1); 471 if (isa<SignedDivIOp>(def) && !maybeZero(e1)) 472 return addExp(Kind::kDivS, e0, e1); 473 if (isa<UnsignedDivIOp>(def) && !maybeZero(e1)) 474 return addExp(Kind::kDivU, e0, e1); 475 if (isa<AddFOp>(def)) 476 return addExp(Kind::kAddF, e0, e1); 477 if (isa<AddIOp>(def)) 478 return addExp(Kind::kAddI, e0, e1); 479 if (isa<SubFOp>(def)) 480 return addExp(Kind::kSubF, e0, e1); 481 if (isa<SubIOp>(def)) 482 return addExp(Kind::kSubI, e0, e1); 483 if (isa<AndOp>(def)) 484 return addExp(Kind::kAndI, e0, e1); 485 if (isa<OrOp>(def)) 486 return addExp(Kind::kOrI, e0, e1); 487 if (isa<XOrOp>(def)) 488 return addExp(Kind::kXorI, e0, e1); 489 if (isa<SignedShiftRightOp>(def) && isInvariant(e1)) 490 return addExp(Kind::kShrS, e0, e1); 491 if (isa<UnsignedShiftRightOp>(def) && isInvariant(e1)) 492 return addExp(Kind::kShrU, e0, e1); 493 if (isa<ShiftLeftOp>(def) && isInvariant(e1)) 494 return addExp(Kind::kShlI, e0, e1); 495 } 496 } 497 // Cannot build. 498 return None; 499 } 500 501 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e, 502 Value v0, Value v1) { 503 switch (tensorExps[e].kind) { 504 case Kind::kTensor: 505 case Kind::kInvariant: 506 llvm_unreachable("unexpected non-op"); 507 case kAbsF: 508 return rewriter.create<AbsFOp>(loc, v0); 509 case kCeilF: 510 return rewriter.create<CeilFOp>(loc, v0); 511 case kFloorF: 512 return rewriter.create<FloorFOp>(loc, v0); 513 case kNegF: 514 return rewriter.create<NegFOp>(loc, v0); 515 case kNegI: 516 assert(v1); // no negi in std 517 return rewriter.create<SubIOp>(loc, v0, v1); 518 case Kind::kMulF: 519 return rewriter.create<MulFOp>(loc, v0, v1); 520 case Kind::kMulI: 521 return rewriter.create<MulIOp>(loc, v0, v1); 522 case Kind::kDivF: 523 return rewriter.create<DivFOp>(loc, v0, v1); 524 case Kind::kDivS: 525 return rewriter.create<SignedDivIOp>(loc, v0, v1); 526 case Kind::kDivU: 527 return rewriter.create<UnsignedDivIOp>(loc, v0, v1); 528 case Kind::kAddF: 529 return rewriter.create<AddFOp>(loc, v0, v1); 530 case Kind::kAddI: 531 return rewriter.create<AddIOp>(loc, v0, v1); 532 case Kind::kSubF: 533 return rewriter.create<SubFOp>(loc, v0, v1); 534 case Kind::kSubI: 535 return rewriter.create<SubIOp>(loc, v0, v1); 536 case Kind::kAndI: 537 return rewriter.create<AndOp>(loc, v0, v1); 538 case Kind::kOrI: 539 return rewriter.create<OrOp>(loc, v0, v1); 540 case Kind::kXorI: 541 return rewriter.create<XOrOp>(loc, v0, v1); 542 case Kind::kShrS: 543 return rewriter.create<SignedShiftRightOp>(loc, v0, v1); 544 case Kind::kShrU: 545 return rewriter.create<UnsignedShiftRightOp>(loc, v0, v1); 546 case Kind::kShlI: 547 return rewriter.create<ShiftLeftOp>(loc, v0, v1); 548 } 549 llvm_unreachable("unexpected expression kind in build"); 550 } 551 552 } // namespace sparse_tensor 553 } // namespace mlir 554