1 //===- Merger.h - Utilities for defining lattices ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header file defines utilities for dealing with iteration lattices. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_ 14 #define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_ 15 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/IR/Value.h" 18 #include "llvm/ADT/BitVector.h" 19 20 namespace mlir { 21 namespace sparse_tensor { 22 23 /// Dimension level type for a tensor (undef means index does not appear). 24 enum Dim { kSparse, kDense, kSingle, kUndef }; 25 26 /// Tensor expression kind. 27 enum Kind { 28 // Leaf. 29 kTensor = 0, 30 kInvariant, 31 kIndex, 32 // Unary operations. 33 kAbsF, 34 kAbsC, 35 kCeilF, 36 kFloorF, 37 kSqrtF, 38 kSqrtC, 39 kExpm1F, 40 kExpm1C, 41 kLog1pF, 42 kLog1pC, 43 kSinF, 44 kSinC, 45 kTanhF, 46 kTanhC, 47 kNegF, 48 kNegC, 49 kNegI, 50 kTruncF, 51 kExtF, 52 kCastFS, // signed 53 kCastFU, // unsigned 54 kCastSF, // signed 55 kCastUF, // unsigned 56 kCastS, // signed 57 kCastU, // unsigned 58 kCastIdx, 59 kTruncI, 60 kCIm, // complex.im 61 kCRe, // complex.re 62 kBitCast, 63 kBinaryBranch, // semiring unary branch created from a binary op 64 kUnary, // semiring unary op 65 // Binary operations. 66 kMulF, 67 kMulC, 68 kMulI, 69 kDivF, 70 kDivC, // complex 71 kDivS, // signed 72 kDivU, // unsigned 73 kAddF, 74 kAddC, 75 kAddI, 76 kSubF, 77 kSubC, 78 kSubI, 79 kAndI, 80 kOrI, 81 kXorI, 82 kShrS, // signed 83 kShrU, // unsigned 84 kShlI, 85 kBinary, // semiring binary op 86 }; 87 88 /// Children subexpressions of tensor operations. 89 struct Children { 90 unsigned e0; 91 unsigned e1; 92 }; 93 94 /// Tensor expression. Represents a MLIR expression in tensor index notation. 95 struct TensorExp { 96 TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *operation); 97 98 /// Tensor expression kind. 99 Kind kind; 100 101 union { 102 /// Expressions representing tensors simply have a tensor number. 103 unsigned tensor; 104 105 /// Indices hold the index number. 106 unsigned index; 107 108 /// Tensor operations hold the indices of their children. 109 Children children; 110 }; 111 112 /// Direct link to IR for an invariant or the destination value (to 113 /// infer destination type) of a cast operation During code generation, 114 /// this field may be used to cache "hoisted" loop invariant tensor loads. 115 Value val; 116 117 /// Code blocks used by semirings. For the case of kUnary and 118 /// kBinary, this holds the original operation with all regions. For 119 /// kBinaryBranch, this holds the YieldOp for the left or right half 120 /// to be merged into a nested scf loop. 121 Operation *op; 122 }; 123 124 /// Lattice point. Each lattice point consists of a conjunction of tensor 125 /// loop indices (encoded in a bitvector) and the index of the corresponding 126 /// tensor expression. 127 struct LatPoint { 128 LatPoint(unsigned n, unsigned e, unsigned b); 129 LatPoint(const BitVector &b, unsigned e); 130 131 /// Conjunction of tensor loop indices as bitvector. This represents 132 /// all indices involved in the tensor expression 133 BitVector bits; 134 135 /// Simplified conjunction of tensor loop indices as bitvector. This 136 /// represents a simplified condition under which this tensor expression 137 /// must execute. Pre-computed during codegen to avoid repeated eval. 138 BitVector simple; 139 140 /// Index of the tensor expression. 141 unsigned exp; 142 }; 143 144 /// A class to handle all iteration lattice operations. This class abstracts 145 /// away from some implementation details of storing iteration lattices and 146 /// tensor expressions. This allows for fine-tuning performance characteristics 147 /// independently from the basic algorithm if bottlenecks are identified. 148 class Merger { 149 public: 150 /// Constructs a merger for the given number of tensors and loops. The 151 /// user supplies the number of tensors involved in the kernel, with the 152 /// last tensor in this set denoting the output tensor. The merger adds an 153 /// additional synthetic tensor at the end of this set to represent all 154 /// invariant expressions in the kernel. Merger(unsigned t,unsigned l)155 Merger(unsigned t, unsigned l) 156 : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l), 157 hasSparseOut(false), dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {} 158 159 /// Adds a tensor expression. Returns its index. 160 unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(), 161 Operation *op = nullptr); 162 unsigned addExp(Kind k, unsigned e, Value v, Operation *op = nullptr) { 163 return addExp(k, e, -1u, v, op); 164 } 165 unsigned addExp(Kind k, Value v, Operation *op = nullptr) { 166 return addExp(k, -1u, -1u, v, op); 167 } 168 169 /// Adds an iteration lattice point. Returns its index. 170 unsigned addLat(unsigned t, unsigned i, unsigned e); 171 172 /// Adds a new, initially empty, set. Returns its index. 173 unsigned addSet(); 174 175 /// Computes a single conjunction of two lattice points by taking the "union" 176 /// of loop indices (effectively constructing a larger "intersection" of those 177 /// indices) with a newly constructed tensor (sub)expression of given kind. 178 /// Returns the index of the new lattice point. 179 unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1, 180 Operation *op = nullptr); 181 182 /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of 183 /// cartesian product. Returns the index of the new set. 184 unsigned takeConj(Kind kind, unsigned s0, unsigned s1, 185 Operation *op = nullptr); 186 187 /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1). 188 /// Returns the index of the new set. 189 unsigned takeDisj(Kind kind, unsigned s0, unsigned s1, 190 Operation *op = nullptr); 191 192 /// Disjunctive merge of two lattice sets L0 and L1 with custom handling of 193 /// the overlap, left, and right regions. Any region may be left missing in 194 /// the output. Returns the index of the new set. 195 unsigned takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig, 196 bool includeLeft, Kind ltrans, Operation *opleft, 197 bool includeRight, Kind rtrans, Operation *opright); 198 199 /// Maps the unary operator over the lattice set of the operand, i.e. each 200 /// lattice point on an expression E is simply copied over, but with OP E 201 /// as new expression. Returns the index of the new set. 202 unsigned mapSet(Kind kind, unsigned s0, Value v = Value(), 203 Operation *op = nullptr); 204 205 /// Optimizes the iteration lattice points in the given set. This 206 /// method should be called right before code generation to avoid 207 /// generating redundant loops and conditions. 208 unsigned optimizeSet(unsigned s0); 209 210 /// Simplifies the conditions in a conjunction of a given lattice point 211 /// within the given set using just two basic rules: 212 /// (1) multiple dense conditions are reduced to single dense, and 213 /// (2) a *singleton* sparse/dense is reduced to sparse/random access. 214 BitVector simplifyCond(unsigned s0, unsigned p0); 215 216 /// Returns true if Li > Lj. 217 bool latGT(unsigned i, unsigned j) const; 218 219 /// Returns true if Li and Lj only differ in dense. 220 bool onlyDenseDiff(unsigned i, unsigned j); 221 222 /// Bit translation. tensor(unsigned b)223 unsigned tensor(unsigned b) const { return b % numTensors; } index(unsigned b)224 unsigned index(unsigned b) const { return b / numTensors; } 225 226 /// Returns true if bit corresponds to queried dim. isDim(unsigned b,Dim d)227 bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); } 228 229 /// Returns true if bit corresponds to index of output tensor. isOutTensor(unsigned b,unsigned i)230 bool isOutTensor(unsigned b, unsigned i) const { 231 return tensor(b) == outTensor && index(b) == i; 232 } 233 234 /// Returns true if tensor access at given index has queried dim. isDim(unsigned t,unsigned i,Dim d)235 bool isDim(unsigned t, unsigned i, Dim d) const { 236 assert(t < numTensors && i < numLoops); 237 return dims[t][i] == d; 238 } 239 240 /// Returns true if any set bit corresponds to queried dim. 241 bool hasAnyDimOf(const BitVector &bits, Dim d) const; 242 243 /// Returns true if given tensor iterates *only* in the given tensor 244 /// expression. For the output tensor, this defines a "simply dynamic" 245 /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for 246 /// sparse vector a. 247 bool isSingleCondition(unsigned t, unsigned e) const; 248 249 /// Dimension setter. setDim(unsigned t,unsigned i,Dim d)250 void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; } 251 252 // Has sparse output tensor setter. setHasSparseOut(bool s)253 void setHasSparseOut(bool s) { hasSparseOut = s; } 254 255 /// Convenience getters to immediately access the stored nodes. 256 /// Typically it is inadvisible to keep the reference around, as in 257 /// "TensorExpr &te = merger.exp(e))", since insertions into the merger 258 /// may cause data movement and invalidate the underlying memory address. exp(unsigned e)259 TensorExp &exp(unsigned e) { return tensorExps[e]; } lat(unsigned l)260 LatPoint &lat(unsigned l) { return latPoints[l]; } set(unsigned s)261 SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; } 262 263 #ifndef NDEBUG 264 /// Print methods (for debugging). 265 void dumpExp(unsigned e) const; 266 void dumpLat(unsigned p) const; 267 void dumpSet(unsigned s) const; 268 void dumpBits(const BitVector &bits) const; 269 #endif 270 271 /// Builds the iteration lattices in a bottom-up traversal given the remaining 272 /// tensor (sub)expression and the next loop index in the iteration graph. 273 /// Returns index of the root expression. 274 unsigned buildLattices(unsigned e, unsigned i); 275 276 /// Builds a tensor expression from the given Linalg operation. 277 /// Returns index of the root expression on success. 278 Optional<unsigned> buildTensorExpFromLinalg(linalg::GenericOp op); 279 280 /// Rebuilds SSA format from a tensor expression. 281 Value buildExp(RewriterBase &rewriter, Location loc, unsigned e, Value v0, 282 Value v1); 283 284 private: 285 /// Private helpers. 286 bool maybeZero(unsigned e) const; 287 bool isInvariant(unsigned e) const; 288 Type inferType(unsigned e, Value src); 289 290 /// Traverses the SSA tree (possibly a DAG) to build a tensor expression. 291 Optional<unsigned> buildTensorExp(linalg::GenericOp op, Value v); 292 293 /// Merger data structures. 294 const unsigned outTensor; 295 const unsigned syntheticTensor; 296 const unsigned numTensors; 297 const unsigned numLoops; 298 bool hasSparseOut; 299 std::vector<std::vector<Dim>> dims; 300 llvm::SmallVector<TensorExp, 32> tensorExps; 301 llvm::SmallVector<LatPoint, 16> latPoints; 302 llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets; 303 }; 304 305 } // namespace sparse_tensor 306 } // namespace mlir 307 308 #endif // MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_ 309