1 //===- Merger.h - Utilities for defining lattices ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities for dealing with iteration lattices.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
14 #define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
15 
16 #include "mlir/Dialect/Linalg/IR/Linalg.h"
17 #include "mlir/IR/Value.h"
18 #include "llvm/ADT/BitVector.h"
19 
20 namespace mlir {
21 namespace sparse_tensor {
22 
23 /// Dimension level type for a tensor (undef means index does not appear).
24 enum Dim { kSparse, kDense, kSingle, kUndef };
25 
26 /// Tensor expression kind.
27 enum Kind {
28   // Leaf.
29   kTensor = 0,
30   kInvariant,
31   kIndex,
32   // Unary operations.
33   kAbsF,
34   kAbsC,
35   kCeilF,
36   kFloorF,
37   kSqrtF,
38   kSqrtC,
39   kExpm1F,
40   kExpm1C,
41   kLog1pF,
42   kLog1pC,
43   kSinF,
44   kSinC,
45   kTanhF,
46   kTanhC,
47   kNegF,
48   kNegC,
49   kNegI,
50   kTruncF,
51   kExtF,
52   kCastFS, // signed
53   kCastFU, // unsigned
54   kCastSF, // signed
55   kCastUF, // unsigned
56   kCastS,  // signed
57   kCastU,  // unsigned
58   kCastIdx,
59   kTruncI,
60   kCIm, // complex.im
61   kCRe, // complex.re
62   kBitCast,
63   kBinaryBranch, // semiring unary branch created from a binary op
64   kUnary,        // semiring unary op
65   // Binary operations.
66   kMulF,
67   kMulC,
68   kMulI,
69   kDivF,
70   kDivC, // complex
71   kDivS, // signed
72   kDivU, // unsigned
73   kAddF,
74   kAddC,
75   kAddI,
76   kSubF,
77   kSubC,
78   kSubI,
79   kAndI,
80   kOrI,
81   kXorI,
82   kShrS, // signed
83   kShrU, // unsigned
84   kShlI,
85   kBinary, // semiring binary op
86 };
87 
88 /// Children subexpressions of tensor operations.
89 struct Children {
90   unsigned e0;
91   unsigned e1;
92 };
93 
94 /// Tensor expression. Represents a MLIR expression in tensor index notation.
95 struct TensorExp {
96   TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *operation);
97 
98   /// Tensor expression kind.
99   Kind kind;
100 
101   union {
102     /// Expressions representing tensors simply have a tensor number.
103     unsigned tensor;
104 
105     /// Indices hold the index number.
106     unsigned index;
107 
108     /// Tensor operations hold the indices of their children.
109     Children children;
110   };
111 
112   /// Direct link to IR for an invariant or the destination value (to
113   /// infer destination type) of a cast operation During code generation,
114   /// this field may be used to cache "hoisted" loop invariant tensor loads.
115   Value val;
116 
117   /// Code blocks used by semirings. For the case of kUnary and
118   /// kBinary, this holds the original operation with all regions. For
119   /// kBinaryBranch, this holds the YieldOp for the left or right half
120   /// to be merged into a nested scf loop.
121   Operation *op;
122 };
123 
124 /// Lattice point. Each lattice point consists of a conjunction of tensor
125 /// loop indices (encoded in a bitvector) and the index of the corresponding
126 /// tensor expression.
127 struct LatPoint {
128   LatPoint(unsigned n, unsigned e, unsigned b);
129   LatPoint(const BitVector &b, unsigned e);
130 
131   /// Conjunction of tensor loop indices as bitvector. This represents
132   /// all indices involved in the tensor expression
133   BitVector bits;
134 
135   /// Simplified conjunction of tensor loop indices as bitvector. This
136   /// represents a simplified condition under which this tensor expression
137   /// must execute. Pre-computed during codegen to avoid repeated eval.
138   BitVector simple;
139 
140   /// Index of the tensor expression.
141   unsigned exp;
142 };
143 
144 /// A class to handle all iteration lattice operations. This class abstracts
145 /// away from some implementation details of storing iteration lattices and
146 /// tensor expressions. This allows for fine-tuning performance characteristics
147 /// independently from the basic algorithm if bottlenecks are identified.
148 class Merger {
149 public:
150   /// Constructs a merger for the given number of tensors and loops. The
151   /// user supplies the number of tensors involved in the kernel, with the
152   /// last tensor in this set denoting the output tensor. The merger adds an
153   /// additional synthetic tensor at the end of this set to represent all
154   /// invariant expressions in the kernel.
Merger(unsigned t,unsigned l)155   Merger(unsigned t, unsigned l)
156       : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l),
157         hasSparseOut(false), dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {}
158 
159   /// Adds a tensor expression. Returns its index.
160   unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(),
161                   Operation *op = nullptr);
162   unsigned addExp(Kind k, unsigned e, Value v, Operation *op = nullptr) {
163     return addExp(k, e, -1u, v, op);
164   }
165   unsigned addExp(Kind k, Value v, Operation *op = nullptr) {
166     return addExp(k, -1u, -1u, v, op);
167   }
168 
169   /// Adds an iteration lattice point. Returns its index.
170   unsigned addLat(unsigned t, unsigned i, unsigned e);
171 
172   /// Adds a new, initially empty, set. Returns its index.
173   unsigned addSet();
174 
175   /// Computes a single conjunction of two lattice points by taking the "union"
176   /// of loop indices (effectively constructing a larger "intersection" of those
177   /// indices) with a newly constructed tensor (sub)expression of given kind.
178   /// Returns the index of the new lattice point.
179   unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1,
180                         Operation *op = nullptr);
181 
182   /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of
183   /// cartesian product. Returns the index of the new set.
184   unsigned takeConj(Kind kind, unsigned s0, unsigned s1,
185                     Operation *op = nullptr);
186 
187   /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1).
188   /// Returns the index of the new set.
189   unsigned takeDisj(Kind kind, unsigned s0, unsigned s1,
190                     Operation *op = nullptr);
191 
192   /// Disjunctive merge of two lattice sets L0 and L1 with custom handling of
193   /// the overlap, left, and right regions. Any region may be left missing in
194   /// the output. Returns the index of the new set.
195   unsigned takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
196                      bool includeLeft, Kind ltrans, Operation *opleft,
197                      bool includeRight, Kind rtrans, Operation *opright);
198 
199   /// Maps the unary operator over the lattice set of the operand, i.e. each
200   /// lattice point on an expression E is simply copied over, but with OP E
201   /// as new expression. Returns the index of the new set.
202   unsigned mapSet(Kind kind, unsigned s0, Value v = Value(),
203                   Operation *op = nullptr);
204 
205   /// Optimizes the iteration lattice points in the given set. This
206   /// method should be called right before code generation to avoid
207   /// generating redundant loops and conditions.
208   unsigned optimizeSet(unsigned s0);
209 
210   /// Simplifies the conditions in a conjunction of a given lattice point
211   /// within the given set using just two basic rules:
212   /// (1) multiple dense conditions are reduced to single dense, and
213   /// (2) a *singleton* sparse/dense is reduced to sparse/random access.
214   BitVector simplifyCond(unsigned s0, unsigned p0);
215 
216   /// Returns true if Li > Lj.
217   bool latGT(unsigned i, unsigned j) const;
218 
219   /// Returns true if Li and Lj only differ in dense.
220   bool onlyDenseDiff(unsigned i, unsigned j);
221 
222   /// Bit translation.
tensor(unsigned b)223   unsigned tensor(unsigned b) const { return b % numTensors; }
index(unsigned b)224   unsigned index(unsigned b) const { return b / numTensors; }
225 
226   /// Returns true if bit corresponds to queried dim.
isDim(unsigned b,Dim d)227   bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); }
228 
229   /// Returns true if bit corresponds to index of output tensor.
isOutTensor(unsigned b,unsigned i)230   bool isOutTensor(unsigned b, unsigned i) const {
231     return tensor(b) == outTensor && index(b) == i;
232   }
233 
234   /// Returns true if tensor access at given index has queried dim.
isDim(unsigned t,unsigned i,Dim d)235   bool isDim(unsigned t, unsigned i, Dim d) const {
236     assert(t < numTensors && i < numLoops);
237     return dims[t][i] == d;
238   }
239 
240   /// Returns true if any set bit corresponds to queried dim.
241   bool hasAnyDimOf(const BitVector &bits, Dim d) const;
242 
243   /// Returns true if given tensor iterates *only* in the given tensor
244   /// expression. For the output tensor, this defines a "simply dynamic"
245   /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for
246   /// sparse vector a.
247   bool isSingleCondition(unsigned t, unsigned e) const;
248 
249   /// Dimension setter.
setDim(unsigned t,unsigned i,Dim d)250   void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; }
251 
252   // Has sparse output tensor setter.
setHasSparseOut(bool s)253   void setHasSparseOut(bool s) { hasSparseOut = s; }
254 
255   /// Convenience getters to immediately access the stored nodes.
256   /// Typically it is inadvisible to keep the reference around, as in
257   /// "TensorExpr &te = merger.exp(e))", since insertions into the merger
258   /// may cause data movement and invalidate the underlying memory address.
exp(unsigned e)259   TensorExp &exp(unsigned e) { return tensorExps[e]; }
lat(unsigned l)260   LatPoint &lat(unsigned l) { return latPoints[l]; }
set(unsigned s)261   SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
262 
263 #ifndef NDEBUG
264   /// Print methods (for debugging).
265   void dumpExp(unsigned e) const;
266   void dumpLat(unsigned p) const;
267   void dumpSet(unsigned s) const;
268   void dumpBits(const BitVector &bits) const;
269 #endif
270 
271   /// Builds the iteration lattices in a bottom-up traversal given the remaining
272   /// tensor (sub)expression and the next loop index in the iteration graph.
273   /// Returns index of the root expression.
274   unsigned buildLattices(unsigned e, unsigned i);
275 
276   /// Builds a tensor expression from the given Linalg operation.
277   /// Returns index of the root expression on success.
278   Optional<unsigned> buildTensorExpFromLinalg(linalg::GenericOp op);
279 
280   /// Rebuilds SSA format from a tensor expression.
281   Value buildExp(RewriterBase &rewriter, Location loc, unsigned e, Value v0,
282                  Value v1);
283 
284 private:
285   /// Private helpers.
286   bool maybeZero(unsigned e) const;
287   bool isInvariant(unsigned e) const;
288   Type inferType(unsigned e, Value src);
289 
290   /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
291   Optional<unsigned> buildTensorExp(linalg::GenericOp op, Value v);
292 
293   /// Merger data structures.
294   const unsigned outTensor;
295   const unsigned syntheticTensor;
296   const unsigned numTensors;
297   const unsigned numLoops;
298   bool hasSparseOut;
299   std::vector<std::vector<Dim>> dims;
300   llvm::SmallVector<TensorExp, 32> tensorExps;
301   llvm::SmallVector<LatPoint, 16> latPoints;
302   llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
303 };
304 
305 } // namespace sparse_tensor
306 } // namespace mlir
307 
308 #endif // MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
309