1744146f6SGus Smith //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2744146f6SGus Smith //
3744146f6SGus Smith // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4744146f6SGus Smith // See https://llvm.org/LICENSE.txt for license information.
5744146f6SGus Smith // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6744146f6SGus Smith //
7744146f6SGus Smith //===----------------------------------------------------------------------===//
8744146f6SGus Smith 
9744146f6SGus Smith #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11736c1b66SAart Bik #include "mlir/Dialect/Complex/IR/Complex.h"
12eda6f907SRiver Riddle #include "mlir/Dialect/Math/IR/Math.h"
1390c2af57SMehdi Amini #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14744146f6SGus Smith 
15557b101cSAart Bik #include "mlir/IR/Operation.h"
16557b101cSAart Bik #include "llvm/Support/Debug.h"
17557b101cSAart Bik 
18744146f6SGus Smith namespace mlir {
19744146f6SGus Smith namespace sparse_tensor {
20744146f6SGus Smith 
21e2d3db42SAart Bik //===----------------------------------------------------------------------===//
22b8a021dbSAart Bik // Constructors.
23e2d3db42SAart Bik //===----------------------------------------------------------------------===//
24b8a021dbSAart Bik 
TensorExp(Kind k,unsigned x,unsigned y,Value v,Operation * o)252c332660SJim Kitchen TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
262c332660SJim Kitchen     : kind(k), val(v), op(o) {
27b8a021dbSAart Bik   switch (kind) {
2806aa6ec8SAart Bik   // Leaf.
298fe65972SAart Bik   case kTensor:
302c332660SJim Kitchen     assert(x != -1u && y == -1u && !v && !o);
31b8a021dbSAart Bik     tensor = x;
32b8a021dbSAart Bik     break;
338fe65972SAart Bik   case kInvariant:
342c332660SJim Kitchen     assert(x == -1u && y == -1u && v && !o);
35b8a021dbSAart Bik     break;
3653cc3a06SAart Bik   case kIndex:
372c332660SJim Kitchen     assert(x != -1u && y == -1u && !v && !o);
3853cc3a06SAart Bik     index = x;
3953cc3a06SAart Bik     break;
4006aa6ec8SAart Bik   // Unary operations.
41123e8dfcSAart Bik   case kAbsF:
42d390035bSBixia Zheng   case kAbsC:
43123e8dfcSAart Bik   case kCeilF:
44123e8dfcSAart Bik   case kFloorF:
45952fa301SAart Bik   case kSqrtF:
46a14057d4Sbixia1   case kSqrtC:
47952fa301SAart Bik   case kExpm1F:
48a14057d4Sbixia1   case kExpm1C:
49952fa301SAart Bik   case kLog1pF:
50d390035bSBixia Zheng   case kLog1pC:
51952fa301SAart Bik   case kSinF:
52d390035bSBixia Zheng   case kSinC:
53952fa301SAart Bik   case kTanhF:
54a14057d4Sbixia1   case kTanhC:
55123e8dfcSAart Bik   case kNegF:
56d390035bSBixia Zheng   case kNegC:
57123e8dfcSAart Bik   case kNegI:
5869edacbcSBixia Zheng   case kCIm:
5969edacbcSBixia Zheng   case kCRe:
602c332660SJim Kitchen     assert(x != -1u && y == -1u && !v && !o);
61123e8dfcSAart Bik     children.e0 = x;
62123e8dfcSAart Bik     children.e1 = y;
63b8a021dbSAart Bik     break;
64e2d3db42SAart Bik   case kTruncF:
65e2d3db42SAart Bik   case kExtF:
66e2d3db42SAart Bik   case kCastFS:
67e2d3db42SAart Bik   case kCastFU:
68e2d3db42SAart Bik   case kCastSF:
69e2d3db42SAart Bik   case kCastUF:
70e2d3db42SAart Bik   case kCastS:
71e2d3db42SAart Bik   case kCastU:
7253cc3a06SAart Bik   case kCastIdx:
73e2d3db42SAart Bik   case kTruncI:
74e2d3db42SAart Bik   case kBitCast:
752c332660SJim Kitchen     assert(x != -1u && y == -1u && v && !o);
762c332660SJim Kitchen     children.e0 = x;
772c332660SJim Kitchen     children.e1 = y;
782c332660SJim Kitchen     break;
792c332660SJim Kitchen   case kBinaryBranch:
802c332660SJim Kitchen     assert(x != -1u && y == -1u && !v && o);
812c332660SJim Kitchen     children.e0 = x;
822c332660SJim Kitchen     children.e1 = y;
832c332660SJim Kitchen     break;
842c332660SJim Kitchen   case kUnary:
852c332660SJim Kitchen     // No assertion on y can be made, as the branching paths involve both
862c332660SJim Kitchen     // a unary (mapSet) and binary (takeDisj) pathway.
872c332660SJim Kitchen     assert(x != -1u && !v && o);
882c332660SJim Kitchen     children.e0 = x;
892c332660SJim Kitchen     children.e1 = y;
902c332660SJim Kitchen     break;
9106aa6ec8SAart Bik   // Binary operations.
9206aa6ec8SAart Bik   case kMulF:
9306aa6ec8SAart Bik   case kMulC:
9406aa6ec8SAart Bik   case kMulI:
9506aa6ec8SAart Bik   case kDivF:
9606aa6ec8SAart Bik   case kDivC:
9706aa6ec8SAart Bik   case kDivS:
9806aa6ec8SAart Bik   case kDivU:
9906aa6ec8SAart Bik   case kAddF:
10006aa6ec8SAart Bik   case kAddC:
10106aa6ec8SAart Bik   case kAddI:
10206aa6ec8SAart Bik   case kSubF:
10306aa6ec8SAart Bik   case kSubC:
10406aa6ec8SAart Bik   case kSubI:
10506aa6ec8SAart Bik   case kAndI:
10606aa6ec8SAart Bik   case kOrI:
10706aa6ec8SAart Bik   case kXorI:
10806aa6ec8SAart Bik   case kShrS:
10906aa6ec8SAart Bik   case kShrU:
11006aa6ec8SAart Bik   case kShlI:
11106aa6ec8SAart Bik     assert(x != -1u && y != -1u && !v && !o);
112e2d3db42SAart Bik     children.e0 = x;
113e2d3db42SAart Bik     children.e1 = y;
114e2d3db42SAart Bik     break;
11506aa6ec8SAart Bik   case kBinary:
11606aa6ec8SAart Bik     assert(x != -1u && y != -1u && !v && o);
117b8a021dbSAart Bik     children.e0 = x;
118b8a021dbSAart Bik     children.e1 = y;
119b8a021dbSAart Bik     break;
120b8a021dbSAart Bik   }
121b8a021dbSAart Bik }
122b8a021dbSAart Bik 
LatPoint(unsigned n,unsigned e,unsigned b)123b8a021dbSAart Bik LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
124b8a021dbSAart Bik     : bits(n, false), simple(), exp(e) {
125b8a021dbSAart Bik   bits.set(b);
126b8a021dbSAart Bik }
127b8a021dbSAart Bik 
LatPoint(const BitVector & b,unsigned e)128d10d49dcSRiver Riddle LatPoint::LatPoint(const BitVector &b, unsigned e)
129b8a021dbSAart Bik     : bits(b), simple(), exp(e) {}
130b8a021dbSAart Bik 
131e2d3db42SAart Bik //===----------------------------------------------------------------------===//
132266a7414SAart Bik // Lattice methods.
133e2d3db42SAart Bik //===----------------------------------------------------------------------===//
134266a7414SAart Bik 
addExp(Kind k,unsigned e0,unsigned e1,Value v,Operation * op)1352c332660SJim Kitchen unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v,
1362c332660SJim Kitchen                         Operation *op) {
137744146f6SGus Smith   unsigned e = tensorExps.size();
1382c332660SJim Kitchen   tensorExps.push_back(TensorExp(k, e0, e1, v, op));
139744146f6SGus Smith   return e;
140744146f6SGus Smith }
141744146f6SGus Smith 
addLat(unsigned t,unsigned i,unsigned e)142744146f6SGus Smith unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
143744146f6SGus Smith   assert(t < numTensors && i < numLoops);
144744146f6SGus Smith   unsigned p = latPoints.size();
145744146f6SGus Smith   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
146744146f6SGus Smith   return p;
147744146f6SGus Smith }
148744146f6SGus Smith 
addSet()149744146f6SGus Smith unsigned Merger::addSet() {
150744146f6SGus Smith   unsigned s = latSets.size();
151744146f6SGus Smith   latSets.emplace_back(SmallVector<unsigned, 16>());
152744146f6SGus Smith   return s;
153744146f6SGus Smith }
154744146f6SGus Smith 
conjLatPoint(Kind kind,unsigned p0,unsigned p1,Operation * op)1552c332660SJim Kitchen unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1,
1562c332660SJim Kitchen                               Operation *op) {
157744146f6SGus Smith   unsigned p = latPoints.size();
158d10d49dcSRiver Riddle   BitVector nb = BitVector(latPoints[p0].bits);
159744146f6SGus Smith   nb |= latPoints[p1].bits;
1602c332660SJim Kitchen   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op);
161744146f6SGus Smith   latPoints.push_back(LatPoint(nb, e));
162744146f6SGus Smith   return p;
163744146f6SGus Smith }
164744146f6SGus Smith 
takeConj(Kind kind,unsigned s0,unsigned s1,Operation * op)1652c332660SJim Kitchen unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
166744146f6SGus Smith   unsigned s = addSet();
167744146f6SGus Smith   for (unsigned p0 : latSets[s0])
168744146f6SGus Smith     for (unsigned p1 : latSets[s1])
1692c332660SJim Kitchen       latSets[s].push_back(conjLatPoint(kind, p0, p1, op));
170744146f6SGus Smith   return s;
171744146f6SGus Smith }
172744146f6SGus Smith 
takeDisj(Kind kind,unsigned s0,unsigned s1,Operation * op)1732c332660SJim Kitchen unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
1742c332660SJim Kitchen   unsigned s = takeConj(kind, s0, s1, op);
175123e8dfcSAart Bik   // Followed by all in s0.
176744146f6SGus Smith   for (unsigned p : latSets[s0])
177744146f6SGus Smith     latSets[s].push_back(p);
178123e8dfcSAart Bik   // Map binary 0-y to unary -y.
1792c332660SJim Kitchen   // TODO: move this if-else logic into buildLattices
1808fe65972SAart Bik   if (kind == kSubF)
1818fe65972SAart Bik     s1 = mapSet(kNegF, s1);
182d390035bSBixia Zheng   else if (kind == kSubC)
183d390035bSBixia Zheng     s1 = mapSet(kNegC, s1);
1848fe65972SAart Bik   else if (kind == kSubI)
1858fe65972SAart Bik     s1 = mapSet(kNegI, s1);
186123e8dfcSAart Bik   // Followed by all in s1.
187744146f6SGus Smith   for (unsigned p : latSets[s1])
188744146f6SGus Smith     latSets[s].push_back(p);
189744146f6SGus Smith   return s;
190744146f6SGus Smith }
191744146f6SGus Smith 
takeCombi(Kind kind,unsigned s0,unsigned s1,Operation * orig,bool includeLeft,Kind ltrans,Operation * opleft,bool includeRight,Kind rtrans,Operation * opright)1922c332660SJim Kitchen unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
1932c332660SJim Kitchen                            bool includeLeft, Kind ltrans, Operation *opleft,
1942c332660SJim Kitchen                            bool includeRight, Kind rtrans, Operation *opright) {
1952c332660SJim Kitchen   unsigned s = takeConj(kind, s0, s1, orig);
1962c332660SJim Kitchen   // Left Region.
1972c332660SJim Kitchen   if (includeLeft) {
1982c332660SJim Kitchen     if (opleft)
1992c332660SJim Kitchen       s0 = mapSet(ltrans, s0, Value(), opleft);
2002c332660SJim Kitchen     for (unsigned p : latSets[s0])
2012c332660SJim Kitchen       latSets[s].push_back(p);
2022c332660SJim Kitchen   }
2032c332660SJim Kitchen   // Right Region.
2042c332660SJim Kitchen   if (includeRight) {
2052c332660SJim Kitchen     if (opright)
2062c332660SJim Kitchen       s1 = mapSet(rtrans, s1, Value(), opright);
2072c332660SJim Kitchen     for (unsigned p : latSets[s1])
2082c332660SJim Kitchen       latSets[s].push_back(p);
2092c332660SJim Kitchen   }
2102c332660SJim Kitchen   return s;
2112c332660SJim Kitchen }
2122c332660SJim Kitchen 
mapSet(Kind kind,unsigned s0,Value v,Operation * op)2132c332660SJim Kitchen unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) {
2142c332660SJim Kitchen   assert(kAbsF <= kind && kind <= kUnary);
215b8a021dbSAart Bik   unsigned s = addSet();
216b8a021dbSAart Bik   for (unsigned p : latSets[s0]) {
2172c332660SJim Kitchen     unsigned e = addExp(kind, latPoints[p].exp, v, op);
218b8a021dbSAart Bik     latPoints.push_back(LatPoint(latPoints[p].bits, e));
219b8a021dbSAart Bik     latSets[s].push_back(latPoints.size() - 1);
220b8a021dbSAart Bik   }
221b8a021dbSAart Bik   return s;
222b8a021dbSAart Bik }
223b8a021dbSAart Bik 
optimizeSet(unsigned s0)224744146f6SGus Smith unsigned Merger::optimizeSet(unsigned s0) {
225744146f6SGus Smith   unsigned s = addSet();
2265a1f6077SMehdi Amini   assert(!latSets[s0].empty());
227744146f6SGus Smith   unsigned p0 = latSets[s0][0];
228744146f6SGus Smith   for (unsigned p1 : latSets[s0]) {
229744146f6SGus Smith     bool add = true;
230744146f6SGus Smith     if (p0 != p1) {
231744146f6SGus Smith       // Is this a straightforward copy?
232744146f6SGus Smith       unsigned e = latPoints[p1].exp;
2338fe65972SAart Bik       if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
234744146f6SGus Smith         continue;
235744146f6SGus Smith       // Conjunction already covered?
236744146f6SGus Smith       for (unsigned p2 : latSets[s]) {
237744146f6SGus Smith         assert(!latGT(p1, p2)); // Lj => Li would be bad
238744146f6SGus Smith         if (onlyDenseDiff(p2, p1)) {
239744146f6SGus Smith           add = false;
240744146f6SGus Smith           break;
241744146f6SGus Smith         }
242744146f6SGus Smith       }
243744146f6SGus Smith       assert(!add || latGT(p0, p1));
244744146f6SGus Smith     }
245744146f6SGus Smith     if (add)
246744146f6SGus Smith       latSets[s].push_back(p1);
247744146f6SGus Smith   }
248744146f6SGus Smith   for (unsigned p : latSets[s])
249744146f6SGus Smith     latPoints[p].simple = simplifyCond(s, p);
250744146f6SGus Smith   return s;
251744146f6SGus Smith }
252744146f6SGus Smith 
simplifyCond(unsigned s0,unsigned p0)253d10d49dcSRiver Riddle BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
254744146f6SGus Smith   // First determine if this lattice point is a *singleton*, i.e.,
255744146f6SGus Smith   // the last point in a lattice, no other is less than this one.
256744146f6SGus Smith   bool isSingleton = true;
257b8a021dbSAart Bik   for (unsigned p1 : latSets[s0]) {
258744146f6SGus Smith     if (p0 != p1 && latGT(p0, p1)) {
259744146f6SGus Smith       isSingleton = false;
260744146f6SGus Smith       break;
261744146f6SGus Smith     }
262744146f6SGus Smith   }
263744146f6SGus Smith   // Now apply the two basic rules.
264d10d49dcSRiver Riddle   BitVector simple = latPoints[p0].bits;
2658fe65972SAart Bik   bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
266744146f6SGus Smith   for (unsigned b = 0, be = simple.size(); b < be; b++) {
2678fe65972SAart Bik     if (simple[b] && !isDim(b, kSparse)) {
268744146f6SGus Smith       if (reset)
269744146f6SGus Smith         simple.reset(b);
270744146f6SGus Smith       reset = true;
271744146f6SGus Smith     }
272744146f6SGus Smith   }
273744146f6SGus Smith   return simple;
274744146f6SGus Smith }
275744146f6SGus Smith 
latGT(unsigned i,unsigned j) const276744146f6SGus Smith bool Merger::latGT(unsigned i, unsigned j) const {
277d10d49dcSRiver Riddle   const BitVector &bitsi = latPoints[i].bits;
278d10d49dcSRiver Riddle   const BitVector &bitsj = latPoints[j].bits;
279744146f6SGus Smith   assert(bitsi.size() == bitsj.size());
280744146f6SGus Smith   if (bitsi.count() > bitsj.count()) {
281744146f6SGus Smith     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
282744146f6SGus Smith       if (bitsj[b] && !bitsi[b])
283744146f6SGus Smith         return false;
284744146f6SGus Smith     return true;
285744146f6SGus Smith   }
286744146f6SGus Smith   return false;
287744146f6SGus Smith }
288744146f6SGus Smith 
onlyDenseDiff(unsigned i,unsigned j)289744146f6SGus Smith bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
290d10d49dcSRiver Riddle   BitVector tmp = latPoints[j].bits;
291744146f6SGus Smith   tmp ^= latPoints[i].bits;
2928fe65972SAart Bik   return !hasAnyDimOf(tmp, kSparse);
293744146f6SGus Smith }
294744146f6SGus Smith 
hasAnyDimOf(const BitVector & bits,Dim d) const295d10d49dcSRiver Riddle bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
296744146f6SGus Smith   for (unsigned b = 0, be = bits.size(); b < be; b++)
297744146f6SGus Smith     if (bits[b] && isDim(b, d))
298744146f6SGus Smith       return true;
299744146f6SGus Smith   return false;
300744146f6SGus Smith }
301744146f6SGus Smith 
isSingleCondition(unsigned t,unsigned e) const3020e85232fSAart Bik bool Merger::isSingleCondition(unsigned t, unsigned e) const {
30345b3cfe8SAart Bik   switch (tensorExps[e].kind) {
30406aa6ec8SAart Bik   // Leaf.
3058fe65972SAart Bik   case kTensor:
30645b3cfe8SAart Bik     return tensorExps[e].tensor == t;
30706aa6ec8SAart Bik   case kInvariant:
30806aa6ec8SAart Bik   case kIndex:
30906aa6ec8SAart Bik     return false;
31006aa6ec8SAart Bik   // Unary operations.
311123e8dfcSAart Bik   case kAbsF:
312d390035bSBixia Zheng   case kAbsC:
313123e8dfcSAart Bik   case kCeilF:
314123e8dfcSAart Bik   case kFloorF:
315952fa301SAart Bik   case kSqrtF:
316a14057d4Sbixia1   case kSqrtC:
317952fa301SAart Bik   case kExpm1F:
318a14057d4Sbixia1   case kExpm1C:
319952fa301SAart Bik   case kLog1pF:
320d390035bSBixia Zheng   case kLog1pC:
321952fa301SAart Bik   case kSinF:
322d390035bSBixia Zheng   case kSinC:
323952fa301SAart Bik   case kTanhF:
324a14057d4Sbixia1   case kTanhC:
325123e8dfcSAart Bik   case kNegF:
326d390035bSBixia Zheng   case kNegC:
327123e8dfcSAart Bik   case kNegI:
328e2d3db42SAart Bik   case kTruncF:
329e2d3db42SAart Bik   case kExtF:
330e2d3db42SAart Bik   case kCastFS:
331e2d3db42SAart Bik   case kCastFU:
332e2d3db42SAart Bik   case kCastSF:
333e2d3db42SAart Bik   case kCastUF:
334e2d3db42SAart Bik   case kCastS:
335e2d3db42SAart Bik   case kCastU:
33653cc3a06SAart Bik   case kCastIdx:
337e2d3db42SAart Bik   case kTruncI:
33869edacbcSBixia Zheng   case kCIm:
33969edacbcSBixia Zheng   case kCRe:
340e2d3db42SAart Bik   case kBitCast:
3410e85232fSAart Bik     return isSingleCondition(t, tensorExps[e].children.e0);
34206aa6ec8SAart Bik   case kBinaryBranch:
34306aa6ec8SAart Bik   case kUnary:
34406aa6ec8SAart Bik     return false;
34506aa6ec8SAart Bik   // Binary operations.
3468fe65972SAart Bik   case kDivF: // note: x / c only
347d390035bSBixia Zheng   case kDivC:
3488fe65972SAart Bik   case kDivS:
3498fe65972SAart Bik   case kDivU:
3508fe65972SAart Bik     assert(!maybeZero(tensorExps[e].children.e1));
3510e85232fSAart Bik     return isSingleCondition(t, tensorExps[e].children.e0);
3528fe65972SAart Bik   case kShrS: // note: x >> inv only
3538fe65972SAart Bik   case kShrU:
3548fe65972SAart Bik   case kShlI:
3558fe65972SAart Bik     assert(isInvariant(tensorExps[e].children.e1));
3560e85232fSAart Bik     return isSingleCondition(t, tensorExps[e].children.e0);
3578fe65972SAart Bik   case kMulF:
358736c1b66SAart Bik   case kMulC:
3598fe65972SAart Bik   case kMulI:
3608fe65972SAart Bik   case kAndI:
3610e85232fSAart Bik     if (isSingleCondition(t, tensorExps[e].children.e0))
3620e85232fSAart Bik       return isSingleCondition(t, tensorExps[e].children.e1) ||
3630e85232fSAart Bik              isInvariant(tensorExps[e].children.e1);
3640e85232fSAart Bik     if (isSingleCondition(t, tensorExps[e].children.e1))
3650e85232fSAart Bik       return isInvariant(tensorExps[e].children.e0);
3660e85232fSAart Bik     return false;
3670e85232fSAart Bik   case kAddF:
368736c1b66SAart Bik   case kAddC:
3690e85232fSAart Bik   case kAddI:
3700e85232fSAart Bik     return isSingleCondition(t, tensorExps[e].children.e0) &&
3710e85232fSAart Bik            isSingleCondition(t, tensorExps[e].children.e1);
37206aa6ec8SAart Bik   case kSubF:
37306aa6ec8SAart Bik   case kSubC:
37406aa6ec8SAart Bik   case kSubI:
37506aa6ec8SAart Bik   case kOrI:
37606aa6ec8SAart Bik   case kXorI:
37706aa6ec8SAart Bik   case kBinary:
37845b3cfe8SAart Bik     return false;
37945b3cfe8SAart Bik   }
380f8ec4dfaSMogball   llvm_unreachable("unexpected kind");
38145b3cfe8SAart Bik }
38245b3cfe8SAart Bik 
383557b101cSAart Bik #ifndef NDEBUG
384557b101cSAart Bik 
385e2d3db42SAart Bik //===----------------------------------------------------------------------===//
386557b101cSAart Bik // Print methods (for debugging).
387e2d3db42SAart Bik //===----------------------------------------------------------------------===//
388557b101cSAart Bik 
kindToOpSymbol(Kind kind)3898fe65972SAart Bik static const char *kindToOpSymbol(Kind kind) {
3908fe65972SAart Bik   switch (kind) {
39106aa6ec8SAart Bik   // Leaf.
3928fe65972SAart Bik   case kTensor:
3938fe65972SAart Bik     return "tensor";
3948fe65972SAart Bik   case kInvariant:
3958fe65972SAart Bik     return "invariant";
39653cc3a06SAart Bik   case kIndex:
39753cc3a06SAart Bik     return "index";
39806aa6ec8SAart Bik   // Unary operations.
3998fe65972SAart Bik   case kAbsF:
400d390035bSBixia Zheng   case kAbsC:
4018fe65972SAart Bik     return "abs";
4028fe65972SAart Bik   case kCeilF:
4038fe65972SAart Bik     return "ceil";
4048fe65972SAart Bik   case kFloorF:
4058fe65972SAart Bik     return "floor";
406952fa301SAart Bik   case kSqrtF:
407a14057d4Sbixia1   case kSqrtC:
408952fa301SAart Bik     return "sqrt";
409952fa301SAart Bik   case kExpm1F:
410a14057d4Sbixia1   case kExpm1C:
411952fa301SAart Bik     return "expm1";
412952fa301SAart Bik   case kLog1pF:
413d390035bSBixia Zheng   case kLog1pC:
414952fa301SAart Bik     return "log1p";
415952fa301SAart Bik   case kSinF:
416d390035bSBixia Zheng   case kSinC:
417952fa301SAart Bik     return "sin";
418952fa301SAart Bik   case kTanhF:
419a14057d4Sbixia1   case kTanhC:
420952fa301SAart Bik     return "tanh";
4218fe65972SAart Bik   case kNegF:
422d390035bSBixia Zheng   case kNegC:
4238fe65972SAart Bik   case kNegI:
4248fe65972SAart Bik     return "-";
425e2d3db42SAart Bik   case kTruncF:
426e2d3db42SAart Bik   case kExtF:
427e2d3db42SAart Bik   case kCastFS:
428e2d3db42SAart Bik   case kCastFU:
429e2d3db42SAart Bik   case kCastSF:
430e2d3db42SAart Bik   case kCastUF:
431e2d3db42SAart Bik   case kCastS:
432e2d3db42SAart Bik   case kCastU:
43353cc3a06SAart Bik   case kCastIdx:
434e2d3db42SAart Bik   case kTruncI:
43569edacbcSBixia Zheng   case kCIm:
43669edacbcSBixia Zheng     return "complex.im";
43769edacbcSBixia Zheng   case kCRe:
43869edacbcSBixia Zheng     return "complex.re";
439e2d3db42SAart Bik   case kBitCast:
440e2d3db42SAart Bik     return "cast";
4412c332660SJim Kitchen   case kBinaryBranch:
4422c332660SJim Kitchen     return "binary_branch";
4432c332660SJim Kitchen   case kUnary:
4442c332660SJim Kitchen     return "unary";
44506aa6ec8SAart Bik   // Binary operations.
4468fe65972SAart Bik   case kMulF:
447736c1b66SAart Bik   case kMulC:
4488fe65972SAart Bik   case kMulI:
4498fe65972SAart Bik     return "*";
4508fe65972SAart Bik   case kDivF:
451d390035bSBixia Zheng   case kDivC:
4528fe65972SAart Bik   case kDivS:
4538fe65972SAart Bik   case kDivU:
4548fe65972SAart Bik     return "/";
4558fe65972SAart Bik   case kAddF:
456736c1b66SAart Bik   case kAddC:
4578fe65972SAart Bik   case kAddI:
4588fe65972SAart Bik     return "+";
4598fe65972SAart Bik   case kSubF:
460d390035bSBixia Zheng   case kSubC:
4618fe65972SAart Bik   case kSubI:
4628fe65972SAart Bik     return "-";
4638fe65972SAart Bik   case kAndI:
4648fe65972SAart Bik     return "&";
4658fe65972SAart Bik   case kOrI:
4668fe65972SAart Bik     return "|";
4678fe65972SAart Bik   case kXorI:
4688fe65972SAart Bik     return "^";
4698fe65972SAart Bik   case kShrS:
4708fe65972SAart Bik     return "a>>";
4718fe65972SAart Bik   case kShrU:
4728fe65972SAart Bik     return ">>";
4738fe65972SAart Bik   case kShlI:
4748fe65972SAart Bik     return "<<";
4752c332660SJim Kitchen   case kBinary:
4762c332660SJim Kitchen     return "binary";
4778fe65972SAart Bik   }
4788fe65972SAart Bik   llvm_unreachable("unexpected kind for symbol");
4798fe65972SAart Bik }
480b8a021dbSAart Bik 
dumpExp(unsigned e) const481557b101cSAart Bik void Merger::dumpExp(unsigned e) const {
482557b101cSAart Bik   switch (tensorExps[e].kind) {
48306aa6ec8SAart Bik   // Leaf.
4848fe65972SAart Bik   case kTensor:
4854569c14aSGus Smith     if (tensorExps[e].tensor == syntheticTensor)
486266a7414SAart Bik       llvm::dbgs() << "synthetic_";
4874569c14aSGus Smith     else if (tensorExps[e].tensor == outTensor)
488266a7414SAart Bik       llvm::dbgs() << "output_";
4894569c14aSGus Smith     llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
490557b101cSAart Bik     break;
4918fe65972SAart Bik   case kInvariant:
492557b101cSAart Bik     llvm::dbgs() << "invariant";
493557b101cSAart Bik     break;
49453cc3a06SAart Bik   case kIndex:
49553cc3a06SAart Bik     llvm::dbgs() << "index_" << tensorExps[e].index;
49653cc3a06SAart Bik     break;
49706aa6ec8SAart Bik   // Unary operations.
498123e8dfcSAart Bik   case kAbsF:
49906aa6ec8SAart Bik   case kAbsC:
500123e8dfcSAart Bik   case kCeilF:
501123e8dfcSAart Bik   case kFloorF:
502952fa301SAart Bik   case kSqrtF:
503a14057d4Sbixia1   case kSqrtC:
504952fa301SAart Bik   case kExpm1F:
505a14057d4Sbixia1   case kExpm1C:
506952fa301SAart Bik   case kLog1pF:
50706aa6ec8SAart Bik   case kLog1pC:
508952fa301SAart Bik   case kSinF:
50906aa6ec8SAart Bik   case kSinC:
510952fa301SAart Bik   case kTanhF:
511a14057d4Sbixia1   case kTanhC:
512123e8dfcSAart Bik   case kNegF:
51306aa6ec8SAart Bik   case kNegC:
514123e8dfcSAart Bik   case kNegI:
515e2d3db42SAart Bik   case kTruncF:
516e2d3db42SAart Bik   case kExtF:
517e2d3db42SAart Bik   case kCastFS:
518e2d3db42SAart Bik   case kCastFU:
519e2d3db42SAart Bik   case kCastSF:
520e2d3db42SAart Bik   case kCastUF:
521e2d3db42SAart Bik   case kCastS:
522e2d3db42SAart Bik   case kCastU:
52353cc3a06SAart Bik   case kCastIdx:
524e2d3db42SAart Bik   case kTruncI:
52506aa6ec8SAart Bik   case kCIm:
52606aa6ec8SAart Bik   case kCRe:
527e2d3db42SAart Bik   case kBitCast:
52806aa6ec8SAart Bik   case kBinaryBranch:
52906aa6ec8SAart Bik   case kUnary:
5308fe65972SAart Bik     llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
531123e8dfcSAart Bik     dumpExp(tensorExps[e].children.e0);
532b8a021dbSAart Bik     break;
53306aa6ec8SAart Bik   // Binary operations.
53406aa6ec8SAart Bik   case kMulF:
53506aa6ec8SAart Bik   case kMulC:
53606aa6ec8SAart Bik   case kMulI:
53706aa6ec8SAart Bik   case kDivF:
53806aa6ec8SAart Bik   case kDivC:
53906aa6ec8SAart Bik   case kDivS:
54006aa6ec8SAart Bik   case kDivU:
54106aa6ec8SAart Bik   case kAddF:
54206aa6ec8SAart Bik   case kAddC:
54306aa6ec8SAart Bik   case kAddI:
54406aa6ec8SAart Bik   case kSubF:
54506aa6ec8SAart Bik   case kSubC:
54606aa6ec8SAart Bik   case kSubI:
54706aa6ec8SAart Bik   case kAndI:
54806aa6ec8SAart Bik   case kOrI:
54906aa6ec8SAart Bik   case kXorI:
55006aa6ec8SAart Bik   case kShrS:
55106aa6ec8SAart Bik   case kShrU:
55206aa6ec8SAart Bik   case kShlI:
55306aa6ec8SAart Bik   case kBinary:
554557b101cSAart Bik     llvm::dbgs() << "(";
5554569c14aSGus Smith     dumpExp(tensorExps[e].children.e0);
5568fe65972SAart Bik     llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
5574569c14aSGus Smith     dumpExp(tensorExps[e].children.e1);
558557b101cSAart Bik     llvm::dbgs() << ")";
559557b101cSAart Bik   }
560557b101cSAart Bik }
561557b101cSAart Bik 
dumpLat(unsigned p) const562557b101cSAart Bik void Merger::dumpLat(unsigned p) const {
563557b101cSAart Bik   llvm::dbgs() << "lat(";
564557b101cSAart Bik   dumpBits(latPoints[p].bits);
565557b101cSAart Bik   llvm::dbgs() << " :";
566557b101cSAart Bik   dumpBits(latPoints[p].simple);
567b8a021dbSAart Bik   llvm::dbgs() << " : ";
568557b101cSAart Bik   dumpExp(latPoints[p].exp);
569557b101cSAart Bik   llvm::dbgs() << " )\n";
570557b101cSAart Bik }
571557b101cSAart Bik 
dumpSet(unsigned s) const572557b101cSAart Bik void Merger::dumpSet(unsigned s) const {
573557b101cSAart Bik   llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
574557b101cSAart Bik   for (unsigned p : latSets[s]) {
575557b101cSAart Bik     llvm::dbgs() << "  ";
576557b101cSAart Bik     dumpLat(p);
577557b101cSAart Bik   }
578557b101cSAart Bik   llvm::dbgs() << "}\n";
579557b101cSAart Bik }
580557b101cSAart Bik 
dumpBits(const BitVector & bits) const581d10d49dcSRiver Riddle void Merger::dumpBits(const BitVector &bits) const {
582557b101cSAart Bik   for (unsigned b = 0, be = bits.size(); b < be; b++) {
583557b101cSAart Bik     if (bits[b]) {
584557b101cSAart Bik       unsigned t = tensor(b);
585557b101cSAart Bik       unsigned i = index(b);
586557b101cSAart Bik       llvm::dbgs() << " i_" << t << "_" << i << "_";
587557b101cSAart Bik       switch (dims[t][i]) {
5888fe65972SAart Bik       case kSparse:
589557b101cSAart Bik         llvm::dbgs() << "S";
590557b101cSAart Bik         break;
5918fe65972SAart Bik       case kDense:
592557b101cSAart Bik         llvm::dbgs() << "D";
593557b101cSAart Bik         break;
5948fe65972SAart Bik       case kSingle:
595557b101cSAart Bik         llvm::dbgs() << "T";
596557b101cSAart Bik         break;
5978fe65972SAart Bik       case kUndef:
598557b101cSAart Bik         llvm::dbgs() << "U";
599557b101cSAart Bik         break;
600557b101cSAart Bik       }
601557b101cSAart Bik     }
602557b101cSAart Bik   }
603557b101cSAart Bik }
604557b101cSAart Bik 
605557b101cSAart Bik #endif // NDEBUG
606557b101cSAart Bik 
607e2d3db42SAart Bik //===----------------------------------------------------------------------===//
608266a7414SAart Bik // Builder methods.
609e2d3db42SAart Bik //===----------------------------------------------------------------------===//
610266a7414SAart Bik 
buildLattices(unsigned e,unsigned i)61145b3cfe8SAart Bik unsigned Merger::buildLattices(unsigned e, unsigned i) {
612266a7414SAart Bik   Kind kind = tensorExps[e].kind;
613b8a021dbSAart Bik   switch (kind) {
61406aa6ec8SAart Bik   // Leaf.
6158fe65972SAart Bik   case kTensor:
61653cc3a06SAart Bik   case kInvariant:
61753cc3a06SAart Bik   case kIndex: {
618266a7414SAart Bik     // Either the index is really used in the tensor expression, or it is
61953cc3a06SAart Bik     // set to the undefined index in that dimension. An invariant expression,
62053cc3a06SAart Bik     // a proper index value, and a truly dynamic sparse output tensor are set
62153cc3a06SAart Bik     // to a synthetic tensor with undefined indices only to ensure the
62253cc3a06SAart Bik     // iteration space is not skipped as a result of their contents.
623266a7414SAart Bik     unsigned s = addSet();
62453cc3a06SAart Bik     unsigned t = syntheticTensor;
62553cc3a06SAart Bik     if (kind == kTensor) {
62653cc3a06SAart Bik       t = tensorExps[e].tensor;
6277d4da4e1SAart Bik       if (hasSparseOut && t == outTensor)
6287d4da4e1SAart Bik         t = syntheticTensor;
62953cc3a06SAart Bik     }
63045b3cfe8SAart Bik     latSets[s].push_back(addLat(t, i, e));
631266a7414SAart Bik     return s;
632266a7414SAart Bik   }
63306aa6ec8SAart Bik   // Unary operations.
634123e8dfcSAart Bik   case kAbsF:
635d390035bSBixia Zheng   case kAbsC:
636123e8dfcSAart Bik   case kCeilF:
637123e8dfcSAart Bik   case kFloorF:
638952fa301SAart Bik   case kSqrtF:
639a14057d4Sbixia1   case kSqrtC:
640952fa301SAart Bik   case kExpm1F:
641a14057d4Sbixia1   case kExpm1C:
642952fa301SAart Bik   case kLog1pF:
643d390035bSBixia Zheng   case kLog1pC:
644952fa301SAart Bik   case kSinF:
645d390035bSBixia Zheng   case kSinC:
646952fa301SAart Bik   case kTanhF:
647a14057d4Sbixia1   case kTanhC:
648123e8dfcSAart Bik   case kNegF:
649d390035bSBixia Zheng   case kNegC:
650123e8dfcSAart Bik   case kNegI:
651e2d3db42SAart Bik   case kTruncF:
652e2d3db42SAart Bik   case kExtF:
653e2d3db42SAart Bik   case kCastFS:
654e2d3db42SAart Bik   case kCastFU:
655e2d3db42SAart Bik   case kCastSF:
656e2d3db42SAart Bik   case kCastUF:
657e2d3db42SAart Bik   case kCastS:
658e2d3db42SAart Bik   case kCastU:
65953cc3a06SAart Bik   case kCastIdx:
660e2d3db42SAart Bik   case kTruncI:
66106aa6ec8SAart Bik   case kCIm:
66206aa6ec8SAart Bik   case kCRe:
663e2d3db42SAart Bik   case kBitCast:
664123e8dfcSAart Bik     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
665123e8dfcSAart Bik     // lattice set of the operand through the operator into a new set.
666123e8dfcSAart Bik     //
667123e8dfcSAart Bik     //  -y|!y | y |
668123e8dfcSAart Bik     //  --+---+---+
669123e8dfcSAart Bik     //    | 0 |-y |
670e2d3db42SAart Bik     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
671e2d3db42SAart Bik                   tensorExps[e].val);
6722c332660SJim Kitchen   case kBinaryBranch:
6732c332660SJim Kitchen     // The left or right half of a binary operation which has already
6742c332660SJim Kitchen     // been split into separate operations for each region.
6752c332660SJim Kitchen     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(),
6762c332660SJim Kitchen                   tensorExps[e].op);
6772c332660SJim Kitchen   case kUnary:
6782c332660SJim Kitchen     // A custom unary operation.
6792c332660SJim Kitchen     //
6802c332660SJim Kitchen     //  op y|    !y    |     y      |
6812c332660SJim Kitchen     //  ----+----------+------------+
6822c332660SJim Kitchen     //      | absent() | present(y) |
6832c332660SJim Kitchen     {
6842c332660SJim Kitchen       unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
6852c332660SJim Kitchen       UnaryOp unop = cast<UnaryOp>(tensorExps[e].op);
68604235d07SJacques Pienaar       Region &absentRegion = unop.getAbsentRegion();
6872c332660SJim Kitchen 
6882c332660SJim Kitchen       if (absentRegion.empty()) {
6892c332660SJim Kitchen         // Simple mapping over existing values.
6902c332660SJim Kitchen         return mapSet(kind, child0, Value(), unop);
691c5ea8d50SMehdi Amini       } // Use a disjunction with `unop` on the left and the absent value as an
6922c332660SJim Kitchen       // invariant on the right.
6932c332660SJim Kitchen       Block &absentBlock = absentRegion.front();
6942c332660SJim Kitchen       YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
69504235d07SJacques Pienaar       Value absentVal = absentYield.getResult();
6962c332660SJim Kitchen       unsigned rhs = addExp(kInvariant, absentVal);
6972c332660SJim Kitchen       return takeDisj(kind, child0, buildLattices(rhs, i), unop);
6982c332660SJim Kitchen     }
69906aa6ec8SAart Bik   // Binary operations.
7008fe65972SAart Bik   case kMulF:
701736c1b66SAart Bik   case kMulC:
7028fe65972SAart Bik   case kMulI:
7038fe65972SAart Bik   case kAndI:
704622eb169SAart Bik     // A multiplicative operation only needs to be performed
705622eb169SAart Bik     // for the conjunction of sparse iteration spaces.
706622eb169SAart Bik     //
707622eb169SAart Bik     //  x*y|!y | y |
708622eb169SAart Bik     //  ---+---+---+
709622eb169SAart Bik     //  !x | 0 | 0 |
710622eb169SAart Bik     //   x | 0 |x*y|
711736c1b66SAart Bik     //
712736c1b66SAart Bik     // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
713622eb169SAart Bik     return takeConj(kind, // take binary conjunction
71445b3cfe8SAart Bik                     buildLattices(tensorExps[e].children.e0, i),
71545b3cfe8SAart Bik                     buildLattices(tensorExps[e].children.e1, i));
7168fe65972SAart Bik   case kDivF:
717d390035bSBixia Zheng   case kDivC:
7188fe65972SAart Bik   case kDivS:
7198fe65972SAart Bik   case kDivU:
720622eb169SAart Bik     // A division is tricky, since 0/0, 0/c, c/0 all have
721622eb169SAart Bik     // specific outcomes for floating-point and integers.
722622eb169SAart Bik     // Thus, we need to traverse the full iteration space.
723622eb169SAart Bik     //
724622eb169SAart Bik     //  x/y|!y | y |
725622eb169SAart Bik     //  ---+---+---+
726622eb169SAart Bik     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
727622eb169SAart Bik     //   x |x/0|x/y|  INT: x/0=exception for any x
728622eb169SAart Bik     //
729622eb169SAart Bik     // TODO: for now we "fixed" this by only accepting x/c cases
730622eb169SAart Bik     //       during expression building, so that the conjunction
731622eb169SAart Bik     //       rules applies (viz. x/c = x*(1/c) as far as lattice
732622eb169SAart Bik     //       construction is concerned).
7338fe65972SAart Bik     assert(!maybeZero(tensorExps[e].children.e1));
734b8a021dbSAart Bik     return takeConj(kind, // take binary conjunction
73545b3cfe8SAart Bik                     buildLattices(tensorExps[e].children.e0, i),
73645b3cfe8SAart Bik                     buildLattices(tensorExps[e].children.e1, i));
7378fe65972SAart Bik   case kAddF:
738736c1b66SAart Bik   case kAddC:
7398fe65972SAart Bik   case kAddI:
7408fe65972SAart Bik   case kSubF:
741d390035bSBixia Zheng   case kSubC:
7428fe65972SAart Bik   case kSubI:
7438fe65972SAart Bik   case kOrI:
7448fe65972SAart Bik   case kXorI:
745622eb169SAart Bik     // An additive operation needs to be performed
746622eb169SAart Bik     // for the disjunction of sparse iteration spaces.
747622eb169SAart Bik     //
748622eb169SAart Bik     //  x+y|!y | y |    x-y|!y | y |
749622eb169SAart Bik     //  ---+---+---+    ---+---+---+
750622eb169SAart Bik     //  !x | 0 | y |    !x | 0 |-y |
751622eb169SAart Bik     //   x | x |x+y|     x | x |x-y|
752b8a021dbSAart Bik     return takeDisj(kind, // take binary disjunction
75345b3cfe8SAart Bik                     buildLattices(tensorExps[e].children.e0, i),
75445b3cfe8SAart Bik                     buildLattices(tensorExps[e].children.e1, i));
7558fe65972SAart Bik   case kShrS:
7568fe65972SAart Bik   case kShrU:
7578fe65972SAart Bik   case kShlI:
7582b6e4332SAart Bik     // A shift operation by an invariant amount (viz. tensor expressions
7592b6e4332SAart Bik     // can only occur at the left-hand-side of the operator) can be handled
7602b6e4332SAart Bik     // with the conjuction rule.
7618fe65972SAart Bik     assert(isInvariant(tensorExps[e].children.e1));
7622b6e4332SAart Bik     return takeConj(kind, // take binary conjunction
7632b6e4332SAart Bik                     buildLattices(tensorExps[e].children.e0, i),
7642b6e4332SAart Bik                     buildLattices(tensorExps[e].children.e1, i));
7652c332660SJim Kitchen   case kBinary:
7662c332660SJim Kitchen     // A custom binary operation.
7672c332660SJim Kitchen     //
7682c332660SJim Kitchen     //  x op y|   !y    |       y      |
7692c332660SJim Kitchen     //  ------+---------+--------------+
7702c332660SJim Kitchen     //    !x  |  empty  |   right(y)   |
7712c332660SJim Kitchen     //     x  | left(x) | overlap(x,y) |
7722c332660SJim Kitchen     {
7732c332660SJim Kitchen       unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
7742c332660SJim Kitchen       unsigned child1 = buildLattices(tensorExps[e].children.e1, i);
7752c332660SJim Kitchen       BinaryOp binop = cast<BinaryOp>(tensorExps[e].op);
77604235d07SJacques Pienaar       Region &leftRegion = binop.getLeftRegion();
77704235d07SJacques Pienaar       Region &rightRegion = binop.getRightRegion();
7782c332660SJim Kitchen       // Left Region.
7792c332660SJim Kitchen       Operation *leftYield = nullptr;
7802c332660SJim Kitchen       if (!leftRegion.empty()) {
7812c332660SJim Kitchen         Block &leftBlock = leftRegion.front();
7822c332660SJim Kitchen         leftYield = leftBlock.getTerminator();
7832c332660SJim Kitchen       }
7842c332660SJim Kitchen       // Right Region.
7852c332660SJim Kitchen       Operation *rightYield = nullptr;
7862c332660SJim Kitchen       if (!rightRegion.empty()) {
7872c332660SJim Kitchen         Block &rightBlock = rightRegion.front();
7882c332660SJim Kitchen         rightYield = rightBlock.getTerminator();
7892c332660SJim Kitchen       }
79004235d07SJacques Pienaar       bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
79104235d07SJacques Pienaar       bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
7922c332660SJim Kitchen       return takeCombi(kBinary, child0, child1, binop, includeLeft,
7932c332660SJim Kitchen                        kBinaryBranch, leftYield, includeRight, kBinaryBranch,
7942c332660SJim Kitchen                        rightYield);
7952c332660SJim Kitchen     }
796266a7414SAart Bik   }
797266a7414SAart Bik   llvm_unreachable("unexpected expression kind");
798266a7414SAart Bik }
799266a7414SAart Bik 
buildTensorExpFromLinalg(linalg::GenericOp op)800266a7414SAart Bik Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
8012a288616SAart Bik   // Build the linalg semantics backward from yield.
802266a7414SAart Bik   Operation *yield = op.region().front().getTerminator();
8032a288616SAart Bik   assert(isa<linalg::YieldOp>(yield));
804266a7414SAart Bik   return buildTensorExp(op, yield->getOperand(0));
805266a7414SAart Bik }
806266a7414SAart Bik 
80746e77b5dSAart Bik /// Only returns false if we are certain this is a nonzero.
maybeZero(unsigned e) const8082b6e4332SAart Bik bool Merger::maybeZero(unsigned e) const {
8098fe65972SAart Bik   if (tensorExps[e].kind == kInvariant) {
810d390035bSBixia Zheng     if (auto c = tensorExps[e].val.getDefiningOp<complex::ConstantOp>()) {
811d390035bSBixia Zheng       ArrayAttr arrayAttr = c.getValue();
812d390035bSBixia Zheng       return arrayAttr[0].cast<FloatAttr>().getValue().isZero() &&
813d390035bSBixia Zheng              arrayAttr[0].cast<FloatAttr>().getValue().isZero();
814d390035bSBixia Zheng     }
815a54f4eaeSMogball     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>())
816a54f4eaeSMogball       return c.value() == 0;
817a54f4eaeSMogball     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>())
818a54f4eaeSMogball       return c.value().isZero();
819622eb169SAart Bik   }
820622eb169SAart Bik   return true;
821622eb169SAart Bik }
822622eb169SAart Bik 
isInvariant(unsigned e) const8232b6e4332SAart Bik bool Merger::isInvariant(unsigned e) const {
8248fe65972SAart Bik   return tensorExps[e].kind == kInvariant;
8252b6e4332SAart Bik }
8262b6e4332SAart Bik 
inferType(unsigned e,Value src)827e2d3db42SAart Bik Type Merger::inferType(unsigned e, Value src) {
828e2d3db42SAart Bik   // Obtain the destination type from the cast node.
829e2d3db42SAart Bik   Type dtp = tensorExps[e].val.getType();
830e2d3db42SAart Bik   // Inspect source type. For vector types, apply the same
831e2d3db42SAart Bik   // vectorization to the destination type.
832e2d3db42SAart Bik   if (auto vtp = src.getType().dyn_cast<VectorType>())
8337783a178SJavier Setoain     return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
834e2d3db42SAart Bik   return dtp;
835e2d3db42SAart Bik }
836e2d3db42SAart Bik 
8372a288616SAart Bik /// Ensures that sparse compiler can generate code for expression.
isAdmissableBranchExp(Operation * op,Block * block,Value v)8382a288616SAart Bik static bool isAdmissableBranchExp(Operation *op, Block *block, Value v) {
8392a288616SAart Bik   // Arguments are always admissable.
8402a288616SAart Bik   if (auto arg = v.dyn_cast<BlockArgument>())
8412a288616SAart Bik     return true;
8422a288616SAart Bik   // Accept index anywhere.
8432a288616SAart Bik   Operation *def = v.getDefiningOp();
8442a288616SAart Bik   if (isa<linalg::IndexOp>(def))
8452a288616SAart Bik     return true;
8462a288616SAart Bik   // Operation defined outside branch.
8472a288616SAart Bik   if (def->getBlock() != block) {
8482a288616SAart Bik     return def->getBlock() != op->getBlock(); // invariant?
8492a288616SAart Bik   }
8502a288616SAart Bik   // Operation defined within branch. Anything is accepted,
8512a288616SAart Bik   // as long as all subexpressions are admissable.
8522a288616SAart Bik   for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
8532a288616SAart Bik     if (!isAdmissableBranchExp(op, block, def->getOperand(i)))
8542a288616SAart Bik       return false;
8552a288616SAart Bik   return true;
8562a288616SAart Bik }
8572a288616SAart Bik 
8582a288616SAart Bik /// Ensures that sparse compiler can generate code for branch.
isAdmissableBranch(Operation * op,Region & region)8592a288616SAart Bik static bool isAdmissableBranch(Operation *op, Region &region) {
8602a288616SAart Bik   if (region.empty())
8612a288616SAart Bik     return true;
8622a288616SAart Bik   // Build the semi-ring branch semantics backward from yield.
8632a288616SAart Bik   Operation *yield = region.front().getTerminator();
8642a288616SAart Bik   assert(isa<YieldOp>(yield));
8652a288616SAart Bik   return isAdmissableBranchExp(op, &region.front(), yield->getOperand(0));
8662a288616SAart Bik }
8672a288616SAart Bik 
buildTensorExp(linalg::GenericOp op,Value v)86845b3cfe8SAart Bik Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
86945b3cfe8SAart Bik   if (auto arg = v.dyn_cast<BlockArgument>()) {
870266a7414SAart Bik     unsigned argN = arg.getArgNumber();
871266a7414SAart Bik     // Any argument of the generic op that is not marked as a scalar
872266a7414SAart Bik     // argument is considered a tensor, indexed by the implicit loop
873266a7414SAart Bik     // bounds. This includes rank-0 tensor arguments.
874266a7414SAart Bik     if (arg.getOwner()->getParentOp() == op) {
875266a7414SAart Bik       OpOperand *t = op.getInputAndOutputOperands()[argN];
876266a7414SAart Bik       if (!op.isScalar(t))
8778fe65972SAart Bik         return addExp(kTensor, argN);
87845b3cfe8SAart Bik       v = t->get(); // get scalar value
879266a7414SAart Bik     }
880266a7414SAart Bik     // Any other argument (marked as scalar argument for the generic op
881266a7414SAart Bik     // or belonging to an enveloping op) is considered invariant.
8828fe65972SAart Bik     return addExp(kInvariant, v);
883266a7414SAart Bik   }
884266a7414SAart Bik   // Something defined outside is invariant.
88545b3cfe8SAart Bik   Operation *def = v.getDefiningOp();
886266a7414SAart Bik   if (def->getBlock() != &op.region().front())
8878fe65972SAart Bik     return addExp(kInvariant, v);
88853cc3a06SAart Bik   // Construct index operations.
88953cc3a06SAart Bik   if (def->getNumOperands() == 0) {
89053cc3a06SAart Bik     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
89153cc3a06SAart Bik       return addExp(kIndex, indexOp.dim());
89253cc3a06SAart Bik   }
893b8a021dbSAart Bik   // Construct unary operations if subexpression can be built.
894b8a021dbSAart Bik   if (def->getNumOperands() == 1) {
895b8a021dbSAart Bik     auto x = buildTensorExp(op, def->getOperand(0));
896491d2701SKazu Hirata     if (x.has_value()) {
897*c27d8152SKazu Hirata       unsigned e = x.value();
898a54f4eaeSMogball       if (isa<math::AbsOp>(def))
8998fe65972SAart Bik         return addExp(kAbsF, e);
900d390035bSBixia Zheng       if (isa<complex::AbsOp>(def))
901d390035bSBixia Zheng         return addExp(kAbsC, e);
902a54f4eaeSMogball       if (isa<math::CeilOp>(def))
9038fe65972SAart Bik         return addExp(kCeilF, e);
904a54f4eaeSMogball       if (isa<math::FloorOp>(def))
9058fe65972SAart Bik         return addExp(kFloorF, e);
906952fa301SAart Bik       if (isa<math::SqrtOp>(def))
907952fa301SAart Bik         return addExp(kSqrtF, e);
908a14057d4Sbixia1       if (isa<complex::SqrtOp>(def))
909a14057d4Sbixia1         return addExp(kSqrtC, e);
910952fa301SAart Bik       if (isa<math::ExpM1Op>(def))
911952fa301SAart Bik         return addExp(kExpm1F, e);
912a14057d4Sbixia1       if (isa<complex::Expm1Op>(def))
913a14057d4Sbixia1         return addExp(kExpm1C, e);
914952fa301SAart Bik       if (isa<math::Log1pOp>(def))
915952fa301SAart Bik         return addExp(kLog1pF, e);
916d390035bSBixia Zheng       if (isa<complex::Log1pOp>(def))
917d390035bSBixia Zheng         return addExp(kLog1pC, e);
918952fa301SAart Bik       if (isa<math::SinOp>(def))
919952fa301SAart Bik         return addExp(kSinF, e);
920d390035bSBixia Zheng       if (isa<complex::SinOp>(def))
921d390035bSBixia Zheng         return addExp(kSinC, e);
922952fa301SAart Bik       if (isa<math::TanhOp>(def))
923952fa301SAart Bik         return addExp(kTanhF, e);
924a14057d4Sbixia1       if (isa<complex::TanhOp>(def))
925a14057d4Sbixia1         return addExp(kTanhC, e);
926a54f4eaeSMogball       if (isa<arith::NegFOp>(def))
9277f1cb43dSAart Bik         return addExp(kNegF, e); // no negi in std
928d390035bSBixia Zheng       if (isa<complex::NegOp>(def))
929d390035bSBixia Zheng         return addExp(kNegC, e);
930a54f4eaeSMogball       if (isa<arith::TruncFOp>(def))
931e2d3db42SAart Bik         return addExp(kTruncF, e, v);
932a54f4eaeSMogball       if (isa<arith::ExtFOp>(def))
933e2d3db42SAart Bik         return addExp(kExtF, e, v);
934a54f4eaeSMogball       if (isa<arith::FPToSIOp>(def))
935e2d3db42SAart Bik         return addExp(kCastFS, e, v);
936a54f4eaeSMogball       if (isa<arith::FPToUIOp>(def))
937e2d3db42SAart Bik         return addExp(kCastFU, e, v);
938a54f4eaeSMogball       if (isa<arith::SIToFPOp>(def))
939e2d3db42SAart Bik         return addExp(kCastSF, e, v);
940a54f4eaeSMogball       if (isa<arith::UIToFPOp>(def))
941e2d3db42SAart Bik         return addExp(kCastUF, e, v);
942a54f4eaeSMogball       if (isa<arith::ExtSIOp>(def))
943e2d3db42SAart Bik         return addExp(kCastS, e, v);
944a54f4eaeSMogball       if (isa<arith::ExtUIOp>(def))
945e2d3db42SAart Bik         return addExp(kCastU, e, v);
94653cc3a06SAart Bik       if (isa<arith::IndexCastOp>(def))
94753cc3a06SAart Bik         return addExp(kCastIdx, e, v);
948a54f4eaeSMogball       if (isa<arith::TruncIOp>(def))
949e2d3db42SAart Bik         return addExp(kTruncI, e, v);
95069edacbcSBixia Zheng       if (isa<complex::ImOp>(def))
95169edacbcSBixia Zheng         return addExp(kCIm, e);
95269edacbcSBixia Zheng       if (isa<complex::ReOp>(def))
95369edacbcSBixia Zheng         return addExp(kCRe, e);
954a54f4eaeSMogball       if (isa<arith::BitcastOp>(def))
955e2d3db42SAart Bik         return addExp(kBitCast, e, v);
9562a288616SAart Bik       if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
95704235d07SJacques Pienaar         if (isAdmissableBranch(unop, unop.getPresentRegion()) &&
95804235d07SJacques Pienaar             isAdmissableBranch(unop, unop.getAbsentRegion()))
9592c332660SJim Kitchen           return addExp(kUnary, e, Value(), def);
960b8a021dbSAart Bik       }
961b8a021dbSAart Bik     }
9622a288616SAart Bik   }
963b8a021dbSAart Bik   // Construct binary operations if subexpressions can be built.
9647d4da4e1SAart Bik   // See buildLattices() for an explanation of rejecting certain
9657d4da4e1SAart Bik   // division and shift operations
966266a7414SAart Bik   if (def->getNumOperands() == 2) {
967266a7414SAart Bik     auto x = buildTensorExp(op, def->getOperand(0));
968266a7414SAart Bik     auto y = buildTensorExp(op, def->getOperand(1));
969491d2701SKazu Hirata     if (x.has_value() && y.has_value()) {
970*c27d8152SKazu Hirata       unsigned e0 = x.value();
971*c27d8152SKazu Hirata       unsigned e1 = y.value();
972a54f4eaeSMogball       if (isa<arith::MulFOp>(def))
9738fe65972SAart Bik         return addExp(kMulF, e0, e1);
974736c1b66SAart Bik       if (isa<complex::MulOp>(def))
975736c1b66SAart Bik         return addExp(kMulC, e0, e1);
976a54f4eaeSMogball       if (isa<arith::MulIOp>(def))
9778fe65972SAart Bik         return addExp(kMulI, e0, e1);
978a54f4eaeSMogball       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
9798fe65972SAart Bik         return addExp(kDivF, e0, e1);
980d390035bSBixia Zheng       if (isa<complex::DivOp>(def) && !maybeZero(e1))
981d390035bSBixia Zheng         return addExp(kDivC, e0, e1);
982a54f4eaeSMogball       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
9838fe65972SAart Bik         return addExp(kDivS, e0, e1);
984a54f4eaeSMogball       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
9858fe65972SAart Bik         return addExp(kDivU, e0, e1);
986a54f4eaeSMogball       if (isa<arith::AddFOp>(def))
9878fe65972SAart Bik         return addExp(kAddF, e0, e1);
988736c1b66SAart Bik       if (isa<complex::AddOp>(def))
989736c1b66SAart Bik         return addExp(kAddC, e0, e1);
990a54f4eaeSMogball       if (isa<arith::AddIOp>(def))
9918fe65972SAart Bik         return addExp(kAddI, e0, e1);
992a54f4eaeSMogball       if (isa<arith::SubFOp>(def))
9938fe65972SAart Bik         return addExp(kSubF, e0, e1);
994d390035bSBixia Zheng       if (isa<complex::SubOp>(def))
995d390035bSBixia Zheng         return addExp(kSubC, e0, e1);
996a54f4eaeSMogball       if (isa<arith::SubIOp>(def))
9978fe65972SAart Bik         return addExp(kSubI, e0, e1);
998a54f4eaeSMogball       if (isa<arith::AndIOp>(def))
9998fe65972SAart Bik         return addExp(kAndI, e0, e1);
1000a54f4eaeSMogball       if (isa<arith::OrIOp>(def))
10018fe65972SAart Bik         return addExp(kOrI, e0, e1);
1002a54f4eaeSMogball       if (isa<arith::XOrIOp>(def))
10038fe65972SAart Bik         return addExp(kXorI, e0, e1);
1004a54f4eaeSMogball       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
10058fe65972SAart Bik         return addExp(kShrS, e0, e1);
1006a54f4eaeSMogball       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
10078fe65972SAart Bik         return addExp(kShrU, e0, e1);
1008a54f4eaeSMogball       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
10098fe65972SAart Bik         return addExp(kShlI, e0, e1);
10102a288616SAart Bik       if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
101104235d07SJacques Pienaar         if (isAdmissableBranch(binop, binop.getOverlapRegion()) &&
101204235d07SJacques Pienaar             (binop.getLeftIdentity() ||
101304235d07SJacques Pienaar              isAdmissableBranch(binop, binop.getLeftRegion())) &&
101404235d07SJacques Pienaar             (binop.getRightIdentity() ||
101504235d07SJacques Pienaar              isAdmissableBranch(binop, binop.getRightRegion())))
10162c332660SJim Kitchen           return addExp(kBinary, e0, e1, Value(), def);
1017266a7414SAart Bik       }
1018266a7414SAart Bik     }
10192a288616SAart Bik   }
1020266a7414SAart Bik   // Cannot build.
1021266a7414SAart Bik   return None;
1022266a7414SAart Bik }
1023266a7414SAart Bik 
insertYieldOp(RewriterBase & rewriter,Location loc,Region & region,ValueRange vals)1024e9fa5590SMatthias Springer static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
1025e9fa5590SMatthias Springer                            ValueRange vals) {
10262c332660SJim Kitchen   // Make a clone of overlap region.
10272c332660SJim Kitchen   Region tmpRegion;
10282c332660SJim Kitchen   BlockAndValueMapping mapper;
10292c332660SJim Kitchen   region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
10302c332660SJim Kitchen   Block &clonedBlock = tmpRegion.front();
10312c332660SJim Kitchen   YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
10322c332660SJim Kitchen   // Merge cloned block and return yield value.
10332c332660SJim Kitchen   Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
10342c332660SJim Kitchen   rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals);
103504235d07SJacques Pienaar   Value val = clonedYield.getResult();
10362c332660SJim Kitchen   rewriter.eraseOp(clonedYield);
10372c332660SJim Kitchen   rewriter.eraseOp(placeholder);
10382c332660SJim Kitchen   return val;
10392c332660SJim Kitchen }
10402c332660SJim Kitchen 
buildUnaryPresent(RewriterBase & rewriter,Location loc,Operation * op,Value v0)1041e9fa5590SMatthias Springer static Value buildUnaryPresent(RewriterBase &rewriter, Location loc,
10422c332660SJim Kitchen                                Operation *op, Value v0) {
10432c332660SJim Kitchen   if (!v0)
10442c332660SJim Kitchen     // Empty input value must be propagated.
10452c332660SJim Kitchen     return Value();
10462c332660SJim Kitchen   UnaryOp unop = cast<UnaryOp>(op);
104704235d07SJacques Pienaar   Region &presentRegion = unop.getPresentRegion();
10482c332660SJim Kitchen   if (presentRegion.empty())
10492c332660SJim Kitchen     // Uninitialized Value() will be interpreted as missing data in the
10502c332660SJim Kitchen     // output.
10512c332660SJim Kitchen     return Value();
10522c332660SJim Kitchen   return insertYieldOp(rewriter, loc, presentRegion, {v0});
10532c332660SJim Kitchen }
10542c332660SJim Kitchen 
buildBinaryOverlap(RewriterBase & rewriter,Location loc,Operation * op,Value v0,Value v1)1055e9fa5590SMatthias Springer static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
10562c332660SJim Kitchen                                 Operation *op, Value v0, Value v1) {
10572c332660SJim Kitchen   if (!v0 || !v1)
10582c332660SJim Kitchen     // Empty input values must be propagated.
10592c332660SJim Kitchen     return Value();
10602c332660SJim Kitchen   BinaryOp binop = cast<BinaryOp>(op);
106104235d07SJacques Pienaar   Region &overlapRegion = binop.getOverlapRegion();
10622c332660SJim Kitchen   if (overlapRegion.empty())
10632c332660SJim Kitchen     // Uninitialized Value() will be interpreted as missing data in the
10642c332660SJim Kitchen     // output.
10652c332660SJim Kitchen     return Value();
10662c332660SJim Kitchen   return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
10672c332660SJim Kitchen }
10682c332660SJim Kitchen 
buildExp(RewriterBase & rewriter,Location loc,unsigned e,Value v0,Value v1)1069e9fa5590SMatthias Springer Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
107045b3cfe8SAart Bik                        Value v0, Value v1) {
107145b3cfe8SAart Bik   switch (tensorExps[e].kind) {
107206aa6ec8SAart Bik   // Leaf.
10738fe65972SAart Bik   case kTensor:
10748fe65972SAart Bik   case kInvariant:
107553cc3a06SAart Bik   case kIndex:
107645b3cfe8SAart Bik     llvm_unreachable("unexpected non-op");
107706aa6ec8SAart Bik   // Unary operations.
1078123e8dfcSAart Bik   case kAbsF:
1079a54f4eaeSMogball     return rewriter.create<math::AbsOp>(loc, v0);
1080d390035bSBixia Zheng   case kAbsC: {
108106aa6ec8SAart Bik     auto type = v0.getType().cast<ComplexType>();
108206aa6ec8SAart Bik     auto eltType = type.getElementType().cast<FloatType>();
1083d390035bSBixia Zheng     return rewriter.create<complex::AbsOp>(loc, eltType, v0);
1084d390035bSBixia Zheng   }
1085123e8dfcSAart Bik   case kCeilF:
1086a54f4eaeSMogball     return rewriter.create<math::CeilOp>(loc, v0);
1087123e8dfcSAart Bik   case kFloorF:
1088a54f4eaeSMogball     return rewriter.create<math::FloorOp>(loc, v0);
1089952fa301SAart Bik   case kSqrtF:
1090952fa301SAart Bik     return rewriter.create<math::SqrtOp>(loc, v0);
1091a14057d4Sbixia1   case kSqrtC:
1092a14057d4Sbixia1     return rewriter.create<complex::SqrtOp>(loc, v0);
1093952fa301SAart Bik   case kExpm1F:
1094952fa301SAart Bik     return rewriter.create<math::ExpM1Op>(loc, v0);
1095a14057d4Sbixia1   case kExpm1C:
1096a14057d4Sbixia1     return rewriter.create<complex::Expm1Op>(loc, v0);
1097952fa301SAart Bik   case kLog1pF:
1098952fa301SAart Bik     return rewriter.create<math::Log1pOp>(loc, v0);
1099d390035bSBixia Zheng   case kLog1pC:
1100d390035bSBixia Zheng     return rewriter.create<complex::Log1pOp>(loc, v0);
1101952fa301SAart Bik   case kSinF:
1102952fa301SAart Bik     return rewriter.create<math::SinOp>(loc, v0);
1103d390035bSBixia Zheng   case kSinC:
1104d390035bSBixia Zheng     return rewriter.create<complex::SinOp>(loc, v0);
1105952fa301SAart Bik   case kTanhF:
1106952fa301SAart Bik     return rewriter.create<math::TanhOp>(loc, v0);
1107a14057d4Sbixia1   case kTanhC:
1108a14057d4Sbixia1     return rewriter.create<complex::TanhOp>(loc, v0);
1109123e8dfcSAart Bik   case kNegF:
1110a54f4eaeSMogball     return rewriter.create<arith::NegFOp>(loc, v0);
1111d390035bSBixia Zheng   case kNegC:
1112d390035bSBixia Zheng     return rewriter.create<complex::NegOp>(loc, v0);
11137f1cb43dSAart Bik   case kNegI: // no negi in std
1114a54f4eaeSMogball     return rewriter.create<arith::SubIOp>(
11157f1cb43dSAart Bik         loc,
1116a54f4eaeSMogball         rewriter.create<arith::ConstantOp>(loc, v0.getType(),
11177f1cb43dSAart Bik                                            rewriter.getZeroAttr(v0.getType())),
11187f1cb43dSAart Bik         v0);
1119e2d3db42SAart Bik   case kTruncF:
11203c69bc4dSRiver Riddle     return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1121e2d3db42SAart Bik   case kExtF:
11223c69bc4dSRiver Riddle     return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1123e2d3db42SAart Bik   case kCastFS:
11243c69bc4dSRiver Riddle     return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1125e2d3db42SAart Bik   case kCastFU:
11263c69bc4dSRiver Riddle     return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1127e2d3db42SAart Bik   case kCastSF:
11283c69bc4dSRiver Riddle     return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1129e2d3db42SAart Bik   case kCastUF:
11303c69bc4dSRiver Riddle     return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1131e2d3db42SAart Bik   case kCastS:
11323c69bc4dSRiver Riddle     return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1133e2d3db42SAart Bik   case kCastU:
11343c69bc4dSRiver Riddle     return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
113553cc3a06SAart Bik   case kCastIdx:
113653cc3a06SAart Bik     return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1137e2d3db42SAart Bik   case kTruncI:
11383c69bc4dSRiver Riddle     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
113906aa6ec8SAart Bik   case kCIm: {
114006aa6ec8SAart Bik     auto type = v0.getType().cast<ComplexType>();
114106aa6ec8SAart Bik     auto eltType = type.getElementType().cast<FloatType>();
114269edacbcSBixia Zheng     return rewriter.create<complex::ImOp>(loc, eltType, v0);
114306aa6ec8SAart Bik   }
114406aa6ec8SAart Bik   case kCRe: {
114506aa6ec8SAart Bik     auto type = v0.getType().cast<ComplexType>();
114606aa6ec8SAart Bik     auto eltType = type.getElementType().cast<FloatType>();
114769edacbcSBixia Zheng     return rewriter.create<complex::ReOp>(loc, eltType, v0);
114869edacbcSBixia Zheng   }
1149e2d3db42SAart Bik   case kBitCast:
11503c69bc4dSRiver Riddle     return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
115106aa6ec8SAart Bik   // Binary operations.
11528fe65972SAart Bik   case kMulF:
1153a54f4eaeSMogball     return rewriter.create<arith::MulFOp>(loc, v0, v1);
1154736c1b66SAart Bik   case kMulC:
1155736c1b66SAart Bik     return rewriter.create<complex::MulOp>(loc, v0, v1);
11568fe65972SAart Bik   case kMulI:
1157a54f4eaeSMogball     return rewriter.create<arith::MulIOp>(loc, v0, v1);
11588fe65972SAart Bik   case kDivF:
1159a54f4eaeSMogball     return rewriter.create<arith::DivFOp>(loc, v0, v1);
1160d390035bSBixia Zheng   case kDivC:
1161d390035bSBixia Zheng     return rewriter.create<complex::DivOp>(loc, v0, v1);
11628fe65972SAart Bik   case kDivS:
1163a54f4eaeSMogball     return rewriter.create<arith::DivSIOp>(loc, v0, v1);
11648fe65972SAart Bik   case kDivU:
1165a54f4eaeSMogball     return rewriter.create<arith::DivUIOp>(loc, v0, v1);
11668fe65972SAart Bik   case kAddF:
1167a54f4eaeSMogball     return rewriter.create<arith::AddFOp>(loc, v0, v1);
1168736c1b66SAart Bik   case kAddC:
1169736c1b66SAart Bik     return rewriter.create<complex::AddOp>(loc, v0, v1);
11708fe65972SAart Bik   case kAddI:
1171a54f4eaeSMogball     return rewriter.create<arith::AddIOp>(loc, v0, v1);
11728fe65972SAart Bik   case kSubF:
1173a54f4eaeSMogball     return rewriter.create<arith::SubFOp>(loc, v0, v1);
1174d390035bSBixia Zheng   case kSubC:
1175d390035bSBixia Zheng     return rewriter.create<complex::SubOp>(loc, v0, v1);
11768fe65972SAart Bik   case kSubI:
1177a54f4eaeSMogball     return rewriter.create<arith::SubIOp>(loc, v0, v1);
11788fe65972SAart Bik   case kAndI:
1179a54f4eaeSMogball     return rewriter.create<arith::AndIOp>(loc, v0, v1);
11808fe65972SAart Bik   case kOrI:
1181a54f4eaeSMogball     return rewriter.create<arith::OrIOp>(loc, v0, v1);
11828fe65972SAart Bik   case kXorI:
1183a54f4eaeSMogball     return rewriter.create<arith::XOrIOp>(loc, v0, v1);
11848fe65972SAart Bik   case kShrS:
1185a54f4eaeSMogball     return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
11868fe65972SAart Bik   case kShrU:
1187a54f4eaeSMogball     return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
11888fe65972SAart Bik   case kShlI:
1189a54f4eaeSMogball     return rewriter.create<arith::ShLIOp>(loc, v0, v1);
119006aa6ec8SAart Bik   case kBinaryBranch: // semi-ring ops with custom logic.
11912c332660SJim Kitchen     return insertYieldOp(rewriter, loc,
11922c332660SJim Kitchen                          *tensorExps[e].op->getBlock()->getParent(), {v0});
11932c332660SJim Kitchen   case kUnary:
11942c332660SJim Kitchen     return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0);
11952c332660SJim Kitchen   case kBinary:
11962c332660SJim Kitchen     return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1);
119745b3cfe8SAart Bik   }
119845b3cfe8SAart Bik   llvm_unreachable("unexpected expression kind in build");
119945b3cfe8SAart Bik }
120045b3cfe8SAart Bik 
1201744146f6SGus Smith } // namespace sparse_tensor
1202744146f6SGus Smith } // namespace mlir
1203