1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Complex/IR/Complex.h"
12 #include "mlir/Dialect/Math/IR/Math.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14
15 #include "mlir/IR/Operation.h"
16 #include "llvm/Support/Debug.h"
17
18 namespace mlir {
19 namespace sparse_tensor {
20
21 //===----------------------------------------------------------------------===//
22 // Constructors.
23 //===----------------------------------------------------------------------===//
24
TensorExp(Kind k,unsigned x,unsigned y,Value v,Operation * o)25 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
26 : kind(k), val(v), op(o) {
27 switch (kind) {
28 // Leaf.
29 case kTensor:
30 assert(x != -1u && y == -1u && !v && !o);
31 tensor = x;
32 break;
33 case kInvariant:
34 assert(x == -1u && y == -1u && v && !o);
35 break;
36 case kIndex:
37 assert(x != -1u && y == -1u && !v && !o);
38 index = x;
39 break;
40 // Unary operations.
41 case kAbsF:
42 case kAbsC:
43 case kCeilF:
44 case kFloorF:
45 case kSqrtF:
46 case kSqrtC:
47 case kExpm1F:
48 case kExpm1C:
49 case kLog1pF:
50 case kLog1pC:
51 case kSinF:
52 case kSinC:
53 case kTanhF:
54 case kTanhC:
55 case kNegF:
56 case kNegC:
57 case kNegI:
58 case kCIm:
59 case kCRe:
60 assert(x != -1u && y == -1u && !v && !o);
61 children.e0 = x;
62 children.e1 = y;
63 break;
64 case kTruncF:
65 case kExtF:
66 case kCastFS:
67 case kCastFU:
68 case kCastSF:
69 case kCastUF:
70 case kCastS:
71 case kCastU:
72 case kCastIdx:
73 case kTruncI:
74 case kBitCast:
75 assert(x != -1u && y == -1u && v && !o);
76 children.e0 = x;
77 children.e1 = y;
78 break;
79 case kBinaryBranch:
80 assert(x != -1u && y == -1u && !v && o);
81 children.e0 = x;
82 children.e1 = y;
83 break;
84 case kUnary:
85 // No assertion on y can be made, as the branching paths involve both
86 // a unary (mapSet) and binary (takeDisj) pathway.
87 assert(x != -1u && !v && o);
88 children.e0 = x;
89 children.e1 = y;
90 break;
91 // Binary operations.
92 case kMulF:
93 case kMulC:
94 case kMulI:
95 case kDivF:
96 case kDivC:
97 case kDivS:
98 case kDivU:
99 case kAddF:
100 case kAddC:
101 case kAddI:
102 case kSubF:
103 case kSubC:
104 case kSubI:
105 case kAndI:
106 case kOrI:
107 case kXorI:
108 case kShrS:
109 case kShrU:
110 case kShlI:
111 assert(x != -1u && y != -1u && !v && !o);
112 children.e0 = x;
113 children.e1 = y;
114 break;
115 case kBinary:
116 assert(x != -1u && y != -1u && !v && o);
117 children.e0 = x;
118 children.e1 = y;
119 break;
120 }
121 }
122
LatPoint(unsigned n,unsigned e,unsigned b)123 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
124 : bits(n, false), simple(), exp(e) {
125 bits.set(b);
126 }
127
LatPoint(const BitVector & b,unsigned e)128 LatPoint::LatPoint(const BitVector &b, unsigned e)
129 : bits(b), simple(), exp(e) {}
130
131 //===----------------------------------------------------------------------===//
132 // Lattice methods.
133 //===----------------------------------------------------------------------===//
134
addExp(Kind k,unsigned e0,unsigned e1,Value v,Operation * op)135 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v,
136 Operation *op) {
137 unsigned e = tensorExps.size();
138 tensorExps.push_back(TensorExp(k, e0, e1, v, op));
139 return e;
140 }
141
addLat(unsigned t,unsigned i,unsigned e)142 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
143 assert(t < numTensors && i < numLoops);
144 unsigned p = latPoints.size();
145 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
146 return p;
147 }
148
addSet()149 unsigned Merger::addSet() {
150 unsigned s = latSets.size();
151 latSets.emplace_back(SmallVector<unsigned, 16>());
152 return s;
153 }
154
conjLatPoint(Kind kind,unsigned p0,unsigned p1,Operation * op)155 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1,
156 Operation *op) {
157 unsigned p = latPoints.size();
158 BitVector nb = BitVector(latPoints[p0].bits);
159 nb |= latPoints[p1].bits;
160 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op);
161 latPoints.push_back(LatPoint(nb, e));
162 return p;
163 }
164
takeConj(Kind kind,unsigned s0,unsigned s1,Operation * op)165 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
166 unsigned s = addSet();
167 for (unsigned p0 : latSets[s0])
168 for (unsigned p1 : latSets[s1])
169 latSets[s].push_back(conjLatPoint(kind, p0, p1, op));
170 return s;
171 }
172
takeDisj(Kind kind,unsigned s0,unsigned s1,Operation * op)173 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
174 unsigned s = takeConj(kind, s0, s1, op);
175 // Followed by all in s0.
176 for (unsigned p : latSets[s0])
177 latSets[s].push_back(p);
178 // Map binary 0-y to unary -y.
179 // TODO: move this if-else logic into buildLattices
180 if (kind == kSubF)
181 s1 = mapSet(kNegF, s1);
182 else if (kind == kSubC)
183 s1 = mapSet(kNegC, s1);
184 else if (kind == kSubI)
185 s1 = mapSet(kNegI, s1);
186 // Followed by all in s1.
187 for (unsigned p : latSets[s1])
188 latSets[s].push_back(p);
189 return s;
190 }
191
takeCombi(Kind kind,unsigned s0,unsigned s1,Operation * orig,bool includeLeft,Kind ltrans,Operation * opleft,bool includeRight,Kind rtrans,Operation * opright)192 unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
193 bool includeLeft, Kind ltrans, Operation *opleft,
194 bool includeRight, Kind rtrans, Operation *opright) {
195 unsigned s = takeConj(kind, s0, s1, orig);
196 // Left Region.
197 if (includeLeft) {
198 if (opleft)
199 s0 = mapSet(ltrans, s0, Value(), opleft);
200 for (unsigned p : latSets[s0])
201 latSets[s].push_back(p);
202 }
203 // Right Region.
204 if (includeRight) {
205 if (opright)
206 s1 = mapSet(rtrans, s1, Value(), opright);
207 for (unsigned p : latSets[s1])
208 latSets[s].push_back(p);
209 }
210 return s;
211 }
212
mapSet(Kind kind,unsigned s0,Value v,Operation * op)213 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) {
214 assert(kAbsF <= kind && kind <= kUnary);
215 unsigned s = addSet();
216 for (unsigned p : latSets[s0]) {
217 unsigned e = addExp(kind, latPoints[p].exp, v, op);
218 latPoints.push_back(LatPoint(latPoints[p].bits, e));
219 latSets[s].push_back(latPoints.size() - 1);
220 }
221 return s;
222 }
223
optimizeSet(unsigned s0)224 unsigned Merger::optimizeSet(unsigned s0) {
225 unsigned s = addSet();
226 assert(!latSets[s0].empty());
227 unsigned p0 = latSets[s0][0];
228 for (unsigned p1 : latSets[s0]) {
229 bool add = true;
230 if (p0 != p1) {
231 // Is this a straightforward copy?
232 unsigned e = latPoints[p1].exp;
233 if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
234 continue;
235 // Conjunction already covered?
236 for (unsigned p2 : latSets[s]) {
237 assert(!latGT(p1, p2)); // Lj => Li would be bad
238 if (onlyDenseDiff(p2, p1)) {
239 add = false;
240 break;
241 }
242 }
243 assert(!add || latGT(p0, p1));
244 }
245 if (add)
246 latSets[s].push_back(p1);
247 }
248 for (unsigned p : latSets[s])
249 latPoints[p].simple = simplifyCond(s, p);
250 return s;
251 }
252
simplifyCond(unsigned s0,unsigned p0)253 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
254 // First determine if this lattice point is a *singleton*, i.e.,
255 // the last point in a lattice, no other is less than this one.
256 bool isSingleton = true;
257 for (unsigned p1 : latSets[s0]) {
258 if (p0 != p1 && latGT(p0, p1)) {
259 isSingleton = false;
260 break;
261 }
262 }
263 // Now apply the two basic rules.
264 BitVector simple = latPoints[p0].bits;
265 bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
266 for (unsigned b = 0, be = simple.size(); b < be; b++) {
267 if (simple[b] && !isDim(b, kSparse)) {
268 if (reset)
269 simple.reset(b);
270 reset = true;
271 }
272 }
273 return simple;
274 }
275
latGT(unsigned i,unsigned j) const276 bool Merger::latGT(unsigned i, unsigned j) const {
277 const BitVector &bitsi = latPoints[i].bits;
278 const BitVector &bitsj = latPoints[j].bits;
279 assert(bitsi.size() == bitsj.size());
280 if (bitsi.count() > bitsj.count()) {
281 for (unsigned b = 0, be = bitsj.size(); b < be; b++)
282 if (bitsj[b] && !bitsi[b])
283 return false;
284 return true;
285 }
286 return false;
287 }
288
onlyDenseDiff(unsigned i,unsigned j)289 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
290 BitVector tmp = latPoints[j].bits;
291 tmp ^= latPoints[i].bits;
292 return !hasAnyDimOf(tmp, kSparse);
293 }
294
hasAnyDimOf(const BitVector & bits,Dim d) const295 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
296 for (unsigned b = 0, be = bits.size(); b < be; b++)
297 if (bits[b] && isDim(b, d))
298 return true;
299 return false;
300 }
301
isSingleCondition(unsigned t,unsigned e) const302 bool Merger::isSingleCondition(unsigned t, unsigned e) const {
303 switch (tensorExps[e].kind) {
304 // Leaf.
305 case kTensor:
306 return tensorExps[e].tensor == t;
307 case kInvariant:
308 case kIndex:
309 return false;
310 // Unary operations.
311 case kAbsF:
312 case kAbsC:
313 case kCeilF:
314 case kFloorF:
315 case kSqrtF:
316 case kSqrtC:
317 case kExpm1F:
318 case kExpm1C:
319 case kLog1pF:
320 case kLog1pC:
321 case kSinF:
322 case kSinC:
323 case kTanhF:
324 case kTanhC:
325 case kNegF:
326 case kNegC:
327 case kNegI:
328 case kTruncF:
329 case kExtF:
330 case kCastFS:
331 case kCastFU:
332 case kCastSF:
333 case kCastUF:
334 case kCastS:
335 case kCastU:
336 case kCastIdx:
337 case kTruncI:
338 case kCIm:
339 case kCRe:
340 case kBitCast:
341 return isSingleCondition(t, tensorExps[e].children.e0);
342 case kBinaryBranch:
343 case kUnary:
344 return false;
345 // Binary operations.
346 case kDivF: // note: x / c only
347 case kDivC:
348 case kDivS:
349 case kDivU:
350 assert(!maybeZero(tensorExps[e].children.e1));
351 return isSingleCondition(t, tensorExps[e].children.e0);
352 case kShrS: // note: x >> inv only
353 case kShrU:
354 case kShlI:
355 assert(isInvariant(tensorExps[e].children.e1));
356 return isSingleCondition(t, tensorExps[e].children.e0);
357 case kMulF:
358 case kMulC:
359 case kMulI:
360 case kAndI:
361 if (isSingleCondition(t, tensorExps[e].children.e0))
362 return isSingleCondition(t, tensorExps[e].children.e1) ||
363 isInvariant(tensorExps[e].children.e1);
364 if (isSingleCondition(t, tensorExps[e].children.e1))
365 return isInvariant(tensorExps[e].children.e0);
366 return false;
367 case kAddF:
368 case kAddC:
369 case kAddI:
370 return isSingleCondition(t, tensorExps[e].children.e0) &&
371 isSingleCondition(t, tensorExps[e].children.e1);
372 case kSubF:
373 case kSubC:
374 case kSubI:
375 case kOrI:
376 case kXorI:
377 case kBinary:
378 return false;
379 }
380 llvm_unreachable("unexpected kind");
381 }
382
383 #ifndef NDEBUG
384
385 //===----------------------------------------------------------------------===//
386 // Print methods (for debugging).
387 //===----------------------------------------------------------------------===//
388
kindToOpSymbol(Kind kind)389 static const char *kindToOpSymbol(Kind kind) {
390 switch (kind) {
391 // Leaf.
392 case kTensor:
393 return "tensor";
394 case kInvariant:
395 return "invariant";
396 case kIndex:
397 return "index";
398 // Unary operations.
399 case kAbsF:
400 case kAbsC:
401 return "abs";
402 case kCeilF:
403 return "ceil";
404 case kFloorF:
405 return "floor";
406 case kSqrtF:
407 case kSqrtC:
408 return "sqrt";
409 case kExpm1F:
410 case kExpm1C:
411 return "expm1";
412 case kLog1pF:
413 case kLog1pC:
414 return "log1p";
415 case kSinF:
416 case kSinC:
417 return "sin";
418 case kTanhF:
419 case kTanhC:
420 return "tanh";
421 case kNegF:
422 case kNegC:
423 case kNegI:
424 return "-";
425 case kTruncF:
426 case kExtF:
427 case kCastFS:
428 case kCastFU:
429 case kCastSF:
430 case kCastUF:
431 case kCastS:
432 case kCastU:
433 case kCastIdx:
434 case kTruncI:
435 case kCIm:
436 return "complex.im";
437 case kCRe:
438 return "complex.re";
439 case kBitCast:
440 return "cast";
441 case kBinaryBranch:
442 return "binary_branch";
443 case kUnary:
444 return "unary";
445 // Binary operations.
446 case kMulF:
447 case kMulC:
448 case kMulI:
449 return "*";
450 case kDivF:
451 case kDivC:
452 case kDivS:
453 case kDivU:
454 return "/";
455 case kAddF:
456 case kAddC:
457 case kAddI:
458 return "+";
459 case kSubF:
460 case kSubC:
461 case kSubI:
462 return "-";
463 case kAndI:
464 return "&";
465 case kOrI:
466 return "|";
467 case kXorI:
468 return "^";
469 case kShrS:
470 return "a>>";
471 case kShrU:
472 return ">>";
473 case kShlI:
474 return "<<";
475 case kBinary:
476 return "binary";
477 }
478 llvm_unreachable("unexpected kind for symbol");
479 }
480
dumpExp(unsigned e) const481 void Merger::dumpExp(unsigned e) const {
482 switch (tensorExps[e].kind) {
483 // Leaf.
484 case kTensor:
485 if (tensorExps[e].tensor == syntheticTensor)
486 llvm::dbgs() << "synthetic_";
487 else if (tensorExps[e].tensor == outTensor)
488 llvm::dbgs() << "output_";
489 llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
490 break;
491 case kInvariant:
492 llvm::dbgs() << "invariant";
493 break;
494 case kIndex:
495 llvm::dbgs() << "index_" << tensorExps[e].index;
496 break;
497 // Unary operations.
498 case kAbsF:
499 case kAbsC:
500 case kCeilF:
501 case kFloorF:
502 case kSqrtF:
503 case kSqrtC:
504 case kExpm1F:
505 case kExpm1C:
506 case kLog1pF:
507 case kLog1pC:
508 case kSinF:
509 case kSinC:
510 case kTanhF:
511 case kTanhC:
512 case kNegF:
513 case kNegC:
514 case kNegI:
515 case kTruncF:
516 case kExtF:
517 case kCastFS:
518 case kCastFU:
519 case kCastSF:
520 case kCastUF:
521 case kCastS:
522 case kCastU:
523 case kCastIdx:
524 case kTruncI:
525 case kCIm:
526 case kCRe:
527 case kBitCast:
528 case kBinaryBranch:
529 case kUnary:
530 llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
531 dumpExp(tensorExps[e].children.e0);
532 break;
533 // Binary operations.
534 case kMulF:
535 case kMulC:
536 case kMulI:
537 case kDivF:
538 case kDivC:
539 case kDivS:
540 case kDivU:
541 case kAddF:
542 case kAddC:
543 case kAddI:
544 case kSubF:
545 case kSubC:
546 case kSubI:
547 case kAndI:
548 case kOrI:
549 case kXorI:
550 case kShrS:
551 case kShrU:
552 case kShlI:
553 case kBinary:
554 llvm::dbgs() << "(";
555 dumpExp(tensorExps[e].children.e0);
556 llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
557 dumpExp(tensorExps[e].children.e1);
558 llvm::dbgs() << ")";
559 }
560 }
561
dumpLat(unsigned p) const562 void Merger::dumpLat(unsigned p) const {
563 llvm::dbgs() << "lat(";
564 dumpBits(latPoints[p].bits);
565 llvm::dbgs() << " :";
566 dumpBits(latPoints[p].simple);
567 llvm::dbgs() << " : ";
568 dumpExp(latPoints[p].exp);
569 llvm::dbgs() << " )\n";
570 }
571
dumpSet(unsigned s) const572 void Merger::dumpSet(unsigned s) const {
573 llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
574 for (unsigned p : latSets[s]) {
575 llvm::dbgs() << " ";
576 dumpLat(p);
577 }
578 llvm::dbgs() << "}\n";
579 }
580
dumpBits(const BitVector & bits) const581 void Merger::dumpBits(const BitVector &bits) const {
582 for (unsigned b = 0, be = bits.size(); b < be; b++) {
583 if (bits[b]) {
584 unsigned t = tensor(b);
585 unsigned i = index(b);
586 llvm::dbgs() << " i_" << t << "_" << i << "_";
587 switch (dims[t][i]) {
588 case kSparse:
589 llvm::dbgs() << "S";
590 break;
591 case kDense:
592 llvm::dbgs() << "D";
593 break;
594 case kSingle:
595 llvm::dbgs() << "T";
596 break;
597 case kUndef:
598 llvm::dbgs() << "U";
599 break;
600 }
601 }
602 }
603 }
604
605 #endif // NDEBUG
606
607 //===----------------------------------------------------------------------===//
608 // Builder methods.
609 //===----------------------------------------------------------------------===//
610
buildLattices(unsigned e,unsigned i)611 unsigned Merger::buildLattices(unsigned e, unsigned i) {
612 Kind kind = tensorExps[e].kind;
613 switch (kind) {
614 // Leaf.
615 case kTensor:
616 case kInvariant:
617 case kIndex: {
618 // Either the index is really used in the tensor expression, or it is
619 // set to the undefined index in that dimension. An invariant expression,
620 // a proper index value, and a truly dynamic sparse output tensor are set
621 // to a synthetic tensor with undefined indices only to ensure the
622 // iteration space is not skipped as a result of their contents.
623 unsigned s = addSet();
624 unsigned t = syntheticTensor;
625 if (kind == kTensor) {
626 t = tensorExps[e].tensor;
627 if (hasSparseOut && t == outTensor)
628 t = syntheticTensor;
629 }
630 latSets[s].push_back(addLat(t, i, e));
631 return s;
632 }
633 // Unary operations.
634 case kAbsF:
635 case kAbsC:
636 case kCeilF:
637 case kFloorF:
638 case kSqrtF:
639 case kSqrtC:
640 case kExpm1F:
641 case kExpm1C:
642 case kLog1pF:
643 case kLog1pC:
644 case kSinF:
645 case kSinC:
646 case kTanhF:
647 case kTanhC:
648 case kNegF:
649 case kNegC:
650 case kNegI:
651 case kTruncF:
652 case kExtF:
653 case kCastFS:
654 case kCastFU:
655 case kCastSF:
656 case kCastUF:
657 case kCastS:
658 case kCastU:
659 case kCastIdx:
660 case kTruncI:
661 case kCIm:
662 case kCRe:
663 case kBitCast:
664 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
665 // lattice set of the operand through the operator into a new set.
666 //
667 // -y|!y | y |
668 // --+---+---+
669 // | 0 |-y |
670 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
671 tensorExps[e].val);
672 case kBinaryBranch:
673 // The left or right half of a binary operation which has already
674 // been split into separate operations for each region.
675 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(),
676 tensorExps[e].op);
677 case kUnary:
678 // A custom unary operation.
679 //
680 // op y| !y | y |
681 // ----+----------+------------+
682 // | absent() | present(y) |
683 {
684 unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
685 UnaryOp unop = cast<UnaryOp>(tensorExps[e].op);
686 Region &absentRegion = unop.getAbsentRegion();
687
688 if (absentRegion.empty()) {
689 // Simple mapping over existing values.
690 return mapSet(kind, child0, Value(), unop);
691 } // Use a disjunction with `unop` on the left and the absent value as an
692 // invariant on the right.
693 Block &absentBlock = absentRegion.front();
694 YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
695 Value absentVal = absentYield.getResult();
696 unsigned rhs = addExp(kInvariant, absentVal);
697 return takeDisj(kind, child0, buildLattices(rhs, i), unop);
698 }
699 // Binary operations.
700 case kMulF:
701 case kMulC:
702 case kMulI:
703 case kAndI:
704 // A multiplicative operation only needs to be performed
705 // for the conjunction of sparse iteration spaces.
706 //
707 // x*y|!y | y |
708 // ---+---+---+
709 // !x | 0 | 0 |
710 // x | 0 |x*y|
711 //
712 // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
713 return takeConj(kind, // take binary conjunction
714 buildLattices(tensorExps[e].children.e0, i),
715 buildLattices(tensorExps[e].children.e1, i));
716 case kDivF:
717 case kDivC:
718 case kDivS:
719 case kDivU:
720 // A division is tricky, since 0/0, 0/c, c/0 all have
721 // specific outcomes for floating-point and integers.
722 // Thus, we need to traverse the full iteration space.
723 //
724 // x/y|!y | y |
725 // ---+---+---+
726 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
727 // x |x/0|x/y| INT: x/0=exception for any x
728 //
729 // TODO: for now we "fixed" this by only accepting x/c cases
730 // during expression building, so that the conjunction
731 // rules applies (viz. x/c = x*(1/c) as far as lattice
732 // construction is concerned).
733 assert(!maybeZero(tensorExps[e].children.e1));
734 return takeConj(kind, // take binary conjunction
735 buildLattices(tensorExps[e].children.e0, i),
736 buildLattices(tensorExps[e].children.e1, i));
737 case kAddF:
738 case kAddC:
739 case kAddI:
740 case kSubF:
741 case kSubC:
742 case kSubI:
743 case kOrI:
744 case kXorI:
745 // An additive operation needs to be performed
746 // for the disjunction of sparse iteration spaces.
747 //
748 // x+y|!y | y | x-y|!y | y |
749 // ---+---+---+ ---+---+---+
750 // !x | 0 | y | !x | 0 |-y |
751 // x | x |x+y| x | x |x-y|
752 return takeDisj(kind, // take binary disjunction
753 buildLattices(tensorExps[e].children.e0, i),
754 buildLattices(tensorExps[e].children.e1, i));
755 case kShrS:
756 case kShrU:
757 case kShlI:
758 // A shift operation by an invariant amount (viz. tensor expressions
759 // can only occur at the left-hand-side of the operator) can be handled
760 // with the conjuction rule.
761 assert(isInvariant(tensorExps[e].children.e1));
762 return takeConj(kind, // take binary conjunction
763 buildLattices(tensorExps[e].children.e0, i),
764 buildLattices(tensorExps[e].children.e1, i));
765 case kBinary:
766 // A custom binary operation.
767 //
768 // x op y| !y | y |
769 // ------+---------+--------------+
770 // !x | empty | right(y) |
771 // x | left(x) | overlap(x,y) |
772 {
773 unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
774 unsigned child1 = buildLattices(tensorExps[e].children.e1, i);
775 BinaryOp binop = cast<BinaryOp>(tensorExps[e].op);
776 Region &leftRegion = binop.getLeftRegion();
777 Region &rightRegion = binop.getRightRegion();
778 // Left Region.
779 Operation *leftYield = nullptr;
780 if (!leftRegion.empty()) {
781 Block &leftBlock = leftRegion.front();
782 leftYield = leftBlock.getTerminator();
783 }
784 // Right Region.
785 Operation *rightYield = nullptr;
786 if (!rightRegion.empty()) {
787 Block &rightBlock = rightRegion.front();
788 rightYield = rightBlock.getTerminator();
789 }
790 bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
791 bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
792 return takeCombi(kBinary, child0, child1, binop, includeLeft,
793 kBinaryBranch, leftYield, includeRight, kBinaryBranch,
794 rightYield);
795 }
796 }
797 llvm_unreachable("unexpected expression kind");
798 }
799
buildTensorExpFromLinalg(linalg::GenericOp op)800 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
801 // Build the linalg semantics backward from yield.
802 Operation *yield = op.region().front().getTerminator();
803 assert(isa<linalg::YieldOp>(yield));
804 return buildTensorExp(op, yield->getOperand(0));
805 }
806
807 /// Only returns false if we are certain this is a nonzero.
maybeZero(unsigned e) const808 bool Merger::maybeZero(unsigned e) const {
809 if (tensorExps[e].kind == kInvariant) {
810 if (auto c = tensorExps[e].val.getDefiningOp<complex::ConstantOp>()) {
811 ArrayAttr arrayAttr = c.getValue();
812 return arrayAttr[0].cast<FloatAttr>().getValue().isZero() &&
813 arrayAttr[0].cast<FloatAttr>().getValue().isZero();
814 }
815 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>())
816 return c.value() == 0;
817 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>())
818 return c.value().isZero();
819 }
820 return true;
821 }
822
isInvariant(unsigned e) const823 bool Merger::isInvariant(unsigned e) const {
824 return tensorExps[e].kind == kInvariant;
825 }
826
inferType(unsigned e,Value src)827 Type Merger::inferType(unsigned e, Value src) {
828 // Obtain the destination type from the cast node.
829 Type dtp = tensorExps[e].val.getType();
830 // Inspect source type. For vector types, apply the same
831 // vectorization to the destination type.
832 if (auto vtp = src.getType().dyn_cast<VectorType>())
833 return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
834 return dtp;
835 }
836
837 /// Ensures that sparse compiler can generate code for expression.
isAdmissableBranchExp(Operation * op,Block * block,Value v)838 static bool isAdmissableBranchExp(Operation *op, Block *block, Value v) {
839 // Arguments are always admissable.
840 if (auto arg = v.dyn_cast<BlockArgument>())
841 return true;
842 // Accept index anywhere.
843 Operation *def = v.getDefiningOp();
844 if (isa<linalg::IndexOp>(def))
845 return true;
846 // Operation defined outside branch.
847 if (def->getBlock() != block) {
848 return def->getBlock() != op->getBlock(); // invariant?
849 }
850 // Operation defined within branch. Anything is accepted,
851 // as long as all subexpressions are admissable.
852 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
853 if (!isAdmissableBranchExp(op, block, def->getOperand(i)))
854 return false;
855 return true;
856 }
857
858 /// Ensures that sparse compiler can generate code for branch.
isAdmissableBranch(Operation * op,Region & region)859 static bool isAdmissableBranch(Operation *op, Region ®ion) {
860 if (region.empty())
861 return true;
862 // Build the semi-ring branch semantics backward from yield.
863 Operation *yield = region.front().getTerminator();
864 assert(isa<YieldOp>(yield));
865 return isAdmissableBranchExp(op, ®ion.front(), yield->getOperand(0));
866 }
867
buildTensorExp(linalg::GenericOp op,Value v)868 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
869 if (auto arg = v.dyn_cast<BlockArgument>()) {
870 unsigned argN = arg.getArgNumber();
871 // Any argument of the generic op that is not marked as a scalar
872 // argument is considered a tensor, indexed by the implicit loop
873 // bounds. This includes rank-0 tensor arguments.
874 if (arg.getOwner()->getParentOp() == op) {
875 OpOperand *t = op.getInputAndOutputOperands()[argN];
876 if (!op.isScalar(t))
877 return addExp(kTensor, argN);
878 v = t->get(); // get scalar value
879 }
880 // Any other argument (marked as scalar argument for the generic op
881 // or belonging to an enveloping op) is considered invariant.
882 return addExp(kInvariant, v);
883 }
884 // Something defined outside is invariant.
885 Operation *def = v.getDefiningOp();
886 if (def->getBlock() != &op.region().front())
887 return addExp(kInvariant, v);
888 // Construct index operations.
889 if (def->getNumOperands() == 0) {
890 if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
891 return addExp(kIndex, indexOp.dim());
892 }
893 // Construct unary operations if subexpression can be built.
894 if (def->getNumOperands() == 1) {
895 auto x = buildTensorExp(op, def->getOperand(0));
896 if (x.has_value()) {
897 unsigned e = x.value();
898 if (isa<math::AbsOp>(def))
899 return addExp(kAbsF, e);
900 if (isa<complex::AbsOp>(def))
901 return addExp(kAbsC, e);
902 if (isa<math::CeilOp>(def))
903 return addExp(kCeilF, e);
904 if (isa<math::FloorOp>(def))
905 return addExp(kFloorF, e);
906 if (isa<math::SqrtOp>(def))
907 return addExp(kSqrtF, e);
908 if (isa<complex::SqrtOp>(def))
909 return addExp(kSqrtC, e);
910 if (isa<math::ExpM1Op>(def))
911 return addExp(kExpm1F, e);
912 if (isa<complex::Expm1Op>(def))
913 return addExp(kExpm1C, e);
914 if (isa<math::Log1pOp>(def))
915 return addExp(kLog1pF, e);
916 if (isa<complex::Log1pOp>(def))
917 return addExp(kLog1pC, e);
918 if (isa<math::SinOp>(def))
919 return addExp(kSinF, e);
920 if (isa<complex::SinOp>(def))
921 return addExp(kSinC, e);
922 if (isa<math::TanhOp>(def))
923 return addExp(kTanhF, e);
924 if (isa<complex::TanhOp>(def))
925 return addExp(kTanhC, e);
926 if (isa<arith::NegFOp>(def))
927 return addExp(kNegF, e); // no negi in std
928 if (isa<complex::NegOp>(def))
929 return addExp(kNegC, e);
930 if (isa<arith::TruncFOp>(def))
931 return addExp(kTruncF, e, v);
932 if (isa<arith::ExtFOp>(def))
933 return addExp(kExtF, e, v);
934 if (isa<arith::FPToSIOp>(def))
935 return addExp(kCastFS, e, v);
936 if (isa<arith::FPToUIOp>(def))
937 return addExp(kCastFU, e, v);
938 if (isa<arith::SIToFPOp>(def))
939 return addExp(kCastSF, e, v);
940 if (isa<arith::UIToFPOp>(def))
941 return addExp(kCastUF, e, v);
942 if (isa<arith::ExtSIOp>(def))
943 return addExp(kCastS, e, v);
944 if (isa<arith::ExtUIOp>(def))
945 return addExp(kCastU, e, v);
946 if (isa<arith::IndexCastOp>(def))
947 return addExp(kCastIdx, e, v);
948 if (isa<arith::TruncIOp>(def))
949 return addExp(kTruncI, e, v);
950 if (isa<complex::ImOp>(def))
951 return addExp(kCIm, e);
952 if (isa<complex::ReOp>(def))
953 return addExp(kCRe, e);
954 if (isa<arith::BitcastOp>(def))
955 return addExp(kBitCast, e, v);
956 if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
957 if (isAdmissableBranch(unop, unop.getPresentRegion()) &&
958 isAdmissableBranch(unop, unop.getAbsentRegion()))
959 return addExp(kUnary, e, Value(), def);
960 }
961 }
962 }
963 // Construct binary operations if subexpressions can be built.
964 // See buildLattices() for an explanation of rejecting certain
965 // division and shift operations
966 if (def->getNumOperands() == 2) {
967 auto x = buildTensorExp(op, def->getOperand(0));
968 auto y = buildTensorExp(op, def->getOperand(1));
969 if (x.has_value() && y.has_value()) {
970 unsigned e0 = x.value();
971 unsigned e1 = y.value();
972 if (isa<arith::MulFOp>(def))
973 return addExp(kMulF, e0, e1);
974 if (isa<complex::MulOp>(def))
975 return addExp(kMulC, e0, e1);
976 if (isa<arith::MulIOp>(def))
977 return addExp(kMulI, e0, e1);
978 if (isa<arith::DivFOp>(def) && !maybeZero(e1))
979 return addExp(kDivF, e0, e1);
980 if (isa<complex::DivOp>(def) && !maybeZero(e1))
981 return addExp(kDivC, e0, e1);
982 if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
983 return addExp(kDivS, e0, e1);
984 if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
985 return addExp(kDivU, e0, e1);
986 if (isa<arith::AddFOp>(def))
987 return addExp(kAddF, e0, e1);
988 if (isa<complex::AddOp>(def))
989 return addExp(kAddC, e0, e1);
990 if (isa<arith::AddIOp>(def))
991 return addExp(kAddI, e0, e1);
992 if (isa<arith::SubFOp>(def))
993 return addExp(kSubF, e0, e1);
994 if (isa<complex::SubOp>(def))
995 return addExp(kSubC, e0, e1);
996 if (isa<arith::SubIOp>(def))
997 return addExp(kSubI, e0, e1);
998 if (isa<arith::AndIOp>(def))
999 return addExp(kAndI, e0, e1);
1000 if (isa<arith::OrIOp>(def))
1001 return addExp(kOrI, e0, e1);
1002 if (isa<arith::XOrIOp>(def))
1003 return addExp(kXorI, e0, e1);
1004 if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1005 return addExp(kShrS, e0, e1);
1006 if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1007 return addExp(kShrU, e0, e1);
1008 if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1009 return addExp(kShlI, e0, e1);
1010 if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1011 if (isAdmissableBranch(binop, binop.getOverlapRegion()) &&
1012 (binop.getLeftIdentity() ||
1013 isAdmissableBranch(binop, binop.getLeftRegion())) &&
1014 (binop.getRightIdentity() ||
1015 isAdmissableBranch(binop, binop.getRightRegion())))
1016 return addExp(kBinary, e0, e1, Value(), def);
1017 }
1018 }
1019 }
1020 // Cannot build.
1021 return None;
1022 }
1023
insertYieldOp(RewriterBase & rewriter,Location loc,Region & region,ValueRange vals)1024 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion,
1025 ValueRange vals) {
1026 // Make a clone of overlap region.
1027 Region tmpRegion;
1028 BlockAndValueMapping mapper;
1029 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
1030 Block &clonedBlock = tmpRegion.front();
1031 YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
1032 // Merge cloned block and return yield value.
1033 Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1034 rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals);
1035 Value val = clonedYield.getResult();
1036 rewriter.eraseOp(clonedYield);
1037 rewriter.eraseOp(placeholder);
1038 return val;
1039 }
1040
buildUnaryPresent(RewriterBase & rewriter,Location loc,Operation * op,Value v0)1041 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc,
1042 Operation *op, Value v0) {
1043 if (!v0)
1044 // Empty input value must be propagated.
1045 return Value();
1046 UnaryOp unop = cast<UnaryOp>(op);
1047 Region &presentRegion = unop.getPresentRegion();
1048 if (presentRegion.empty())
1049 // Uninitialized Value() will be interpreted as missing data in the
1050 // output.
1051 return Value();
1052 return insertYieldOp(rewriter, loc, presentRegion, {v0});
1053 }
1054
buildBinaryOverlap(RewriterBase & rewriter,Location loc,Operation * op,Value v0,Value v1)1055 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
1056 Operation *op, Value v0, Value v1) {
1057 if (!v0 || !v1)
1058 // Empty input values must be propagated.
1059 return Value();
1060 BinaryOp binop = cast<BinaryOp>(op);
1061 Region &overlapRegion = binop.getOverlapRegion();
1062 if (overlapRegion.empty())
1063 // Uninitialized Value() will be interpreted as missing data in the
1064 // output.
1065 return Value();
1066 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
1067 }
1068
buildExp(RewriterBase & rewriter,Location loc,unsigned e,Value v0,Value v1)1069 Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
1070 Value v0, Value v1) {
1071 switch (tensorExps[e].kind) {
1072 // Leaf.
1073 case kTensor:
1074 case kInvariant:
1075 case kIndex:
1076 llvm_unreachable("unexpected non-op");
1077 // Unary operations.
1078 case kAbsF:
1079 return rewriter.create<math::AbsOp>(loc, v0);
1080 case kAbsC: {
1081 auto type = v0.getType().cast<ComplexType>();
1082 auto eltType = type.getElementType().cast<FloatType>();
1083 return rewriter.create<complex::AbsOp>(loc, eltType, v0);
1084 }
1085 case kCeilF:
1086 return rewriter.create<math::CeilOp>(loc, v0);
1087 case kFloorF:
1088 return rewriter.create<math::FloorOp>(loc, v0);
1089 case kSqrtF:
1090 return rewriter.create<math::SqrtOp>(loc, v0);
1091 case kSqrtC:
1092 return rewriter.create<complex::SqrtOp>(loc, v0);
1093 case kExpm1F:
1094 return rewriter.create<math::ExpM1Op>(loc, v0);
1095 case kExpm1C:
1096 return rewriter.create<complex::Expm1Op>(loc, v0);
1097 case kLog1pF:
1098 return rewriter.create<math::Log1pOp>(loc, v0);
1099 case kLog1pC:
1100 return rewriter.create<complex::Log1pOp>(loc, v0);
1101 case kSinF:
1102 return rewriter.create<math::SinOp>(loc, v0);
1103 case kSinC:
1104 return rewriter.create<complex::SinOp>(loc, v0);
1105 case kTanhF:
1106 return rewriter.create<math::TanhOp>(loc, v0);
1107 case kTanhC:
1108 return rewriter.create<complex::TanhOp>(loc, v0);
1109 case kNegF:
1110 return rewriter.create<arith::NegFOp>(loc, v0);
1111 case kNegC:
1112 return rewriter.create<complex::NegOp>(loc, v0);
1113 case kNegI: // no negi in std
1114 return rewriter.create<arith::SubIOp>(
1115 loc,
1116 rewriter.create<arith::ConstantOp>(loc, v0.getType(),
1117 rewriter.getZeroAttr(v0.getType())),
1118 v0);
1119 case kTruncF:
1120 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1121 case kExtF:
1122 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1123 case kCastFS:
1124 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1125 case kCastFU:
1126 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1127 case kCastSF:
1128 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1129 case kCastUF:
1130 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1131 case kCastS:
1132 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1133 case kCastU:
1134 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
1135 case kCastIdx:
1136 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1137 case kTruncI:
1138 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
1139 case kCIm: {
1140 auto type = v0.getType().cast<ComplexType>();
1141 auto eltType = type.getElementType().cast<FloatType>();
1142 return rewriter.create<complex::ImOp>(loc, eltType, v0);
1143 }
1144 case kCRe: {
1145 auto type = v0.getType().cast<ComplexType>();
1146 auto eltType = type.getElementType().cast<FloatType>();
1147 return rewriter.create<complex::ReOp>(loc, eltType, v0);
1148 }
1149 case kBitCast:
1150 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1151 // Binary operations.
1152 case kMulF:
1153 return rewriter.create<arith::MulFOp>(loc, v0, v1);
1154 case kMulC:
1155 return rewriter.create<complex::MulOp>(loc, v0, v1);
1156 case kMulI:
1157 return rewriter.create<arith::MulIOp>(loc, v0, v1);
1158 case kDivF:
1159 return rewriter.create<arith::DivFOp>(loc, v0, v1);
1160 case kDivC:
1161 return rewriter.create<complex::DivOp>(loc, v0, v1);
1162 case kDivS:
1163 return rewriter.create<arith::DivSIOp>(loc, v0, v1);
1164 case kDivU:
1165 return rewriter.create<arith::DivUIOp>(loc, v0, v1);
1166 case kAddF:
1167 return rewriter.create<arith::AddFOp>(loc, v0, v1);
1168 case kAddC:
1169 return rewriter.create<complex::AddOp>(loc, v0, v1);
1170 case kAddI:
1171 return rewriter.create<arith::AddIOp>(loc, v0, v1);
1172 case kSubF:
1173 return rewriter.create<arith::SubFOp>(loc, v0, v1);
1174 case kSubC:
1175 return rewriter.create<complex::SubOp>(loc, v0, v1);
1176 case kSubI:
1177 return rewriter.create<arith::SubIOp>(loc, v0, v1);
1178 case kAndI:
1179 return rewriter.create<arith::AndIOp>(loc, v0, v1);
1180 case kOrI:
1181 return rewriter.create<arith::OrIOp>(loc, v0, v1);
1182 case kXorI:
1183 return rewriter.create<arith::XOrIOp>(loc, v0, v1);
1184 case kShrS:
1185 return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
1186 case kShrU:
1187 return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
1188 case kShlI:
1189 return rewriter.create<arith::ShLIOp>(loc, v0, v1);
1190 case kBinaryBranch: // semi-ring ops with custom logic.
1191 return insertYieldOp(rewriter, loc,
1192 *tensorExps[e].op->getBlock()->getParent(), {v0});
1193 case kUnary:
1194 return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0);
1195 case kBinary:
1196 return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1);
1197 }
1198 llvm_unreachable("unexpected expression kind in build");
1199 }
1200
1201 } // namespace sparse_tensor
1202 } // namespace mlir
1203