1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering sparse tensor types to actual sparse code. 10 // 11 // The concept of letting a compiler generate sparse code automatically was 12 // pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and 13 // formalized to tensor algebra by [Kjolstad17,20] for the Sparse Tensor 14 // Algebra Compiler (TACO). The implementation in this file closely follows 15 // the "sparse iteration theory" that forms the foundation of TACO. A rewriting 16 // rule is applied to each tensor expression in linalg (MLIR's tensor index 17 // notation) where the sparsity of tensors is indicated with annotation using 18 // a per-dimension specification of sparse/dense storage together with a 19 // specification of the order on the dimensions. Subsequently, a topologically 20 // sorted iteration graph, reflecting the required order on indices with respect 21 // to the dimensions of each tensor, is constructed to ensure that all tensors 22 // are visited in natural index order. Next, iteration lattices are constructed 23 // for the tensor expression for every index in topological order. Each 24 // iteration lattice point consists of a conjunction of tensor indices together 25 // with a tensor (sub)expression that needs to be evaluated for that 26 // conjunction. Within the lattice, iteration points are ordered according to 27 // the way indices are exhausted. As such these iteration lattices drive actual 28 // sparse code generation, which consists of a tedious but relatively 29 // straightforward one-to-one mapping from iteration lattices to combinations 30 // of for-loops, while-loops, and if-statements. 31 // 32 // [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations. 33 // PhD thesis, Leiden University, May 1996 (aartbik.com/sparse.php). 34 // [Kjolstad17] Fredrik Berg Kjolstad, Shoaib Ashraf Kamil, Stephen Chou, 35 // David Lugato, and Saman Amarasinghe. The Tensor Algebra Compiler. 36 // Proceedings of the ACM on Programming Languages, October 2017. 37 // [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation. 38 // PhD thesis, MIT, February, 2020 (tensor-compiler.org). 39 // 40 // Implementation detail: We use llvm::SmallVector for vectors with 41 // variable lengths and std::vector for vectors with fixed lengths. 42 //===----------------------------------------------------------------------===// 43 44 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 45 #include "mlir/Dialect/Linalg/Utils/Utils.h" 46 #include "mlir/Dialect/MemRef/IR/MemRef.h" 47 #include "mlir/Dialect/SCF/SCF.h" 48 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 49 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 50 #include "mlir/Dialect/StandardOps/IR/Ops.h" 51 #include "mlir/Dialect/Vector/VectorOps.h" 52 #include "mlir/IR/Matchers.h" 53 #include "mlir/IR/TensorEncoding.h" 54 #include "llvm/ADT/SmallBitVector.h" 55 56 using namespace mlir; 57 using namespace mlir::sparse_tensor; 58 59 namespace { 60 61 enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI }; 62 enum class Dim { kSparse, kDense, kSingle, kUndef }; 63 64 /// Tensor expression. Represents a MLIR expression in tensor index notation. 65 /// For tensors, e0 denotes the tensor index. For invariants, the IR value is 66 /// stored directly. For binary operations, e0 and e1 denote the index of the 67 /// children tensor expressions. 68 struct TensorExp { 69 TensorExp(Kind k, unsigned x, unsigned y, Value v) 70 : kind(k), e0(x), e1(y), val(v) { 71 assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) || 72 (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) || 73 (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val)); 74 } 75 Kind kind; 76 /// Indices of children expression(s). 77 unsigned e0; 78 unsigned e1; 79 /// Direct link to IR for an invariant. During code generation, 80 /// field is used to cache "hoisted" loop invariant tensor loads. 81 Value val; 82 }; 83 84 /// Lattice point. Each lattice point consists of a conjunction of tensor 85 /// loop indices (encoded in a bitvector) and the index of the corresponding 86 /// tensor expression. 87 struct LatPoint { 88 LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) { 89 bits.set(b); 90 } 91 LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {} 92 /// Conjunction of tensor loop indices as bitvector. This represents 93 /// all indices involved in the tensor expression 94 llvm::BitVector bits; 95 /// Simplified conjunction of tensor loop indices as bitvector. This 96 /// represents a simplified condition under which this tensor expression 97 /// must execute. Pre-computed during codegen to avoid repeated eval. 98 llvm::BitVector simple; 99 /// Index of the tensor expresssion. 100 unsigned exp; 101 }; 102 103 /// A class to handle all iteration lattice operations. This class abstracts 104 /// away from some implementation details of storing iteration lattices and 105 /// tensor expressions. This allows for fine-tuning performance characteristics 106 /// independently from the basic algorithm if bottlenecks are identified. 107 class Merger { 108 public: 109 /// Constructs a merger for the given number of tensors and loops. The 110 /// user supplies the number of tensors involved in the kernel, with the 111 /// last tensor in this set denoting the output tensor. The merger adds an 112 /// additional synthetic tensor at the end of this set to represent all 113 /// invariant expressions in the kernel. 114 Merger(unsigned t, unsigned l) 115 : outTensor(t - 1), numTensors(t + 1), numLoops(l), 116 dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {} 117 118 /// Adds a tensor expression. Returns its index. 119 unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) { 120 unsigned e = tensorExps.size(); 121 tensorExps.push_back(TensorExp(k, e0, e1, v)); 122 return e; 123 } 124 unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); } 125 126 /// Adds an iteration lattice point. Returns its index. 127 unsigned addLat(unsigned t, unsigned i, unsigned e) { 128 assert(t < numTensors && i < numLoops); 129 unsigned p = latPoints.size(); 130 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 131 return p; 132 } 133 134 /// Adds a new, initially empty, set. Returns its index. 135 unsigned addSet() { 136 unsigned s = latSets.size(); 137 latSets.emplace_back(SmallVector<unsigned, 16>()); 138 return s; 139 } 140 141 /// Computes a single conjunction of two lattice points by taking the "union" 142 /// of loop indices (effectively constructing a larger "intersection" of those 143 /// indices) with a newly constructed tensor (sub)expression of given kind. 144 /// Returns the index of the new lattice point. 145 unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) { 146 unsigned p = latPoints.size(); 147 llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits); 148 nb |= latPoints[p1].bits; 149 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); 150 latPoints.push_back(LatPoint(nb, e)); 151 return p; 152 } 153 154 /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of 155 /// cartesian product. Returns the index of the new set. 156 unsigned takeConj(Kind kind, unsigned s0, unsigned s1) { 157 unsigned s = addSet(); 158 for (unsigned p0 : latSets[s0]) 159 for (unsigned p1 : latSets[s1]) 160 latSets[s].push_back(conjLatPoint(kind, p0, p1)); 161 return s; 162 } 163 164 /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1). 165 /// Returns the index of the new set. 166 unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) { 167 unsigned s = takeConj(kind, s0, s1); 168 for (unsigned p : latSets[s0]) 169 latSets[s].push_back(p); 170 for (unsigned p : latSets[s1]) 171 latSets[s].push_back(p); 172 return s; 173 } 174 175 /// Optimizes the iteration lattice points in the given set. This 176 /// method should be called right before code generation to avoid 177 /// generating redundant loops and conditions. 178 unsigned optimizeSet(unsigned s0) { 179 unsigned s = addSet(); 180 assert(latSets[s0].size() != 0); 181 unsigned p0 = latSets[s0][0]; 182 for (unsigned p1 : latSets[s0]) { 183 bool add = true; 184 if (p0 != p1) { 185 // Is this a straightforward copy? 186 unsigned e = latPoints[p1].exp; 187 if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor) 188 continue; 189 // Conjunction already covered? 190 for (unsigned p2 : latSets[s]) { 191 assert(!latGT(p1, p2)); // Lj => Li would be bad 192 if (onlyDenseDiff(p2, p1)) { 193 add = false; 194 break; 195 } 196 } 197 assert(!add || latGT(p0, p1)); 198 } 199 if (add) 200 latSets[s].push_back(p1); 201 } 202 for (unsigned p : latSets[s]) 203 latPoints[p].simple = simplifyCond(s, p); 204 return s; 205 } 206 207 /// Simplifies the conditions in a conjunction of a given lattice point 208 /// within the given set using just two basic rules: 209 /// (1) multiple dense conditions are reduced to single dense, and 210 /// (2) a *singleton* sparse/dense is reduced to sparse/random access. 211 llvm::BitVector simplifyCond(unsigned s, unsigned p0) { 212 // First determine if this lattice point is a *singleton*, i.e., 213 // the last point in a lattice, no other is less than this one. 214 bool isSingleton = true; 215 for (unsigned p1 : latSets[s]) { 216 if (p0 != p1 && latGT(p0, p1)) { 217 isSingleton = false; 218 break; 219 } 220 } 221 // Now apply the two basic rules. 222 llvm::BitVector simple = latPoints[p0].bits; 223 bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse); 224 for (unsigned b = 0, be = simple.size(); b < be; b++) { 225 if (simple[b] && !isDim(b, Dim::kSparse)) { 226 if (reset) 227 simple.reset(b); 228 reset = true; 229 } 230 } 231 return simple; 232 } 233 234 /// Returns true if Li > Lj. 235 bool latGT(unsigned i, unsigned j) const { 236 const llvm::BitVector &bitsi = latPoints[i].bits; 237 const llvm::BitVector &bitsj = latPoints[j].bits; 238 assert(bitsi.size() == bitsj.size()); 239 if (bitsi.count() > bitsj.count()) { 240 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 241 if (bitsj[b] && !bitsi[b]) 242 return false; 243 return true; 244 } 245 return false; 246 } 247 248 /// Returns true if Li and Lj only differ in dense. 249 bool onlyDenseDiff(unsigned i, unsigned j) { 250 llvm::BitVector tmp = latPoints[j].bits; 251 tmp ^= latPoints[i].bits; 252 return !hasAnyDimOf(tmp, Dim::kSparse); 253 } 254 255 /// Bit translation. 256 unsigned tensor(unsigned b) const { return b % numTensors; } 257 unsigned index(unsigned b) const { return b / numTensors; } 258 259 /// Returns true if bit corresponds to queried dim. 260 bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); } 261 262 /// Returns true if bit corresponds to index of output tensor. 263 bool isOutTensor(unsigned b, unsigned i) const { 264 return tensor(b) == outTensor && index(b) == i; 265 } 266 267 /// Returns true if tensor access at given index has queried dim. 268 bool isDim(unsigned t, unsigned i, Dim d) const { 269 assert(t < numTensors && i < numLoops); 270 return dims[t][i] == d; 271 } 272 273 /// Returns true if any set bit corresponds to queried dim. 274 bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const { 275 for (unsigned b = 0, be = bits.size(); b < be; b++) 276 if (bits[b] && isDim(b, d)) 277 return true; 278 return false; 279 } 280 281 /// Setter 282 void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; } 283 284 /// Getters. 285 TensorExp &exp(unsigned e) { return tensorExps[e]; } 286 LatPoint &lat(unsigned l) { return latPoints[l]; } 287 SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; } 288 289 private: 290 const unsigned outTensor; 291 const unsigned numTensors; 292 const unsigned numLoops; 293 294 std::vector<std::vector<Dim>> dims; 295 llvm::SmallVector<TensorExp, 32> tensorExps; 296 llvm::SmallVector<LatPoint, 16> latPoints; 297 llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets; 298 }; 299 300 // Code generation. 301 struct CodeGen { 302 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) 303 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 304 pointers(numTensors, std::vector<Value>(numLoops)), 305 indices(numTensors, std::vector<Value>(numLoops)), 306 highs(numTensors, std::vector<Value>(numLoops)), 307 pidxs(numTensors, std::vector<Value>(numLoops)), 308 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 309 curVecLength(1), curVecMask() {} 310 /// Sparsification options. 311 SparsificationOptions options; 312 /// Universal dense indices and upper bounds (by index). The loops array 313 /// is updated with the value of the universal dense index in the current 314 /// loop. The sizes array is set once with the inferred dimension sizes. 315 std::vector<Value> loops; 316 std::vector<Value> sizes; 317 /// Buffers for storing dense and sparse numerical values (by tensor). 318 /// This array is set once during bufferization of all tensors. 319 std::vector<Value> buffers; 320 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 321 /// This array is set once during bufferization of all sparse tensors. 322 std::vector<std::vector<Value>> pointers; 323 std::vector<std::vector<Value>> indices; 324 /// Sparse iteration information (by tensor and index). These arrays 325 /// are updated to remain current within the current loop. 326 std::vector<std::vector<Value>> highs; 327 std::vector<std::vector<Value>> pidxs; 328 std::vector<std::vector<Value>> idxs; 329 /// Current reduction, updated during code generation. When indices of a 330 /// reduction are exhausted, all inner loops can "scalarize" the reduction. 331 // TODO: currently only done for (a chain of) innermost for-loops, where it 332 // is most effective; we could generalize to more outer and while-loops. 333 unsigned redExp; 334 Value redVal; 335 // Current vector length and mask. 336 unsigned curVecLength; 337 Value curVecMask; 338 }; 339 340 } // namespace 341 342 // Helper method to apply dimension ordering permutation. 343 static unsigned perm(SparseTensorEncodingAttr &enc, unsigned d) { 344 if (enc) { 345 auto order = enc.getDimOrdering(); 346 if (order) { 347 assert(order.isPermutation()); 348 return order.getDimPosition(d); 349 } 350 } 351 return d; 352 } 353 354 // Helper method to translate dim level type to internal representation. 355 static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) { 356 if (enc) { 357 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 358 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 359 return Dim::kSparse; 360 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 361 return Dim::kSingle; 362 } 363 return Dim::kDense; 364 } 365 366 /// Helper method to inspect sparse encodings in the tensor types. 367 /// Fills the per-dimension sparsity information for all tensors. 368 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 369 bool annotated = false; 370 for (OpOperand *t : op.getInputAndOutputOperands()) { 371 auto map = op.getTiedIndexingMap(t); 372 if (!map.isProjectedPermutation()) 373 return false; 374 auto enc = getSparseTensorEncoding(t->get().getType()); 375 if (enc) 376 annotated = true; 377 assert(map.getNumResults() == op.getRank(t)); 378 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 379 unsigned idx = map.getDimPosition(perm(enc, d)); 380 merger.setDim(t->getOperandNumber(), idx, toDim(enc, d)); 381 } 382 } 383 return annotated; 384 } 385 386 /// A DFS helper to compute a topological sort. Note that recursion is 387 /// bounded by the number of implicit loops, which is always small. 388 /// Returns false when a cycle is detected. 389 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 390 std::vector<unsigned> &topSort, 391 std::vector<std::vector<bool>> &adjM) { 392 if (visit[i] != 0) 393 return visit[i] != 1; // 1 denotes cycle! 394 visit[i] = 1; 395 for (unsigned j = 0, e = visit.size(); j < e; j++) 396 if (adjM[i][j]) 397 if (!topSortDFS(j, visit, topSort, adjM)) 398 return false; 399 visit[i] = 2; 400 topSort.push_back(i); 401 return true; 402 } 403 404 /// Computes a topologically sorted iteration graph for the linalg operation. 405 /// Ensures all tensors are visited in natural index order. This is essential 406 /// for sparse storage formats since these only support access along fixed 407 /// dimensions. Even for dense storage formats, however, the natural index 408 /// order yields innermost unit-stride access with better spatial locality. 409 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 410 std::vector<unsigned> &topSort, 411 bool sparseOnly) { 412 // Set up an n x n from/to adjacency matrix of the iteration graph 413 // for the implicit loop indices i_0 .. i_n-1. 414 unsigned n = op.getNumLoops(); 415 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 416 417 // Iterate over the indexing maps of every tensor in the tensor expression. 418 for (OpOperand *t : op.getInputAndOutputOperands()) { 419 auto map = op.getTiedIndexingMap(t); 420 auto enc = getSparseTensorEncoding(t->get().getType()); 421 assert(map.getNumDims() == n); 422 // Skip dense tensor constraints when sparse only is requested. 423 if (sparseOnly && !enc) 424 continue; 425 // Each tensor expression and optional dimension ordering (row-major 426 // by default) puts an ordering constraint on the loop indices. For 427 // example, the tensor expresion A_ijk forces the ordering i < j < k 428 // on the loop indices if no explicit dimension ordering is given. 429 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 430 unsigned f = map.getDimPosition(perm(enc, d - 1)); 431 unsigned t = map.getDimPosition(perm(enc, d)); 432 adjM[f][t] = true; 433 } 434 } 435 436 // Topologically sort the iteration graph to determine loop order. 437 // Report failure for a cyclic iteration graph. 438 topSort.clear(); 439 topSort.reserve(n); 440 std::vector<unsigned> visit(n, 0); 441 for (unsigned i = 0; i < n; i++) 442 if (visit[i] == 0) 443 if (!topSortDFS(i, visit, topSort, adjM)) 444 return false; // cycle! 445 std::reverse(std::begin(topSort), std::end(topSort)); 446 return true; 447 } 448 449 /// Traverses the SSA tree (possibly a DAG) to build a tensor expression. 450 /// This simplifies constructing (sub)expressions during iteration lattice 451 /// building (compared to using the SSA representation everywhere). 452 static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op, 453 Value val) { 454 if (auto arg = val.dyn_cast<BlockArgument>()) { 455 unsigned argN = arg.getArgNumber(); 456 // Any argument of the generic op that is not marked as a scalar 457 // argument is considered a tensor, indexed by the implicit loop 458 // bounds. This includes rank-0 tensor arguments. 459 if (arg.getOwner()->getParentOp() == op) { 460 OpOperand *t = op.getInputAndOutputOperands()[argN]; 461 if (!op.isScalar(t)) 462 return merger.addExp(Kind::kTensor, argN); 463 val = t->get(); // get scalar value 464 } 465 // Any other argument (marked as scalar argument for the generic op 466 // or belonging to an enveloping op) is considered invariant. 467 return merger.addExp(Kind::kInvariant, val); 468 } 469 Operation *def = val.getDefiningOp(); 470 if (def->getBlock() != &op.region().front()) { 471 // Something defined outside is invariant. 472 return merger.addExp(Kind::kInvariant, val); 473 } else if (def->getNumOperands() == 2) { 474 // Construct binary operations if subexpressions could be built. 475 auto x = buildTensorExp(merger, op, def->getOperand(0)); 476 auto y = buildTensorExp(merger, op, def->getOperand(1)); 477 if (x.hasValue() && y.hasValue()) { 478 unsigned e0 = x.getValue(); 479 unsigned e1 = y.getValue(); 480 if (isa<MulFOp>(def)) 481 return merger.addExp(Kind::kMulF, e0, e1); 482 if (isa<MulIOp>(def)) 483 return merger.addExp(Kind::kMulI, e0, e1); 484 if (isa<AddFOp>(def)) 485 return merger.addExp(Kind::kAddF, e0, e1); 486 if (isa<AddIOp>(def)) 487 return merger.addExp(Kind::kAddI, e0, e1); 488 } 489 } 490 // Cannot build (yet). 491 return None; 492 } 493 494 /// Returns true if given tensor co-iterates with conjunction only. 495 /// For the output tensor, this defines a "simply dynamic" operation. 496 /// For instance: A(I) = A(I) * B(I) * C(I) 497 static unsigned isConjunction(Merger &merger, unsigned tensor, unsigned exp) { 498 switch (merger.exp(exp).kind) { 499 case Kind::kTensor: 500 return merger.exp(exp).e0 == tensor; 501 case Kind::kMulF: 502 case Kind::kMulI: 503 return isConjunction(merger, tensor, merger.exp(exp).e0) || 504 isConjunction(merger, tensor, merger.exp(exp).e1); 505 default: 506 return false; 507 } 508 } 509 510 /// Returns true when the tensor expression is admissable for codegen. 511 /// Since all sparse input tensors are admissable, we just need to check 512 /// whether the output tensor in the tensor expression codegen is admissable. 513 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 514 unsigned exp) { 515 OpOperand *lhs = op.getOutputOperand(0); 516 unsigned tensor = lhs->getOperandNumber(); 517 auto enc = getSparseTensorEncoding(lhs->get().getType()); 518 // An non-annotated output tensor is assumed dense, and becomes a random 519 // access n-dim memref. Admissable since inserstions cannot occur. 520 if (!enc) 521 return true; 522 // An all-dense annotated "sparse" output tensor becomes a linearized random 523 // access 1-dim memref. Also admissable since insertions cannot occur. 524 bool allDense = true; 525 unsigned numLoops = op.iterator_types().getValue().size(); 526 for (unsigned i = 0; i < numLoops; i++) 527 if (merger.isDim(tensor, i, Dim::kSparse)) { 528 allDense = false; 529 break; 530 } 531 if (allDense) 532 return true; 533 // A tensor expression with a sparse output tensor that changes its values 534 // but not its nonzero structure, an operation called "simply dynamic" in 535 // [Bik96,Ch9], is also admissable without special codegen. 536 if (isConjunction(merger, tensor, exp)) 537 return true; 538 // Reject for now since this requires changes to the nonzero structure. 539 // TODO: implement "workspaces" [Kjolstad2019] 540 return false; 541 } 542 543 /// Builds the iteration lattices in a bottom-up traversal given the remaining 544 /// tensor (sub)expression and the next loop index in the iteration graph. 545 static unsigned buildLattices(Merger &merger, linalg::GenericOp op, 546 unsigned exp, unsigned idx) { 547 Kind kind = merger.exp(exp).kind; 548 if (kind == Kind::kTensor || kind == Kind::kInvariant) { 549 // Either the index is really used in the tensor expression, or it is 550 // set to the undefined index in that dimension. An invariant expression 551 // is set to a synthetic tensor with undefined indices only. 552 unsigned s = merger.addSet(); 553 unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0 554 : op.getNumInputsAndOutputs(); 555 merger.set(s).push_back(merger.addLat(t, idx, exp)); 556 return s; 557 } 558 unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx); 559 unsigned s1 = buildLattices(merger, op, merger.exp(exp).e1, idx); 560 switch (kind) { 561 case Kind::kTensor: 562 case Kind::kInvariant: 563 llvm_unreachable("handled above"); 564 case Kind::kMulF: 565 case Kind::kMulI: 566 return merger.takeConj(kind, s0, s1); 567 case Kind::kAddF: 568 case Kind::kAddI: 569 return merger.takeDisj(kind, s0, s1); 570 } 571 llvm_unreachable("unexpected expression kind"); 572 } 573 574 /// Maps sparse integer option to actual integral storage type. 575 static Type genIntType(PatternRewriter &rewriter, unsigned width) { 576 if (width == 0) 577 return rewriter.getIndexType(); 578 return rewriter.getIntegerType(width); 579 } 580 581 /// Detects in-place annotation on tensor argument. 582 static bool getInPlace(Value val) { 583 if (auto arg = val.dyn_cast<BlockArgument>()) 584 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 585 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 586 arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) 587 return attr.getValue(); 588 return false; 589 } 590 591 /// Generates buffer for the output tensor. 592 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 593 linalg::GenericOp op, MemRefType denseTp, 594 ArrayRef<Value> args) { 595 Location loc = op.getLoc(); 596 Value tensor = op.getOutputOperand(0)->get(); 597 // The output tensor simply could materialize from the buffer that will 598 // be generated for the tensor present in the outs() clause. This has 599 // the major advantage that the sparse kernel only updates the nonzero 600 // positions for the output tensor. 601 if (getInPlace(tensor)) 602 return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 603 // By default, a new buffer is allocated which is initialized to the 604 // tensor defined in the outs() clause. This is always correct but 605 // introduces a dense initialization component that may negatively 606 // impact the running complexity of the sparse kernel. 607 Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 608 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 609 rewriter.create<linalg::CopyOp>(loc, init, alloc); 610 return alloc; 611 } 612 613 /// Local bufferization of all dense and sparse data structures. 614 /// This code enables testing the first prototype sparse compiler. 615 // TODO: replace this with a proliferated bufferization strategy 616 static bool genBuffers(Merger &merger, CodeGen &codegen, 617 PatternRewriter &rewriter, linalg::GenericOp op) { 618 Location loc = op.getLoc(); 619 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 620 // For every tensor, find lower and upper bound on dimensions, set the 621 // same bounds on loop indices, and obtain dense or sparse buffer(s). 622 SmallVector<Value, 4> args; 623 for (OpOperand *t : op.getInputAndOutputOperands()) { 624 unsigned tensor = t->getOperandNumber(); 625 auto shape = op.getShape(t); 626 auto map = op.getTiedIndexingMap(t); 627 auto enc = getSparseTensorEncoding(t->get().getType()); 628 // Scan all dimensions of current tensor. 629 args.clear(); 630 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 631 unsigned idx = map.getDimPosition(perm(enc, d)); 632 // Handle sparse storage schemes. 633 if (merger.isDim(tensor, idx, Dim::kSparse)) { 634 auto dynShape = {ShapedType::kDynamicSize}; 635 auto ptrTp = MemRefType::get( 636 dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 637 auto indTp = MemRefType::get( 638 dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 639 Value dim = rewriter.create<ConstantIndexOp>(loc, d); 640 // Generate sparse primitives to obtains pointer and indices. 641 codegen.pointers[tensor][idx] = 642 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 643 codegen.indices[tensor][idx] = 644 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 645 } 646 // Find lower and upper bound in current dimension. 647 Value up; 648 if (shape[d] == MemRefType::kDynamicSize) { 649 up = rewriter.create<memref::DimOp>(loc, t->get(), d); 650 args.push_back(up); 651 } else { 652 up = rewriter.create<ConstantIndexOp>(loc, shape[d]); 653 } 654 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 655 } 656 // Perform the required bufferization. Dense inputs materialize 657 // from the input tensors. Dense outputs need special handling. 658 // Sparse inputs use sparse primitives to obtain the values. 659 // We also accept in-place all-dense annotated "sparse" outputs. 660 Type elementType = getElementTypeOrSelf(t->get().getType()); 661 if (!enc) { 662 // Non-annotated dense tensors. 663 auto denseTp = MemRefType::get(shape, elementType); 664 if (tensor < op.getNumInputs()) 665 codegen.buffers[tensor] = 666 rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get()); 667 else 668 codegen.buffers[tensor] = 669 genOutputBuffer(codegen, rewriter, op, denseTp, args); 670 } else { 671 // Annotated sparse tensors. 672 if (tensor == op.getNumInputs() && !getInPlace(t->get())) 673 return false; // reject output if not in-place 674 auto dynShape = {ShapedType::kDynamicSize}; 675 auto sparseTp = MemRefType::get(dynShape, elementType); 676 codegen.buffers[tensor] = 677 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 678 } 679 } 680 return true; 681 } 682 683 /// Constructs vector type. 684 static VectorType vectorType(CodeGen &codegen, Type etp) { 685 return VectorType::get(codegen.curVecLength, etp); 686 } 687 688 /// Constructs vector type from pointer. 689 static VectorType vectorType(CodeGen &codegen, Value ptr) { 690 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 691 } 692 693 /// Constructs vector iteration mask. 694 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 695 Value iv, Value lo, Value hi, Value step) { 696 Location loc = iv.getLoc(); 697 VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1)); 698 // Special case if the vector length evenly divides the trip count (for 699 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 700 // so that all subsequent masked memory operations are immediately folded 701 // into unconditional memory operations. 702 IntegerAttr loInt, hiInt, stepInt; 703 if (matchPattern(lo, m_Constant(&loInt)) && 704 matchPattern(hi, m_Constant(&hiInt)) && 705 matchPattern(step, m_Constant(&stepInt))) { 706 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 707 return rewriter.create<vector::BroadcastOp>( 708 loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1)); 709 } 710 // Otherwise, generate a vector mask that avoids overrunning the upperbound 711 // during vector execution. Here we rely on subsequent loop optimizations to 712 // avoid executing the mask in all iterations, for example, by splitting the 713 // loop into an unconditional vector loop and a scalar cleanup loop. 714 Value end = rewriter.create<SubIOp>(loc, hi, iv); 715 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 716 } 717 718 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 719 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 720 Value ptr, ArrayRef<Value> args) { 721 Location loc = ptr.getLoc(); 722 VectorType vtp = vectorType(codegen, ptr); 723 Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 724 if (args.back().getType().isa<VectorType>()) { 725 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 726 Value indexVec = args.back(); 727 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 728 return rewriter.create<vector::GatherOp>( 729 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 730 } 731 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 732 codegen.curVecMask, pass); 733 } 734 735 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 736 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 737 Value rhs, Value ptr, ArrayRef<Value> args) { 738 Location loc = ptr.getLoc(); 739 if (args.back().getType().isa<VectorType>()) { 740 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 741 Value indexVec = args.back(); 742 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 743 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 744 codegen.curVecMask, rhs); 745 return; 746 } 747 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 748 rhs); 749 } 750 751 /// Generates a vectorized invariant. Here we rely on subsequent loop 752 /// optimizations to hoist the invariant broadcast out of the vector loop. 753 static Value genVectorInvariantValue(CodeGen &codegen, 754 PatternRewriter &rewriter, Value val) { 755 VectorType vtp = vectorType(codegen, val.getType()); 756 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 757 } 758 759 /// Generates a load on a dense or sparse tensor. 760 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 761 PatternRewriter &rewriter, linalg::GenericOp op, 762 unsigned exp) { 763 // Test if the load was hoisted to a higher loop nest. 764 Value val = merger.exp(exp).val; 765 if (val) { 766 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 767 return genVectorInvariantValue(codegen, rewriter, val); 768 return val; 769 } 770 // Actual load. 771 SmallVector<Value, 4> args; 772 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0]; 773 unsigned tensor = t->getOperandNumber(); 774 auto map = op.getTiedIndexingMap(t); 775 auto enc = getSparseTensorEncoding(t->get().getType()); 776 unsigned rank = map.getNumResults(); 777 if (enc) { 778 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 779 assert(codegen.pidxs[tensor][idx] != nullptr); 780 args.push_back(codegen.pidxs[tensor][idx]); // position index 781 } else { 782 for (unsigned d = 0; d < rank; d++) { 783 unsigned idx = map.getDimPosition(d); 784 args.push_back(codegen.loops[idx]); // universal dense index 785 } 786 } 787 Location loc = op.getLoc(); 788 Value ptr = codegen.buffers[tensor]; 789 if (codegen.curVecLength > 1) 790 return genVectorLoad(codegen, rewriter, ptr, args); 791 return rewriter.create<memref::LoadOp>(loc, ptr, args); 792 } 793 794 /// Generates a store on a dense or sparse tensor. 795 static void genTensorStore(Merger &merger, CodeGen &codegen, 796 PatternRewriter &rewriter, linalg::GenericOp op, 797 OpOperand *t, Value rhs) { 798 Location loc = op.getLoc(); 799 // Test if this is a scalarized reduction. 800 OpOperand *lhs = op.getOutputOperand(0); 801 if (lhs == t && codegen.redVal) { 802 if (codegen.curVecLength > 1) 803 rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs, 804 codegen.redVal); 805 codegen.redVal = rhs; 806 return; 807 } 808 // Actual store. 809 SmallVector<Value, 4> args; 810 unsigned tensor = t->getOperandNumber(); 811 auto map = op.getTiedIndexingMap(t); 812 auto enc = getSparseTensorEncoding(t->get().getType()); 813 unsigned rank = map.getNumResults(); 814 if (enc) { 815 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 816 assert(codegen.pidxs[tensor][idx] != nullptr); 817 args.push_back(codegen.pidxs[tensor][idx]); // position index 818 } else { 819 for (unsigned d = 0; d < rank; d++) { 820 unsigned idx = map.getDimPosition(d); 821 args.push_back(codegen.loops[idx]); // universal dense index 822 } 823 } 824 Value ptr = codegen.buffers[tensor]; 825 if (codegen.curVecLength > 1) 826 genVectorStore(codegen, rewriter, rhs, ptr, args); 827 else 828 rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 829 } 830 831 /// Generates a pointer/index load from the sparse storage scheme. Narrower 832 /// data types need to be zero extended before casting the value into the 833 /// index type used for looping and indexing. 834 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 835 Value ptr, Value s) { 836 // See https://llvm.org/docs/GetElementPtr.html for some background on 837 // the complications described below. 838 if (codegen.curVecLength > 1) { 839 // Since the index vector is used in a subsequent gather/scatter operations, 840 // which effectively defines an unsigned pointer + signed index, we must 841 // zero extend the vector to an index width. For 8-bit and 16-bit values, 842 // an 32-bit index width suffices. For 32-bit values, zero extending the 843 // elements into 64-bit loses some performance since the 32-bit indexed 844 // gather/scatter is more efficient than the 64-bit index variant (if the 845 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 846 // preserve this performance). For 64-bit values, there is no good way 847 // to state that the indices are unsigned, with creates the potential of 848 // incorrect address calculations in the unlikely case we need such 849 // extremely large offsets. 850 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 851 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 852 if (!etp.isa<IndexType>()) { 853 if (etp.getIntOrFloatBitWidth() < 32) 854 vload = rewriter.create<ZeroExtendIOp>( 855 loc, vload, vectorType(codegen, rewriter.getIntegerType(32))); 856 else if (etp.getIntOrFloatBitWidth() < 64 && 857 !codegen.options.enableSIMDIndex32) 858 vload = rewriter.create<ZeroExtendIOp>( 859 loc, vload, vectorType(codegen, rewriter.getIntegerType(64))); 860 } 861 return vload; 862 } 863 // For the scalar case, we simply zero extend narrower indices into 64-bit 864 // values before casting to index without a performance penalty. Here too, 865 // however, indices that already are 64-bit, in theory, cannot express the 866 // full range as explained above. 867 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 868 if (!load.getType().isa<IndexType>()) { 869 if (load.getType().getIntOrFloatBitWidth() < 64) 870 load = rewriter.create<ZeroExtendIOp>(loc, load, 871 rewriter.getIntegerType(64)); 872 load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType()); 873 } 874 return load; 875 } 876 877 /// Generates an invariant value. 878 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 879 PatternRewriter &rewriter, unsigned exp) { 880 Value val = merger.exp(exp).val; 881 if (codegen.curVecLength > 1) 882 return genVectorInvariantValue(codegen, rewriter, val); 883 return val; 884 } 885 886 /// Generates an address computation "sz * p + i". 887 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 888 Location loc, Value size, Value p, Value i) { 889 Value mul = rewriter.create<MulIOp>(loc, size, p); 890 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 891 Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType()); 892 mul = genVectorInvariantValue(codegen, rewriter, inv); 893 } 894 return rewriter.create<AddIOp>(loc, mul, i); 895 } 896 897 /// Generates start of a reduction. 898 static Value genReductionStart(Merger &merger, CodeGen &codegen, 899 PatternRewriter &rewriter, 900 linalg::GenericOp op) { 901 if (codegen.redVal) 902 return codegen.redVal; // chained with previous for-loop 903 if (codegen.curVecLength > 1) { 904 // TODO: assumes + reductions for now 905 VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]); 906 return rewriter.create<ConstantOp>(op.getLoc(), vtp, 907 rewriter.getZeroAttr(vtp)); 908 } 909 return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 910 } 911 912 /// Generates end of a reduction. 913 static void genReductionEnd(Merger &merger, CodeGen &codegen, 914 PatternRewriter &rewriter, linalg::GenericOp op) { 915 Value red = codegen.redVal; 916 if (!red) 917 return; 918 assert(codegen.curVecLength == 1); 919 codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain 920 OpOperand *lhs = op.getOutputOperand(0); 921 if (auto vtp = red.getType().dyn_cast<VectorType>()) { 922 // TODO: assumes + reductions for now 923 StringAttr kind = rewriter.getStringAttr("add"); 924 Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 925 // Integer reductions don't accept an accumulator. 926 if (vtp.getElementType().isa<IntegerType>()) { 927 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 928 kind, red, ValueRange{}); 929 red = rewriter.create<AddIOp>(op.getLoc(), red, ld); 930 } else { 931 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 932 kind, red, ld); 933 } 934 } 935 genTensorStore(merger, codegen, rewriter, op, lhs, red); 936 } 937 938 /// Recursively generates tensor expression. 939 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 940 linalg::GenericOp op, unsigned exp) { 941 if (merger.exp(exp).kind == Kind::kTensor) 942 return genTensorLoad(merger, codegen, rewriter, op, exp); 943 else if (merger.exp(exp).kind == Kind::kInvariant) 944 return genInvariantValue(merger, codegen, rewriter, exp); 945 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0); 946 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1); 947 switch (merger.exp(exp).kind) { 948 case Kind::kTensor: 949 case Kind::kInvariant: 950 llvm_unreachable("handled above"); 951 case Kind::kMulF: 952 return rewriter.create<MulFOp>(op.getLoc(), v0, v1); 953 case Kind::kMulI: 954 return rewriter.create<MulIOp>(op.getLoc(), v0, v1); 955 case Kind::kAddF: 956 return rewriter.create<AddFOp>(op.getLoc(), v0, v1); 957 case Kind::kAddI: 958 return rewriter.create<AddIOp>(op.getLoc(), v0, v1); 959 } 960 llvm_unreachable("unexpected expression kind"); 961 } 962 963 /// Hoists loop invariant tensor loads for which indices have been exhausted. 964 static void genInvariants(Merger &merger, CodeGen &codegen, 965 PatternRewriter &rewriter, linalg::GenericOp op, 966 unsigned exp, unsigned ldx, bool hoist) { 967 if (merger.exp(exp).kind == Kind::kTensor) { 968 // Inspect tensor indices. 969 bool atLevel = ldx == -1u; 970 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0]; 971 auto map = op.getTiedIndexingMap(t); 972 auto enc = getSparseTensorEncoding(t->get().getType()); 973 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 974 unsigned idx = map.getDimPosition(perm(enc, d)); 975 if (!codegen.loops[idx]) 976 return; // still in play 977 else if (idx == ldx) 978 atLevel = true; 979 } 980 // All exhausted at this level (atLevel denotes exactly at this level). 981 OpOperand *lhs = op.getOutputOperand(0); 982 if (lhs == t) { 983 codegen.redExp = hoist ? exp : -1u; 984 } else if (atLevel) { 985 merger.exp(exp).val = 986 hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 987 } 988 } else if (merger.exp(exp).kind != Kind::kInvariant) { 989 // Traverse into the binary operations. Note that we only hoist 990 // tensor loads, since subsequent MLIR/LLVM passes know how to 991 // deal with all other kinds of derived loop invariants. 992 unsigned e0 = merger.exp(exp).e0; 993 unsigned e1 = merger.exp(exp).e1; 994 genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist); 995 genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist); 996 } 997 } 998 999 /// Generates initialization code for the subsequent loop sequence at 1000 /// current index level. Returns true if the loop sequence needs to 1001 /// maintain the universal index. 1002 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1003 linalg::GenericOp op, std::vector<unsigned> &topSort, 1004 unsigned at, llvm::BitVector &inits) { 1005 bool needsUniv = false; 1006 Location loc = op.getLoc(); 1007 unsigned idx = topSort[at]; 1008 1009 // Initialize sparse positions. 1010 for (unsigned b = 0, be = inits.size(); b < be; b++) { 1011 if (inits[b]) { 1012 unsigned tensor = merger.tensor(b); 1013 assert(idx == merger.index(b)); 1014 if (merger.isDim(b, Dim::kSparse)) { 1015 // Initialize sparse index. 1016 unsigned pat = at; 1017 for (; pat != 0; pat--) { 1018 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1019 break; 1020 } 1021 Value ptr = codegen.pointers[tensor][idx]; 1022 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 1023 Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 1024 : codegen.pidxs[tensor][topSort[pat - 1]]; 1025 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 1026 Value p1 = rewriter.create<AddIOp>(loc, p0, one); 1027 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 1028 } else { 1029 // Dense index still in play. 1030 needsUniv = true; 1031 } 1032 } 1033 } 1034 1035 // Initialize the universal dense index. 1036 codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0); 1037 return needsUniv; 1038 } 1039 1040 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 1041 /// operation is a candidate. Whether it is actually converted to SIMD code 1042 /// depends on the requested strategy. 1043 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 1044 switch (codegen.options.vectorizationStrategy) { 1045 case SparseVectorizationStrategy::kNone: 1046 return false; 1047 case SparseVectorizationStrategy::kDenseInnerLoop: 1048 return isInner && !isSparse; 1049 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 1050 return isInner; 1051 } 1052 llvm_unreachable("unexpected vectorization strategy"); 1053 } 1054 1055 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 1056 /// that is marked "parallel" is a candidate. Whether it is actually converted 1057 /// to a parallel operation depends on the requested strategy. 1058 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 1059 bool isSparse, bool isVector) { 1060 switch (codegen.options.parallelizationStrategy) { 1061 case SparseParallelizationStrategy::kNone: 1062 return false; 1063 case SparseParallelizationStrategy::kDenseOuterLoop: 1064 return isOuter && !isSparse && !isReduction && !isVector; 1065 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 1066 return isOuter && !isReduction && !isVector; 1067 case SparseParallelizationStrategy::kDenseAnyLoop: 1068 return !isSparse && !isReduction && !isVector; 1069 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 1070 return !isReduction && !isVector; 1071 } 1072 llvm_unreachable("unexpected parallelization strategy"); 1073 } 1074 1075 /// Checks unit strides for dense tensors. The iteration graph may have ignored 1076 /// dense access patterns in order to avoid cycles (sparse access patterns are 1077 /// always placed innermost), but that means dense access has become strided. 1078 /// For now, we reject vectorization of such cases. 1079 /// TODO: implement strided load/stores on dense arrays 1080 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 1081 unsigned idx) { 1082 for (OpOperand *t : op.getInputAndOutputOperands()) { 1083 if (!getSparseTensorEncoding(t->get().getType())) { 1084 auto map = op.getTiedIndexingMap(t); 1085 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1086 if (map.getDimPosition(d) == idx && d != rank - 1) 1087 return false; 1088 } 1089 } 1090 } 1091 return true; 1092 } 1093 1094 /// Generates a for-loop on a single index. 1095 static Operation *genFor(Merger &merger, CodeGen &codegen, 1096 PatternRewriter &rewriter, linalg::GenericOp op, 1097 bool isOuter, bool isInner, unsigned idx, 1098 llvm::BitVector &indices) { 1099 unsigned fb = indices.find_first(); 1100 unsigned tensor = merger.tensor(fb); 1101 assert(idx == merger.index(fb)); 1102 auto iteratorTypes = op.iterator_types().getValue(); 1103 bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]); 1104 bool isSparse = merger.isDim(fb, Dim::kSparse); 1105 bool isVector = isVectorFor(codegen, isInner, isSparse) && 1106 denseUnitStrides(merger, op, idx); 1107 bool isParallel = 1108 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 1109 1110 // Prepare vector length. 1111 if (isVector) 1112 codegen.curVecLength = codegen.options.vectorLength; 1113 1114 // Loop bounds and increment. 1115 Location loc = op.getLoc(); 1116 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 1117 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 1118 Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength); 1119 1120 // Emit a parallel loop. 1121 if (isParallel) { 1122 assert(!isVector); 1123 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 1124 if (isSparse) 1125 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1126 else 1127 codegen.loops[idx] = parOp.getInductionVars()[0]; 1128 rewriter.setInsertionPointToStart(parOp.getBody()); 1129 return parOp; 1130 } 1131 1132 // Emit a sequential loop, potentially with a scalarized reduction. 1133 bool scalarRed = isInner && codegen.redExp != -1u; 1134 SmallVector<Value, 4> operands; 1135 if (scalarRed) { 1136 Value load = genReductionStart(merger, codegen, rewriter, op); 1137 operands.push_back(load); 1138 } 1139 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 1140 if (scalarRed) { 1141 codegen.redVal = merger.exp(codegen.redExp).val = 1142 forOp.getRegionIterArgs().front(); 1143 } 1144 // Assign induction variable to sparse or dense index. 1145 Value iv = forOp.getInductionVar(); 1146 if (isSparse) 1147 codegen.pidxs[tensor][idx] = iv; 1148 else 1149 codegen.loops[idx] = iv; 1150 rewriter.setInsertionPointToStart(forOp.getBody()); 1151 // Share vector iteration mask between all subsequent loads/stores. 1152 if (isVector) 1153 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1154 return forOp; 1155 } 1156 1157 /// Emit a while-loop for co-iteration over multiple indices. 1158 static Operation *genWhile(Merger &merger, CodeGen &codegen, 1159 PatternRewriter &rewriter, linalg::GenericOp op, 1160 unsigned idx, bool needsUniv, 1161 llvm::BitVector &indices) { 1162 SmallVector<Type, 4> types; 1163 SmallVector<Value, 4> operands; 1164 // Construct the while-loop with a parameter for each index. 1165 Type indexType = rewriter.getIndexType(); 1166 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1167 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1168 unsigned tensor = merger.tensor(b); 1169 assert(idx == merger.index(b)); 1170 types.push_back(indexType); 1171 assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() && 1172 "type mismatch for sparse index"); 1173 operands.push_back(codegen.pidxs[tensor][idx]); 1174 } 1175 } 1176 if (needsUniv) { 1177 types.push_back(indexType); 1178 assert(codegen.loops[idx].getType().isa<IndexType>() && 1179 "type mismatch for universal index"); 1180 operands.push_back(codegen.loops[idx]); 1181 } 1182 Location loc = op.getLoc(); 1183 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1184 Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 1185 Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 1186 1187 // Build the "before" region, which effectively consists 1188 // of a conjunction of "i < upper" tests on all induction. 1189 rewriter.setInsertionPointToStart(&whileOp.before().front()); 1190 Value cond; 1191 unsigned o = 0; 1192 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1193 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1194 unsigned tensor = merger.tensor(b); 1195 assert(idx == merger.index(b)); 1196 Value op1 = before->getArgument(o); 1197 Value op2 = codegen.highs[tensor][idx]; 1198 Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2); 1199 cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc; 1200 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1201 } 1202 } 1203 if (needsUniv) 1204 codegen.loops[idx] = after->getArgument(o++); 1205 assert(o == operands.size()); 1206 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1207 rewriter.setInsertionPointToStart(&whileOp.after().front()); 1208 return whileOp; 1209 } 1210 1211 /// Generates a for-loop or a while-loop, depending on whether it implements 1212 /// singleton iteration or co-iteration over the given conjunction. 1213 static Operation *genLoop(Merger &merger, CodeGen &codegen, 1214 PatternRewriter &rewriter, linalg::GenericOp op, 1215 std::vector<unsigned> &topSort, unsigned at, 1216 bool needsUniv, llvm::BitVector &indices) { 1217 unsigned idx = topSort[at]; 1218 if (indices.count() == 1) { 1219 bool isOuter = at == 0; 1220 bool isInner = at == topSort.size() - 1; 1221 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1222 indices); 1223 } 1224 genReductionEnd(merger, codegen, rewriter, op); // cannot chain 1225 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1226 } 1227 1228 /// Generates the local variables for this loop, consisting of the sparse 1229 /// indices, restored universal dense index, and dense positions. 1230 static void genLocals(Merger &merger, CodeGen &codegen, 1231 PatternRewriter &rewriter, linalg::GenericOp op, 1232 std::vector<unsigned> &topSort, unsigned at, 1233 bool needsUniv, llvm::BitVector &locals) { 1234 Location loc = op.getLoc(); 1235 unsigned idx = topSort[at]; 1236 1237 // Initialize sparse indices. 1238 Value min; 1239 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1240 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1241 unsigned tensor = merger.tensor(b); 1242 assert(idx == merger.index(b)); 1243 Value ptr = codegen.indices[tensor][idx]; 1244 Value s = codegen.pidxs[tensor][idx]; 1245 Value load = genLoad(codegen, rewriter, loc, ptr, s); 1246 codegen.idxs[tensor][idx] = load; 1247 if (!needsUniv) { 1248 if (min) { 1249 Value cmp = 1250 rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min); 1251 min = rewriter.create<SelectOp>(loc, cmp, load, min); 1252 } else { 1253 min = load; 1254 } 1255 } 1256 } 1257 } 1258 1259 // Merge dense universal index over minimum. 1260 if (min) { 1261 assert(!needsUniv); 1262 codegen.loops[idx] = min; 1263 } 1264 1265 // Initialize dense positions. Note that we generate dense indices of the 1266 // output tensor unconditionally, since they may not appear in the lattice, 1267 // but may be needed for linearized codegen. 1268 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1269 if ((locals[b] || merger.isOutTensor(b, idx)) && 1270 merger.isDim(b, Dim::kDense)) { 1271 unsigned tensor = merger.tensor(b); 1272 assert(idx == merger.index(b)); 1273 unsigned pat = at; 1274 for (; pat != 0; pat--) 1275 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1276 break; 1277 Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 1278 : codegen.pidxs[tensor][topSort[pat - 1]]; 1279 codegen.pidxs[tensor][idx] = genAddress( 1280 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1281 } 1282 } 1283 } 1284 1285 /// Generates the induction structure for a while-loop. 1286 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1287 PatternRewriter &rewriter, linalg::GenericOp op, 1288 unsigned idx, bool needsUniv, 1289 llvm::BitVector &induction, ResultRange results) { 1290 Location loc = op.getLoc(); 1291 unsigned o = 0; 1292 SmallVector<Value, 4> operands; 1293 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 1294 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1295 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1296 unsigned tensor = merger.tensor(b); 1297 assert(idx == merger.index(b)); 1298 Value op1 = codegen.idxs[tensor][idx]; 1299 Value op2 = codegen.loops[idx]; 1300 Value op3 = codegen.pidxs[tensor][idx]; 1301 Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 1302 Value add = rewriter.create<AddIOp>(loc, op3, one); 1303 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 1304 codegen.pidxs[tensor][idx] = results[o++]; 1305 } 1306 } 1307 if (needsUniv) { 1308 operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one)); 1309 codegen.loops[idx] = results[o++]; 1310 } 1311 assert(o == operands.size()); 1312 rewriter.create<scf::YieldOp>(loc, operands); 1313 } 1314 1315 /// Generates a single if-statement within a while-loop. 1316 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1317 PatternRewriter &rewriter, linalg::GenericOp op, 1318 unsigned idx, llvm::BitVector &conditions) { 1319 Location loc = op.getLoc(); 1320 Value cond; 1321 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1322 if (conditions[b]) { 1323 unsigned tensor = merger.tensor(b); 1324 assert(idx == merger.index(b)); 1325 Value clause; 1326 if (merger.isDim(b, Dim::kSparse)) { 1327 Value op1 = codegen.idxs[tensor][idx]; 1328 Value op2 = codegen.loops[idx]; 1329 clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 1330 } else { 1331 clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true 1332 } 1333 cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause; 1334 } 1335 } 1336 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true); 1337 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 1338 return ifOp; 1339 } 1340 1341 /// Recursively generates code while computing iteration lattices in order 1342 /// to manage the complexity of implementing co-iteration over unions 1343 /// and intersections of sparse iterations spaces. 1344 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1345 linalg::GenericOp op, std::vector<unsigned> &topSort, 1346 unsigned exp, unsigned at) { 1347 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1348 if (at == topSort.size()) { 1349 OpOperand *lhs = op.getOutputOperand(0); 1350 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1351 genTensorStore(merger, codegen, rewriter, op, lhs, rhs); 1352 return; 1353 } 1354 assert(codegen.curVecLength == 1); 1355 1356 // Construct iteration lattices for current loop index, with L0 at top. 1357 // Then emit initialization code for the loop sequence at this level. 1358 // We maintain the universal dense index if dense indices are still 1359 // in play for a non-singleton loop sequence. 1360 Location loc = op.getLoc(); 1361 unsigned idx = topSort[at]; 1362 unsigned lts = merger.optimizeSet(buildLattices(merger, op, exp, idx)); 1363 unsigned lsize = merger.set(lts).size(); 1364 assert(lsize != 0); 1365 unsigned l0 = merger.set(lts)[0]; 1366 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1367 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true); 1368 bool needsUniv = false; 1369 if (genInit(merger, codegen, rewriter, op, topSort, at, 1370 merger.lat(l0).bits)) { 1371 // Maintain the universal index only if it is actually 1372 // consumed by a subsequent lattice point. 1373 for (unsigned i = 1; i < lsize; i++) { 1374 unsigned li = merger.set(lts)[i]; 1375 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) { 1376 needsUniv = true; 1377 break; 1378 } 1379 } 1380 } 1381 1382 // Emit a loop for every lattice point L0 >= Li. 1383 for (unsigned i = 0; i < lsize; i++) { 1384 unsigned li = merger.set(lts)[i]; 1385 1386 // Emit loop. 1387 codegen.curVecLength = 1; 1388 llvm::BitVector indices = merger.lat(li).simple; 1389 Operation *loop = 1390 genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices); 1391 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1392 merger.lat(li).bits); 1393 1394 // Visit all lattices points with Li >= Lj to generate the 1395 // loop-body, possibly with if statements for coiteration. 1396 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1397 for (unsigned j = 0; j < lsize; j++) { 1398 unsigned lj = merger.set(lts)[j]; 1399 unsigned ej = merger.lat(lj).exp; 1400 if (li == lj || merger.latGT(li, lj)) { 1401 // Recurse into body of each branch. 1402 if (isWhile) { 1403 scf::IfOp ifOp = 1404 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1405 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1406 rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1407 } else { 1408 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1409 } 1410 } 1411 } 1412 1413 // Wrap-up induction and restore insertion point. 1414 if (isWhile) { 1415 scf::WhileOp whileOp = cast<scf::WhileOp>(loop); 1416 rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1417 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1418 merger.lat(li).bits, whileOp.results()); 1419 } else { 1420 needsUniv = false; 1421 if (codegen.redVal) { 1422 rewriter.create<scf::YieldOp>(loc, codegen.redVal); 1423 codegen.redVal = loop->getResult(0); 1424 } 1425 } 1426 rewriter.setInsertionPointAfter(loop); 1427 } 1428 1429 // Wrap-up loop sequence. 1430 codegen.curVecLength = 1; 1431 genReductionEnd(merger, codegen, rewriter, op); 1432 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); 1433 codegen.loops[idx] = Value(); 1434 } 1435 1436 /// Converts the result computed by the sparse kernel into the required form. 1437 static void genResult(Merger &merger, CodeGen &codegen, 1438 PatternRewriter &rewriter, linalg::GenericOp op) { 1439 Location loc = op.getLoc(); 1440 OpOperand *lhs = op.getOutputOperand(0); 1441 Type resType = lhs->get().getType(); 1442 unsigned tensor = lhs->getOperandNumber(); 1443 auto map = op.getTiedIndexingMap(lhs); 1444 auto enc = getSparseTensorEncoding(resType); 1445 Value result = codegen.buffers.back(); // value array 1446 if (enc) { 1447 // The sparse annotation unambigiously defines the arrays needed 1448 // to "reconstruct" the sparse tensor from the storage scheme 1449 // (even though lowering should never need this eventually). 1450 SmallVector<Value, 4> args; 1451 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1452 unsigned idx = map.getDimPosition(perm(enc, d)); 1453 if (merger.isDim(tensor, idx, Dim::kSparse)) { 1454 args.push_back(codegen.pointers[tensor][idx]); 1455 args.push_back(codegen.indices[tensor][idx]); 1456 } 1457 } 1458 args.push_back(result); 1459 result = rewriter.create<ToTensorOp>(loc, resType, args); 1460 } else { 1461 // To "reconstruct" an non-annotated tensor, sipmly load it 1462 // from the bufferized value. 1463 result = rewriter.create<memref::TensorLoadOp>(loc, resType, result); 1464 } 1465 rewriter.replaceOp(op, result); 1466 } 1467 1468 namespace { 1469 1470 /// Sparse rewriting rule for generic Lingalg operation. 1471 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1472 public: 1473 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1474 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1475 1476 LogicalResult matchAndRewrite(linalg::GenericOp op, 1477 PatternRewriter &rewriter) const override { 1478 // Detects sparse annotations and translate the per-dimension sparsity 1479 // information for all tensors to loop indices in the kernel. 1480 assert(op.getNumOutputs() == 1); 1481 unsigned numTensors = op.getNumInputsAndOutputs(); 1482 unsigned numLoops = op.iterator_types().getValue().size(); 1483 Merger merger(numTensors, numLoops); 1484 if (!findSparseAnnotations(merger, op)) 1485 return failure(); 1486 1487 // Computes a topologically sorted iteration graph to ensure 1488 // tensors are visited in natural index order. Fails on cycles. 1489 // This assumes that higher-level passes have already put the 1490 // tensors in each tensor expression in a feasible order. 1491 std::vector<unsigned> topSort; 1492 if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) && 1493 !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true)) 1494 return failure(); 1495 1496 // Finds the terminating yield statement and builds the tensor 1497 // expression for the Linalg operation in SSA form. 1498 Operation *yield = op.region().front().getTerminator(); 1499 Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0)); 1500 if (!exp.hasValue()) 1501 return failure(); // build failure 1502 1503 // Reject an inadmissable tensor expression. 1504 if (!isAdmissableTensorExp(merger, op, exp.getValue())) 1505 return failure(); 1506 1507 // Recursively generates code. 1508 CodeGen codegen(options, numTensors, numLoops); 1509 if (!genBuffers(merger, codegen, rewriter, op)) 1510 return failure(); // could not bufferize 1511 genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); 1512 genResult(merger, codegen, rewriter, op); 1513 return success(); 1514 } 1515 1516 private: 1517 /// Options to control sparse code generation. 1518 SparsificationOptions options; 1519 }; 1520 1521 } // namespace 1522 1523 /// Populates the given patterns list with rewriting rules required for 1524 /// the sparsification of linear algebra operations. 1525 void mlir::populateSparsificationPatterns( 1526 RewritePatternSet &patterns, const SparsificationOptions &options) { 1527 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1528 } 1529