1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering sparse tensor types to actual sparse code. 10 // 11 // The concept of letting a compiler generate sparse code automatically was 12 // pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and 13 // formalized to tensor algebra by [Kjolstad17,20] for the Sparse Tensor 14 // Algebra Compiler (TACO). The implementation in this file closely follows 15 // the "sparse iteration theory" that forms the foundation of TACO. A rewriting 16 // rule is applied to each tensor expression in linalg (MLIR's tensor index 17 // notation) where the sparsity of tensors is indicated with annotation using 18 // a per-dimension specification of sparse/dense storage together with a 19 // specification of the order on the dimensions. Subsequently, a topologically 20 // sorted iteration graph, reflecting the required order on indices with respect 21 // to the dimensions of each tensor, is constructed to ensure that all tensors 22 // are visited in natural index order. Next, iteration lattices are constructed 23 // for the tensor expression for every index in topological order. Each 24 // iteration lattice point consists of a conjunction of tensor indices together 25 // with a tensor (sub)expression that needs to be evaluated for that 26 // conjunction. Within the lattice, iteration points are ordered according to 27 // the way indices are exhausted. As such these iteration lattices drive actual 28 // sparse code generation, which consists of a tedious but relatively 29 // straightforward one-to-one mapping from iteration lattices to combinations 30 // of for-loops, while-loops, and if-statements. 31 // 32 // [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations. 33 // PhD thesis, Leiden University, May 1996 (aartbik.com/sparse.php). 34 // [Kjolstad17] Fredrik Berg Kjolstad, Shoaib Ashraf Kamil, Stephen Chou, 35 // David Lugato, and Saman Amarasinghe. The Tensor Algebra Compiler. 36 // Proceedings of the ACM on Programming Languages, October 2017. 37 // [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation. 38 // PhD thesis, MIT, February, 2020 (tensor-compiler.org). 39 // 40 // Implementation detail: We use llvm::SmallVector for vectors with 41 // variable lengths and std::vector for vectors with fixed lengths. 42 //===----------------------------------------------------------------------===// 43 44 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 45 #include "mlir/Dialect/Linalg/Utils/Utils.h" 46 #include "mlir/Dialect/MemRef/IR/MemRef.h" 47 #include "mlir/Dialect/SCF/SCF.h" 48 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 49 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 50 #include "mlir/Dialect/StandardOps/IR/Ops.h" 51 #include "mlir/Dialect/Vector/VectorOps.h" 52 #include "mlir/IR/Matchers.h" 53 #include "mlir/IR/TensorEncoding.h" 54 #include "llvm/ADT/SmallBitVector.h" 55 56 using namespace mlir; 57 using namespace mlir::sparse_tensor; 58 59 namespace { 60 61 enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI }; 62 enum class Dim { kSparse, kDense, kSingle, kUndef }; 63 64 /// Tensor expression. Represents a MLIR expression in tensor index notation. 65 /// For tensors, e0 denotes the tensor index. For invariants, the IR value is 66 /// stored directly. For binary operations, e0 and e1 denote the index of the 67 /// children tensor expressions. 68 struct TensorExp { 69 TensorExp(Kind k, unsigned x, unsigned y, Value v) 70 : kind(k), e0(x), e1(y), val(v) { 71 assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) || 72 (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) || 73 (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val)); 74 } 75 Kind kind; 76 /// Indices of children expression(s). 77 unsigned e0; 78 unsigned e1; 79 /// Direct link to IR for an invariant. During code generation, 80 /// field is used to cache "hoisted" loop invariant tensor loads. 81 Value val; 82 }; 83 84 /// Lattice point. Each lattice point consists of a conjunction of tensor 85 /// loop indices (encoded in a bitvector) and the index of the corresponding 86 /// tensor expression. 87 struct LatPoint { 88 LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) { 89 bits.set(b); 90 } 91 LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {} 92 /// Conjunction of tensor loop indices as bitvector. This represents 93 /// all indices involved in the tensor expression 94 llvm::BitVector bits; 95 /// Simplified conjunction of tensor loop indices as bitvector. This 96 /// represents a simplified condition under which this tensor expression 97 /// must execute. Pre-computed during codegen to avoid repeated eval. 98 llvm::BitVector simple; 99 /// Index of the tensor expresssion. 100 unsigned exp; 101 }; 102 103 /// A class to handle all iteration lattice operations. This class abstracts 104 /// away from some implementation details of storing iteration lattices and 105 /// tensor expressions. This allows for fine-tuning performance characteristics 106 /// independently from the basic algorithm if bottlenecks are identified. 107 class Merger { 108 public: 109 /// Constructs a merger for the given number of tensors and loops. The 110 /// user supplies the number of tensors involved in the kernel, with the 111 /// last tensor in this set denoting the output tensor. The merger adds an 112 /// additional synthetic tensor at the end of this set to represent all 113 /// invariant expressions in the kernel. 114 Merger(unsigned t, unsigned l) 115 : outTensor(t - 1), numTensors(t + 1), numLoops(l), 116 dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {} 117 118 /// Adds a tensor expression. Returns its index. 119 unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) { 120 unsigned e = tensorExps.size(); 121 tensorExps.push_back(TensorExp(k, e0, e1, v)); 122 return e; 123 } 124 unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); } 125 126 /// Adds an iteration lattice point. Returns its index. 127 unsigned addLat(unsigned t, unsigned i, unsigned e) { 128 assert(t < numTensors && i < numLoops); 129 unsigned p = latPoints.size(); 130 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 131 return p; 132 } 133 134 /// Adds a new, initially empty, set. Returns its index. 135 unsigned addSet() { 136 unsigned s = latSets.size(); 137 latSets.emplace_back(SmallVector<unsigned, 16>()); 138 return s; 139 } 140 141 /// Computes a single conjunction of two lattice points by taking the "union" 142 /// of loop indices (effectively constructing a larger "intersection" of those 143 /// indices) with a newly constructed tensor (sub)expression of given kind. 144 /// Returns the index of the new lattice point. 145 unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) { 146 unsigned p = latPoints.size(); 147 llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits); 148 nb |= latPoints[p1].bits; 149 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); 150 latPoints.push_back(LatPoint(nb, e)); 151 return p; 152 } 153 154 /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of 155 /// cartesian product. Returns the index of the new set. 156 unsigned takeConj(Kind kind, unsigned s0, unsigned s1) { 157 unsigned s = addSet(); 158 for (unsigned p0 : latSets[s0]) 159 for (unsigned p1 : latSets[s1]) 160 latSets[s].push_back(conjLatPoint(kind, p0, p1)); 161 return s; 162 } 163 164 /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1). 165 /// Returns the index of the new set. 166 unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) { 167 unsigned s = takeConj(kind, s0, s1); 168 for (unsigned p : latSets[s0]) 169 latSets[s].push_back(p); 170 for (unsigned p : latSets[s1]) 171 latSets[s].push_back(p); 172 return s; 173 } 174 175 /// Optimizes the iteration lattice points in the given set. This 176 /// method should be called right before code generation to avoid 177 /// generating redundant loops and conditions. 178 unsigned optimizeSet(unsigned s0) { 179 unsigned s = addSet(); 180 assert(latSets[s0].size() != 0); 181 unsigned p0 = latSets[s0][0]; 182 for (unsigned p1 : latSets[s0]) { 183 bool add = true; 184 if (p0 != p1) { 185 // Is this a straightforward copy? 186 unsigned e = latPoints[p1].exp; 187 if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor) 188 continue; 189 // Conjunction already covered? 190 for (unsigned p2 : latSets[s]) { 191 assert(!latGT(p1, p2)); // Lj => Li would be bad 192 if (onlyDenseDiff(p2, p1)) { 193 add = false; 194 break; 195 } 196 } 197 assert(!add || latGT(p0, p1)); 198 } 199 if (add) 200 latSets[s].push_back(p1); 201 } 202 for (unsigned p : latSets[s]) 203 latPoints[p].simple = simplifyCond(s, p); 204 return s; 205 } 206 207 /// Simplifies the conditions in a conjunction of a given lattice point 208 /// within the given set using just two basic rules: 209 /// (1) multiple dense conditions are reduced to single dense, and 210 /// (2) a *singleton* sparse/dense is reduced to sparse/random access. 211 llvm::BitVector simplifyCond(unsigned s, unsigned p0) { 212 // First determine if this lattice point is a *singleton*, i.e., 213 // the last point in a lattice, no other is less than this one. 214 bool isSingleton = true; 215 for (unsigned p1 : latSets[s]) { 216 if (p0 != p1 && latGT(p0, p1)) { 217 isSingleton = false; 218 break; 219 } 220 } 221 // Now apply the two basic rules. 222 llvm::BitVector simple = latPoints[p0].bits; 223 bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse); 224 for (unsigned b = 0, be = simple.size(); b < be; b++) { 225 if (simple[b] && !isDim(b, Dim::kSparse)) { 226 if (reset) 227 simple.reset(b); 228 reset = true; 229 } 230 } 231 return simple; 232 } 233 234 /// Returns true if Li > Lj. 235 bool latGT(unsigned i, unsigned j) const { 236 const llvm::BitVector &bitsi = latPoints[i].bits; 237 const llvm::BitVector &bitsj = latPoints[j].bits; 238 assert(bitsi.size() == bitsj.size()); 239 if (bitsi.count() > bitsj.count()) { 240 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 241 if (bitsj[b] && !bitsi[b]) 242 return false; 243 return true; 244 } 245 return false; 246 } 247 248 /// Returns true if Li and Lj only differ in dense. 249 bool onlyDenseDiff(unsigned i, unsigned j) { 250 llvm::BitVector tmp = latPoints[j].bits; 251 tmp ^= latPoints[i].bits; 252 return !hasAnyDimOf(tmp, Dim::kSparse); 253 } 254 255 /// Bit translation. 256 unsigned tensor(unsigned b) const { return b % numTensors; } 257 unsigned index(unsigned b) const { return b / numTensors; } 258 259 /// Returns true if bit corresponds to queried dim. 260 bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); } 261 262 /// Returns true if bit corresponds to index of output tensor. 263 bool isOutTensor(unsigned b, unsigned i) const { 264 return tensor(b) == outTensor && index(b) == i; 265 } 266 267 /// Returns true if tensor access at given index has queried dim. 268 bool isDim(unsigned t, unsigned i, Dim d) const { 269 assert(t < numTensors && i < numLoops); 270 return dims[t][i] == d; 271 } 272 273 /// Returns true if any set bit corresponds to queried dim. 274 bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const { 275 for (unsigned b = 0, be = bits.size(); b < be; b++) 276 if (bits[b] && isDim(b, d)) 277 return true; 278 return false; 279 } 280 281 /// Setter 282 void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; } 283 284 /// Getters. 285 TensorExp &exp(unsigned e) { return tensorExps[e]; } 286 LatPoint &lat(unsigned l) { return latPoints[l]; } 287 SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; } 288 289 private: 290 const unsigned outTensor; 291 const unsigned numTensors; 292 const unsigned numLoops; 293 294 std::vector<std::vector<Dim>> dims; 295 llvm::SmallVector<TensorExp, 32> tensorExps; 296 llvm::SmallVector<LatPoint, 16> latPoints; 297 llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets; 298 }; 299 300 // Code generation. 301 struct CodeGen { 302 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) 303 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 304 pointers(numTensors, std::vector<Value>(numLoops)), 305 indices(numTensors, std::vector<Value>(numLoops)), 306 highs(numTensors, std::vector<Value>(numLoops)), 307 pidxs(numTensors, std::vector<Value>(numLoops)), 308 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 309 curVecLength(1), curVecMask() {} 310 /// Sparsification options. 311 SparsificationOptions options; 312 /// Universal dense indices and upper bounds (by index). The loops array 313 /// is updated with the value of the universal dense index in the current 314 /// loop. The sizes array is set once with the inferred dimension sizes. 315 std::vector<Value> loops; 316 std::vector<Value> sizes; 317 /// Buffers for storing dense and sparse numerical values (by tensor). 318 /// This array is set once during bufferization of all tensors. 319 std::vector<Value> buffers; 320 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 321 /// This array is set once during bufferization of all sparse tensors. 322 std::vector<std::vector<Value>> pointers; 323 std::vector<std::vector<Value>> indices; 324 /// Sparse iteration information (by tensor and index). These arrays 325 /// are updated to remain current within the current loop. 326 std::vector<std::vector<Value>> highs; 327 std::vector<std::vector<Value>> pidxs; 328 std::vector<std::vector<Value>> idxs; 329 /// Current reduction, updated during code generation. When indices of a 330 /// reduction are exhausted, all inner loops can "scalarize" the reduction. 331 // TODO: currently only done for (a chain of) innermost for-loops, where it 332 // is most effective; we could generalize to more outer and while-loops. 333 unsigned redExp; 334 Value redVal; 335 // Current vector length and mask. 336 unsigned curVecLength; 337 Value curVecMask; 338 }; 339 340 } // namespace 341 342 // Helper method to apply dimension ordering permutation. 343 static unsigned perm(SparseTensorEncodingAttr &enc, unsigned d) { 344 if (enc) { 345 auto order = enc.getDimOrdering(); 346 if (order) { 347 assert(order.isPermutation()); 348 return order.getDimPosition(d); 349 } 350 } 351 return d; 352 } 353 354 // Helper method to translate dim level type to internal representation. 355 static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) { 356 if (enc) { 357 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 358 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 359 return Dim::kSparse; 360 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 361 return Dim::kSingle; 362 } 363 return Dim::kDense; 364 } 365 366 /// Helper method to inspect sparse encodings in the tensor types. 367 /// Fills the per-dimension sparsity information for all tensors. 368 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 369 bool annotated = false; 370 OpOperand *lhs = op.getOutputOperand(0); 371 for (OpOperand *t : op.getInputAndOutputOperands()) { 372 auto map = op.getTiedIndexingMap(t); 373 if (!map.isProjectedPermutation()) 374 return false; 375 auto enc = getSparseTensorEncoding(t->get().getType()); 376 if (enc) 377 annotated = true; 378 assert(map.getNumResults() == op.getRank(t)); 379 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 380 unsigned idx = map.getDimPosition(perm(enc, d)); 381 Dim dim = toDim(enc, d); 382 merger.setDim(t->getOperandNumber(), idx, toDim(enc, d)); 383 // Accept only all-dense annotated "sparse" output. 384 // TODO: support truly sparse outputs too 385 if (t == lhs && dim != Dim::kDense) 386 return false; 387 } 388 } 389 return annotated; 390 } 391 392 /// A DFS helper to compute a topological sort. Note that recursion is 393 /// bounded by the number of implicit loops, which is always small. 394 /// Returns false when a cycle is detected. 395 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 396 std::vector<unsigned> &topSort, 397 std::vector<std::vector<bool>> &adjM) { 398 if (visit[i] != 0) 399 return visit[i] != 1; // 1 denotes cycle! 400 visit[i] = 1; 401 for (unsigned j = 0, e = visit.size(); j < e; j++) 402 if (adjM[i][j]) 403 if (!topSortDFS(j, visit, topSort, adjM)) 404 return false; 405 visit[i] = 2; 406 topSort.push_back(i); 407 return true; 408 } 409 410 /// Computes a topologically sorted iteration graph for the linalg operation. 411 /// Ensures all tensors are visited in natural index order. This is essential 412 /// for sparse storage formats since these only support access along fixed 413 /// dimensions. Even for dense storage formats, however, the natural index 414 /// order yields innermost unit-stride access with better spatial locality. 415 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 416 std::vector<unsigned> &topSort, 417 bool sparseOnly) { 418 // Set up an n x n from/to adjacency matrix of the iteration graph 419 // for the implicit loop indices i_0 .. i_n-1. 420 unsigned n = op.getNumLoops(); 421 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 422 423 // Iterate over the indexing maps of every tensor in the tensor expression. 424 for (OpOperand *t : op.getInputAndOutputOperands()) { 425 auto map = op.getTiedIndexingMap(t); 426 auto enc = getSparseTensorEncoding(t->get().getType()); 427 assert(map.getNumDims() == n); 428 // Skip dense tensor constraints when sparse only is requested. 429 if (sparseOnly && !enc) 430 continue; 431 // Each tensor expression and optional dimension ordering (row-major 432 // by default) puts an ordering constraint on the loop indices. For 433 // example, the tensor expresion A_ijk forces the ordering i < j < k 434 // on the loop indices if no explicit dimension ordering is given. 435 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 436 unsigned f = map.getDimPosition(perm(enc, d - 1)); 437 unsigned t = map.getDimPosition(perm(enc, d)); 438 adjM[f][t] = true; 439 } 440 } 441 442 // Topologically sort the iteration graph to determine loop order. 443 // Report failure for a cyclic iteration graph. 444 topSort.clear(); 445 topSort.reserve(n); 446 std::vector<unsigned> visit(n, 0); 447 for (unsigned i = 0; i < n; i++) 448 if (visit[i] == 0) 449 if (!topSortDFS(i, visit, topSort, adjM)) 450 return false; // cycle! 451 std::reverse(std::begin(topSort), std::end(topSort)); 452 return true; 453 } 454 455 /// Traverses the SSA tree (possibly a DAG) to build a tensor expression. 456 /// This simplifies constructing (sub)expressions during iteration lattice 457 /// building (compared to using the SSA representation everywhere). 458 static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op, 459 Value val) { 460 if (auto arg = val.dyn_cast<BlockArgument>()) { 461 unsigned argN = arg.getArgNumber(); 462 // Any argument of the generic op that is not marked as a scalar 463 // argument is considered a tensor, indexed by the implicit loop 464 // bounds. This includes rank-0 tensor arguments. 465 if (arg.getOwner()->getParentOp() == op) { 466 OpOperand *t = op.getInputAndOutputOperands()[argN]; 467 if (!op.isScalar(t)) 468 return merger.addExp(Kind::kTensor, argN); 469 val = t->get(); // get scalar value 470 } 471 // Any other argument (marked as scalar argument for the generic op 472 // or belonging to an enveloping op) is considered invariant. 473 return merger.addExp(Kind::kInvariant, val); 474 } 475 Operation *def = val.getDefiningOp(); 476 if (def->getBlock() != &op.region().front()) { 477 // Something defined outside is invariant. 478 return merger.addExp(Kind::kInvariant, val); 479 } else if (def->getNumOperands() == 2) { 480 // Construct binary operations if subexpressions could be built. 481 auto x = buildTensorExp(merger, op, def->getOperand(0)); 482 auto y = buildTensorExp(merger, op, def->getOperand(1)); 483 if (x.hasValue() && y.hasValue()) { 484 unsigned e0 = x.getValue(); 485 unsigned e1 = y.getValue(); 486 if (isa<MulFOp>(def)) 487 return merger.addExp(Kind::kMulF, e0, e1); 488 if (isa<MulIOp>(def)) 489 return merger.addExp(Kind::kMulI, e0, e1); 490 if (isa<AddFOp>(def)) 491 return merger.addExp(Kind::kAddF, e0, e1); 492 if (isa<AddIOp>(def)) 493 return merger.addExp(Kind::kAddI, e0, e1); 494 } 495 } 496 // Cannot build (yet). 497 return None; 498 } 499 500 /// Builds the iteration lattices in a bottom-up traversal given the remaining 501 /// tensor (sub)expression and the next loop index in the iteration graph. 502 static unsigned buildLattices(Merger &merger, linalg::GenericOp op, 503 unsigned exp, unsigned idx) { 504 Kind kind = merger.exp(exp).kind; 505 if (kind == Kind::kTensor || kind == Kind::kInvariant) { 506 // Either the index is really used in the tensor expression, or it is 507 // set to the undefined index in that dimension. An invariant expression 508 // is set to a synthetic tensor with undefined indices only. 509 unsigned s = merger.addSet(); 510 unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0 511 : op.getNumInputsAndOutputs(); 512 merger.set(s).push_back(merger.addLat(t, idx, exp)); 513 return s; 514 } 515 unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx); 516 unsigned s1 = buildLattices(merger, op, merger.exp(exp).e1, idx); 517 switch (kind) { 518 case Kind::kTensor: 519 case Kind::kInvariant: 520 llvm_unreachable("handled above"); 521 case Kind::kMulF: 522 case Kind::kMulI: 523 return merger.takeConj(kind, s0, s1); 524 case Kind::kAddF: 525 case Kind::kAddI: 526 return merger.takeDisj(kind, s0, s1); 527 } 528 llvm_unreachable("unexpected expression kind"); 529 } 530 531 /// Maps sparse integer option to actual integral storage type. 532 static Type genIntType(PatternRewriter &rewriter, unsigned width) { 533 if (width == 0) 534 return rewriter.getIndexType(); 535 return rewriter.getIntegerType(width); 536 } 537 538 /// Detects in-place annotation on tensor argument. 539 static bool getInPlace(Value val) { 540 if (auto arg = val.dyn_cast<BlockArgument>()) 541 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 542 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 543 arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) 544 return attr.getValue(); 545 return false; 546 } 547 548 /// Generates buffer for the output tensor. 549 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 550 linalg::GenericOp op, MemRefType denseTp, 551 ArrayRef<Value> args) { 552 Location loc = op.getLoc(); 553 Value tensor = op.getOutputOperand(0)->get(); 554 // The output tensor simply could materialize from the buffer that will 555 // be generated for the tensor present in the outs() clause. This has 556 // the major advantage that the sparse kernel only updates the nonzero 557 // positions for the output tensor. 558 if (getInPlace(tensor)) 559 return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 560 // By default, a new buffer is allocated which is initialized to the 561 // tensor defined in the outs() clause. This is always correct but 562 // introduces a dense initialization component that may negatively 563 // impact the running complexity of the sparse kernel. 564 Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 565 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 566 rewriter.create<linalg::CopyOp>(loc, init, alloc); 567 return alloc; 568 } 569 570 /// Local bufferization of all dense and sparse data structures. 571 /// This code enables testing the first prototype sparse compiler. 572 // TODO: replace this with a proliferated bufferization strategy 573 static bool genBuffers(Merger &merger, CodeGen &codegen, 574 PatternRewriter &rewriter, linalg::GenericOp op) { 575 Location loc = op.getLoc(); 576 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 577 // For every tensor, find lower and upper bound on dimensions, set the 578 // same bounds on loop indices, and obtain dense or sparse buffer(s). 579 SmallVector<Value, 4> args; 580 for (OpOperand *t : op.getInputAndOutputOperands()) { 581 unsigned tensor = t->getOperandNumber(); 582 auto shape = op.getShape(t); 583 auto map = op.getTiedIndexingMap(t); 584 auto enc = getSparseTensorEncoding(t->get().getType()); 585 // Scan all dimensions of current tensor. 586 args.clear(); 587 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 588 unsigned idx = map.getDimPosition(perm(enc, d)); 589 // Handle sparse storage schemes. 590 if (merger.isDim(tensor, idx, Dim::kSparse)) { 591 auto dynShape = {ShapedType::kDynamicSize}; 592 auto ptrTp = MemRefType::get( 593 dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 594 auto indTp = MemRefType::get( 595 dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 596 Value dim = rewriter.create<ConstantIndexOp>(loc, d); 597 // Generate sparse primitives to obtains pointer and indices. 598 codegen.pointers[tensor][idx] = 599 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 600 codegen.indices[tensor][idx] = 601 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 602 } 603 // Find lower and upper bound in current dimension. 604 Value up; 605 if (shape[d] == MemRefType::kDynamicSize) { 606 up = rewriter.create<memref::DimOp>(loc, t->get(), d); 607 args.push_back(up); 608 } else { 609 up = rewriter.create<ConstantIndexOp>(loc, shape[d]); 610 } 611 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 612 } 613 // Perform the required bufferization. Dense inputs materialize 614 // from the input tensors. Dense outputs need special handling. 615 // Sparse inputs use sparse primitives to obtain the values. 616 // We also accept in-place all-dense annotated "sparse" outputs. 617 Type elementType = getElementTypeOrSelf(t->get().getType()); 618 if (!enc) { 619 // Non-annotated dense tensors. 620 auto denseTp = MemRefType::get(shape, elementType); 621 if (tensor < op.getNumInputs()) 622 codegen.buffers[tensor] = 623 rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get()); 624 else 625 codegen.buffers[tensor] = 626 genOutputBuffer(codegen, rewriter, op, denseTp, args); 627 } else { 628 // Annotated sparse tensors. 629 if (tensor == op.getNumInputs() && !getInPlace(t->get())) 630 return false; // reject output if not in-place 631 auto dynShape = {ShapedType::kDynamicSize}; 632 auto sparseTp = MemRefType::get(dynShape, elementType); 633 codegen.buffers[tensor] = 634 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 635 } 636 } 637 return true; 638 } 639 640 /// Constructs vector type. 641 static VectorType vectorType(CodeGen &codegen, Type etp) { 642 return VectorType::get(codegen.curVecLength, etp); 643 } 644 645 /// Constructs vector type from pointer. 646 static VectorType vectorType(CodeGen &codegen, Value ptr) { 647 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 648 } 649 650 /// Constructs vector iteration mask. 651 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 652 Value iv, Value lo, Value hi, Value step) { 653 Location loc = iv.getLoc(); 654 VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1)); 655 // Special case if the vector length evenly divides the trip count (for 656 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 657 // so that all subsequent masked memory operations are immediately folded 658 // into unconditional memory operations. 659 IntegerAttr loInt, hiInt, stepInt; 660 if (matchPattern(lo, m_Constant(&loInt)) && 661 matchPattern(hi, m_Constant(&hiInt)) && 662 matchPattern(step, m_Constant(&stepInt))) { 663 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 664 return rewriter.create<vector::BroadcastOp>( 665 loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1)); 666 } 667 // Otherwise, generate a vector mask that avoids overrunning the upperbound 668 // during vector execution. Here we rely on subsequent loop optimizations to 669 // avoid executing the mask in all iterations, for example, by splitting the 670 // loop into an unconditional vector loop and a scalar cleanup loop. 671 Value end = rewriter.create<SubIOp>(loc, hi, iv); 672 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 673 } 674 675 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 676 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 677 Value ptr, ArrayRef<Value> args) { 678 Location loc = ptr.getLoc(); 679 VectorType vtp = vectorType(codegen, ptr); 680 Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 681 if (args.back().getType().isa<VectorType>()) { 682 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 683 Value indexVec = args.back(); 684 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 685 return rewriter.create<vector::GatherOp>( 686 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 687 } 688 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 689 codegen.curVecMask, pass); 690 } 691 692 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 693 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 694 Value rhs, Value ptr, ArrayRef<Value> args) { 695 Location loc = ptr.getLoc(); 696 if (args.back().getType().isa<VectorType>()) { 697 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 698 Value indexVec = args.back(); 699 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 700 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 701 codegen.curVecMask, rhs); 702 return; 703 } 704 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 705 rhs); 706 } 707 708 /// Generates a vectorized invariant. Here we rely on subsequent loop 709 /// optimizations to hoist the invariant broadcast out of the vector loop. 710 static Value genVectorInvariantValue(CodeGen &codegen, 711 PatternRewriter &rewriter, Value val) { 712 VectorType vtp = vectorType(codegen, val.getType()); 713 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 714 } 715 716 /// Generates a load on a dense or sparse tensor. 717 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 718 PatternRewriter &rewriter, linalg::GenericOp op, 719 unsigned exp) { 720 // Test if the load was hoisted to a higher loop nest. 721 Value val = merger.exp(exp).val; 722 if (val) { 723 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 724 return genVectorInvariantValue(codegen, rewriter, val); 725 return val; 726 } 727 // Actual load. 728 SmallVector<Value, 4> args; 729 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0]; 730 unsigned tensor = t->getOperandNumber(); 731 auto map = op.getTiedIndexingMap(t); 732 auto enc = getSparseTensorEncoding(t->get().getType()); 733 unsigned rank = map.getNumResults(); 734 if (enc) { 735 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 736 assert(codegen.pidxs[tensor][idx] != nullptr); 737 args.push_back(codegen.pidxs[tensor][idx]); // position index 738 } else { 739 for (unsigned d = 0; d < rank; d++) { 740 unsigned idx = map.getDimPosition(d); 741 args.push_back(codegen.loops[idx]); // universal dense index 742 } 743 } 744 Location loc = op.getLoc(); 745 Value ptr = codegen.buffers[tensor]; 746 if (codegen.curVecLength > 1) 747 return genVectorLoad(codegen, rewriter, ptr, args); 748 return rewriter.create<memref::LoadOp>(loc, ptr, args); 749 } 750 751 /// Generates a store on a dense or sparse tensor. 752 static void genTensorStore(Merger &merger, CodeGen &codegen, 753 PatternRewriter &rewriter, linalg::GenericOp op, 754 OpOperand *t, Value rhs) { 755 Location loc = op.getLoc(); 756 // Test if this is a scalarized reduction. 757 OpOperand *lhs = op.getOutputOperand(0); 758 if (lhs == t && codegen.redVal) { 759 if (codegen.curVecLength > 1) 760 rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs, 761 codegen.redVal); 762 codegen.redVal = rhs; 763 return; 764 } 765 // Actual store. 766 SmallVector<Value, 4> args; 767 unsigned tensor = t->getOperandNumber(); 768 auto map = op.getTiedIndexingMap(t); 769 auto enc = getSparseTensorEncoding(t->get().getType()); 770 unsigned rank = map.getNumResults(); 771 if (enc) { 772 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 773 assert(codegen.pidxs[tensor][idx] != nullptr); 774 args.push_back(codegen.pidxs[tensor][idx]); // position index 775 } else { 776 for (unsigned d = 0; d < rank; d++) { 777 unsigned idx = map.getDimPosition(d); 778 args.push_back(codegen.loops[idx]); // universal dense index 779 } 780 } 781 Value ptr = codegen.buffers[tensor]; 782 if (codegen.curVecLength > 1) 783 genVectorStore(codegen, rewriter, rhs, ptr, args); 784 else 785 rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 786 } 787 788 /// Generates a pointer/index load from the sparse storage scheme. Narrower 789 /// data types need to be zero extended before casting the value into the 790 /// index type used for looping and indexing. 791 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 792 Value ptr, Value s) { 793 // See https://llvm.org/docs/GetElementPtr.html for some background on 794 // the complications described below. 795 if (codegen.curVecLength > 1) { 796 // Since the index vector is used in a subsequent gather/scatter operations, 797 // which effectively defines an unsigned pointer + signed index, we must 798 // zero extend the vector to an index width. For 8-bit and 16-bit values, 799 // an 32-bit index width suffices. For 32-bit values, zero extending the 800 // elements into 64-bit loses some performance since the 32-bit indexed 801 // gather/scatter is more efficient than the 64-bit index variant (if the 802 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 803 // preserve this performance). For 64-bit values, there is no good way 804 // to state that the indices are unsigned, with creates the potential of 805 // incorrect address calculations in the unlikely case we need such 806 // extremely large offsets. 807 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 808 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 809 if (!etp.isa<IndexType>()) { 810 if (etp.getIntOrFloatBitWidth() < 32) 811 vload = rewriter.create<ZeroExtendIOp>( 812 loc, vload, vectorType(codegen, rewriter.getIntegerType(32))); 813 else if (etp.getIntOrFloatBitWidth() < 64 && 814 !codegen.options.enableSIMDIndex32) 815 vload = rewriter.create<ZeroExtendIOp>( 816 loc, vload, vectorType(codegen, rewriter.getIntegerType(64))); 817 } 818 return vload; 819 } 820 // For the scalar case, we simply zero extend narrower indices into 64-bit 821 // values before casting to index without a performance penalty. Here too, 822 // however, indices that already are 64-bit, in theory, cannot express the 823 // full range as explained above. 824 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 825 if (!load.getType().isa<IndexType>()) { 826 if (load.getType().getIntOrFloatBitWidth() < 64) 827 load = rewriter.create<ZeroExtendIOp>(loc, load, 828 rewriter.getIntegerType(64)); 829 load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType()); 830 } 831 return load; 832 } 833 834 /// Generates an invariant value. 835 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 836 PatternRewriter &rewriter, unsigned exp) { 837 Value val = merger.exp(exp).val; 838 if (codegen.curVecLength > 1) 839 return genVectorInvariantValue(codegen, rewriter, val); 840 return val; 841 } 842 843 /// Generates an address computation "sz * p + i". 844 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 845 Location loc, Value size, Value p, Value i) { 846 Value mul = rewriter.create<MulIOp>(loc, size, p); 847 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 848 Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType()); 849 mul = genVectorInvariantValue(codegen, rewriter, inv); 850 } 851 return rewriter.create<AddIOp>(loc, mul, i); 852 } 853 854 /// Generates start of a reduction. 855 static Value genReductionStart(Merger &merger, CodeGen &codegen, 856 PatternRewriter &rewriter, 857 linalg::GenericOp op) { 858 if (codegen.redVal) 859 return codegen.redVal; // chained with previous for-loop 860 if (codegen.curVecLength > 1) { 861 // TODO: assumes + reductions for now 862 VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]); 863 return rewriter.create<ConstantOp>(op.getLoc(), vtp, 864 rewriter.getZeroAttr(vtp)); 865 } 866 return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 867 } 868 869 /// Generates end of a reduction. 870 static void genReductionEnd(Merger &merger, CodeGen &codegen, 871 PatternRewriter &rewriter, linalg::GenericOp op) { 872 Value red = codegen.redVal; 873 if (!red) 874 return; 875 assert(codegen.curVecLength == 1); 876 codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain 877 OpOperand *lhs = op.getOutputOperand(0); 878 if (auto vtp = red.getType().dyn_cast<VectorType>()) { 879 // TODO: assumes + reductions for now 880 StringAttr kind = rewriter.getStringAttr("add"); 881 Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 882 // Integer reductions don't accept an accumulator. 883 if (vtp.getElementType().isa<IntegerType>()) { 884 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 885 kind, red, ValueRange{}); 886 red = rewriter.create<AddIOp>(op.getLoc(), red, ld); 887 } else { 888 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 889 kind, red, ld); 890 } 891 } 892 genTensorStore(merger, codegen, rewriter, op, lhs, red); 893 } 894 895 /// Recursively generates tensor expression. 896 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 897 linalg::GenericOp op, unsigned exp) { 898 if (merger.exp(exp).kind == Kind::kTensor) 899 return genTensorLoad(merger, codegen, rewriter, op, exp); 900 else if (merger.exp(exp).kind == Kind::kInvariant) 901 return genInvariantValue(merger, codegen, rewriter, exp); 902 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0); 903 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1); 904 switch (merger.exp(exp).kind) { 905 case Kind::kTensor: 906 case Kind::kInvariant: 907 llvm_unreachable("handled above"); 908 case Kind::kMulF: 909 return rewriter.create<MulFOp>(op.getLoc(), v0, v1); 910 case Kind::kMulI: 911 return rewriter.create<MulIOp>(op.getLoc(), v0, v1); 912 case Kind::kAddF: 913 return rewriter.create<AddFOp>(op.getLoc(), v0, v1); 914 case Kind::kAddI: 915 return rewriter.create<AddIOp>(op.getLoc(), v0, v1); 916 } 917 llvm_unreachable("unexpected expression kind"); 918 } 919 920 /// Hoists loop invariant tensor loads for which indices have been exhausted. 921 static void genInvariants(Merger &merger, CodeGen &codegen, 922 PatternRewriter &rewriter, linalg::GenericOp op, 923 unsigned exp, unsigned ldx, bool hoist) { 924 if (merger.exp(exp).kind == Kind::kTensor) { 925 // Inspect tensor indices. 926 bool atLevel = ldx == -1u; 927 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0]; 928 auto map = op.getTiedIndexingMap(t); 929 auto enc = getSparseTensorEncoding(t->get().getType()); 930 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 931 unsigned idx = map.getDimPosition(perm(enc, d)); 932 if (!codegen.loops[idx]) 933 return; // still in play 934 else if (idx == ldx) 935 atLevel = true; 936 } 937 // All exhausted at this level (atLevel denotes exactly at this level). 938 OpOperand *lhs = op.getOutputOperand(0); 939 if (lhs == t) { 940 codegen.redExp = hoist ? exp : -1u; 941 } else if (atLevel) { 942 merger.exp(exp).val = 943 hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 944 } 945 } else if (merger.exp(exp).kind != Kind::kInvariant) { 946 // Traverse into the binary operations. Note that we only hoist 947 // tensor loads, since subsequent MLIR/LLVM passes know how to 948 // deal with all other kinds of derived loop invariants. 949 unsigned e0 = merger.exp(exp).e0; 950 unsigned e1 = merger.exp(exp).e1; 951 genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist); 952 genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist); 953 } 954 } 955 956 /// Generates initialization code for the subsequent loop sequence at 957 /// current index level. Returns true if the loop sequence needs to 958 /// maintain the universal index. 959 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 960 linalg::GenericOp op, std::vector<unsigned> &topSort, 961 unsigned at, llvm::BitVector &inits) { 962 bool needsUniv = false; 963 Location loc = op.getLoc(); 964 unsigned idx = topSort[at]; 965 966 // Initialize sparse positions. 967 for (unsigned b = 0, be = inits.size(); b < be; b++) { 968 if (inits[b]) { 969 unsigned tensor = merger.tensor(b); 970 assert(idx == merger.index(b)); 971 if (merger.isDim(b, Dim::kSparse)) { 972 // Initialize sparse index. 973 unsigned pat = at; 974 for (; pat != 0; pat--) { 975 if (codegen.pidxs[tensor][topSort[pat - 1]]) 976 break; 977 } 978 Value ptr = codegen.pointers[tensor][idx]; 979 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 980 Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 981 : codegen.pidxs[tensor][topSort[pat - 1]]; 982 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 983 Value p1 = rewriter.create<AddIOp>(loc, p0, one); 984 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 985 } else { 986 // Dense index still in play. 987 needsUniv = true; 988 } 989 } 990 } 991 992 // Initialize the universal dense index. 993 codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0); 994 return needsUniv; 995 } 996 997 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 998 /// operation is a candidate. Whether it is actually converted to SIMD code 999 /// depends on the requested strategy. 1000 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 1001 switch (codegen.options.vectorizationStrategy) { 1002 case SparseVectorizationStrategy::kNone: 1003 return false; 1004 case SparseVectorizationStrategy::kDenseInnerLoop: 1005 return isInner && !isSparse; 1006 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 1007 return isInner; 1008 } 1009 llvm_unreachable("unexpected vectorization strategy"); 1010 } 1011 1012 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 1013 /// that is marked "parallel" is a candidate. Whether it is actually converted 1014 /// to a parallel operation depends on the requested strategy. 1015 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 1016 bool isSparse, bool isVector) { 1017 switch (codegen.options.parallelizationStrategy) { 1018 case SparseParallelizationStrategy::kNone: 1019 return false; 1020 case SparseParallelizationStrategy::kDenseOuterLoop: 1021 return isOuter && !isSparse && !isReduction && !isVector; 1022 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 1023 return isOuter && !isReduction && !isVector; 1024 case SparseParallelizationStrategy::kDenseAnyLoop: 1025 return !isSparse && !isReduction && !isVector; 1026 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 1027 return !isReduction && !isVector; 1028 } 1029 llvm_unreachable("unexpected parallelization strategy"); 1030 } 1031 1032 /// Checks unit strides for dense tensors. The iteration graph may have ignored 1033 /// dense access patterns in order to avoid cycles (sparse access patterns are 1034 /// always placed innermost), but that means dense access has become strided. 1035 /// For now, we reject vectorization of such cases. 1036 /// TODO: implement strided load/stores on dense arrays 1037 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 1038 unsigned idx) { 1039 for (OpOperand *t : op.getInputAndOutputOperands()) { 1040 if (!getSparseTensorEncoding(t->get().getType())) { 1041 auto map = op.getTiedIndexingMap(t); 1042 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1043 if (map.getDimPosition(d) == idx && d != rank - 1) 1044 return false; 1045 } 1046 } 1047 } 1048 return true; 1049 } 1050 1051 /// Generates a for-loop on a single index. 1052 static Operation *genFor(Merger &merger, CodeGen &codegen, 1053 PatternRewriter &rewriter, linalg::GenericOp op, 1054 bool isOuter, bool isInner, unsigned idx, 1055 llvm::BitVector &indices) { 1056 unsigned fb = indices.find_first(); 1057 unsigned tensor = merger.tensor(fb); 1058 assert(idx == merger.index(fb)); 1059 auto iteratorTypes = op.iterator_types().getValue(); 1060 bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]); 1061 bool isSparse = merger.isDim(fb, Dim::kSparse); 1062 bool isVector = isVectorFor(codegen, isInner, isSparse) && 1063 denseUnitStrides(merger, op, idx); 1064 bool isParallel = 1065 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 1066 1067 // Prepare vector length. 1068 if (isVector) 1069 codegen.curVecLength = codegen.options.vectorLength; 1070 1071 // Loop bounds and increment. 1072 Location loc = op.getLoc(); 1073 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 1074 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 1075 Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength); 1076 1077 // Emit a parallel loop. 1078 if (isParallel) { 1079 assert(!isVector); 1080 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 1081 if (isSparse) 1082 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1083 else 1084 codegen.loops[idx] = parOp.getInductionVars()[0]; 1085 rewriter.setInsertionPointToStart(parOp.getBody()); 1086 return parOp; 1087 } 1088 1089 // Emit a sequential loop, potentially with a scalarized reduction. 1090 bool scalarRed = isInner && codegen.redExp != -1u; 1091 SmallVector<Value, 4> operands; 1092 if (scalarRed) { 1093 Value load = genReductionStart(merger, codegen, rewriter, op); 1094 operands.push_back(load); 1095 } 1096 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 1097 if (scalarRed) { 1098 codegen.redVal = merger.exp(codegen.redExp).val = 1099 forOp.getRegionIterArgs().front(); 1100 } 1101 // Assign induction variable to sparse or dense index. 1102 Value iv = forOp.getInductionVar(); 1103 if (isSparse) 1104 codegen.pidxs[tensor][idx] = iv; 1105 else 1106 codegen.loops[idx] = iv; 1107 rewriter.setInsertionPointToStart(forOp.getBody()); 1108 // Share vector iteration mask between all subsequent loads/stores. 1109 if (isVector) 1110 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1111 return forOp; 1112 } 1113 1114 /// Emit a while-loop for co-iteration over multiple indices. 1115 static Operation *genWhile(Merger &merger, CodeGen &codegen, 1116 PatternRewriter &rewriter, linalg::GenericOp op, 1117 unsigned idx, bool needsUniv, 1118 llvm::BitVector &indices) { 1119 SmallVector<Type, 4> types; 1120 SmallVector<Value, 4> operands; 1121 // Construct the while-loop with a parameter for each index. 1122 Type indexType = rewriter.getIndexType(); 1123 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1124 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1125 unsigned tensor = merger.tensor(b); 1126 assert(idx == merger.index(b)); 1127 types.push_back(indexType); 1128 assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() && 1129 "type mismatch for sparse index"); 1130 operands.push_back(codegen.pidxs[tensor][idx]); 1131 } 1132 } 1133 if (needsUniv) { 1134 types.push_back(indexType); 1135 assert(codegen.loops[idx].getType().isa<IndexType>() && 1136 "type mismatch for universal index"); 1137 operands.push_back(codegen.loops[idx]); 1138 } 1139 Location loc = op.getLoc(); 1140 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1141 Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 1142 Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 1143 1144 // Build the "before" region, which effectively consists 1145 // of a conjunction of "i < upper" tests on all induction. 1146 rewriter.setInsertionPointToStart(&whileOp.before().front()); 1147 Value cond; 1148 unsigned o = 0; 1149 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1150 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1151 unsigned tensor = merger.tensor(b); 1152 assert(idx == merger.index(b)); 1153 Value op1 = before->getArgument(o); 1154 Value op2 = codegen.highs[tensor][idx]; 1155 Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2); 1156 cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc; 1157 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1158 } 1159 } 1160 if (needsUniv) 1161 codegen.loops[idx] = after->getArgument(o++); 1162 assert(o == operands.size()); 1163 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1164 rewriter.setInsertionPointToStart(&whileOp.after().front()); 1165 return whileOp; 1166 } 1167 1168 /// Generates a for-loop or a while-loop, depending on whether it implements 1169 /// singleton iteration or co-iteration over the given conjunction. 1170 static Operation *genLoop(Merger &merger, CodeGen &codegen, 1171 PatternRewriter &rewriter, linalg::GenericOp op, 1172 std::vector<unsigned> &topSort, unsigned at, 1173 bool needsUniv, llvm::BitVector &indices) { 1174 unsigned idx = topSort[at]; 1175 if (indices.count() == 1) { 1176 bool isOuter = at == 0; 1177 bool isInner = at == topSort.size() - 1; 1178 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1179 indices); 1180 } 1181 genReductionEnd(merger, codegen, rewriter, op); // cannot chain 1182 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1183 } 1184 1185 /// Generates the local variables for this loop, consisting of the sparse 1186 /// indices, restored universal dense index, and dense positions. 1187 static void genLocals(Merger &merger, CodeGen &codegen, 1188 PatternRewriter &rewriter, linalg::GenericOp op, 1189 std::vector<unsigned> &topSort, unsigned at, 1190 bool needsUniv, llvm::BitVector &locals) { 1191 Location loc = op.getLoc(); 1192 unsigned idx = topSort[at]; 1193 1194 // Initialize sparse indices. 1195 Value min; 1196 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1197 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1198 unsigned tensor = merger.tensor(b); 1199 assert(idx == merger.index(b)); 1200 Value ptr = codegen.indices[tensor][idx]; 1201 Value s = codegen.pidxs[tensor][idx]; 1202 Value load = genLoad(codegen, rewriter, loc, ptr, s); 1203 codegen.idxs[tensor][idx] = load; 1204 if (!needsUniv) { 1205 if (min) { 1206 Value cmp = 1207 rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min); 1208 min = rewriter.create<SelectOp>(loc, cmp, load, min); 1209 } else { 1210 min = load; 1211 } 1212 } 1213 } 1214 } 1215 1216 // Merge dense universal index over minimum. 1217 if (min) { 1218 assert(!needsUniv); 1219 codegen.loops[idx] = min; 1220 } 1221 1222 // Initialize dense positions. Note that we generate dense indices of the 1223 // output tensor unconditionally, since they may not appear in the lattice, 1224 // but may be needed for linearized codegen. 1225 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1226 if ((locals[b] || merger.isOutTensor(b, idx)) && 1227 merger.isDim(b, Dim::kDense)) { 1228 unsigned tensor = merger.tensor(b); 1229 assert(idx == merger.index(b)); 1230 unsigned pat = at; 1231 for (; pat != 0; pat--) 1232 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1233 break; 1234 Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 1235 : codegen.pidxs[tensor][topSort[pat - 1]]; 1236 codegen.pidxs[tensor][idx] = genAddress( 1237 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1238 } 1239 } 1240 } 1241 1242 /// Generates the induction structure for a while-loop. 1243 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1244 PatternRewriter &rewriter, linalg::GenericOp op, 1245 unsigned idx, bool needsUniv, 1246 llvm::BitVector &induction, ResultRange results) { 1247 Location loc = op.getLoc(); 1248 unsigned o = 0; 1249 SmallVector<Value, 4> operands; 1250 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 1251 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1252 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1253 unsigned tensor = merger.tensor(b); 1254 assert(idx == merger.index(b)); 1255 Value op1 = codegen.idxs[tensor][idx]; 1256 Value op2 = codegen.loops[idx]; 1257 Value op3 = codegen.pidxs[tensor][idx]; 1258 Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 1259 Value add = rewriter.create<AddIOp>(loc, op3, one); 1260 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 1261 codegen.pidxs[tensor][idx] = results[o++]; 1262 } 1263 } 1264 if (needsUniv) { 1265 operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one)); 1266 codegen.loops[idx] = results[o++]; 1267 } 1268 assert(o == operands.size()); 1269 rewriter.create<scf::YieldOp>(loc, operands); 1270 } 1271 1272 /// Generates a single if-statement within a while-loop. 1273 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1274 PatternRewriter &rewriter, linalg::GenericOp op, 1275 unsigned idx, llvm::BitVector &conditions) { 1276 Location loc = op.getLoc(); 1277 Value cond; 1278 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1279 if (conditions[b]) { 1280 unsigned tensor = merger.tensor(b); 1281 assert(idx == merger.index(b)); 1282 Value clause; 1283 if (merger.isDim(b, Dim::kSparse)) { 1284 Value op1 = codegen.idxs[tensor][idx]; 1285 Value op2 = codegen.loops[idx]; 1286 clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 1287 } else { 1288 clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true 1289 } 1290 cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause; 1291 } 1292 } 1293 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true); 1294 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 1295 return ifOp; 1296 } 1297 1298 /// Recursively generates code while computing iteration lattices in order 1299 /// to manage the complexity of implementing co-iteration over unions 1300 /// and intersections of sparse iterations spaces. 1301 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1302 linalg::GenericOp op, std::vector<unsigned> &topSort, 1303 unsigned exp, unsigned at) { 1304 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1305 if (at == topSort.size()) { 1306 OpOperand *lhs = op.getOutputOperand(0); 1307 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1308 genTensorStore(merger, codegen, rewriter, op, lhs, rhs); 1309 return; 1310 } 1311 assert(codegen.curVecLength == 1); 1312 1313 // Construct iteration lattices for current loop index, with L0 at top. 1314 // Then emit initialization code for the loop sequence at this level. 1315 // We maintain the universal dense index if dense indices are still 1316 // in play for a non-singleton loop sequence. 1317 Location loc = op.getLoc(); 1318 unsigned idx = topSort[at]; 1319 unsigned lts = merger.optimizeSet(buildLattices(merger, op, exp, idx)); 1320 unsigned lsize = merger.set(lts).size(); 1321 assert(lsize != 0); 1322 unsigned l0 = merger.set(lts)[0]; 1323 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1324 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true); 1325 bool needsUniv = false; 1326 if (genInit(merger, codegen, rewriter, op, topSort, at, 1327 merger.lat(l0).bits)) { 1328 // Maintain the universal index only if it is actually 1329 // consumed by a subsequent lattice point. 1330 for (unsigned i = 1; i < lsize; i++) { 1331 unsigned li = merger.set(lts)[i]; 1332 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) { 1333 needsUniv = true; 1334 break; 1335 } 1336 } 1337 } 1338 1339 // Emit a loop for every lattice point L0 >= Li. 1340 for (unsigned i = 0; i < lsize; i++) { 1341 unsigned li = merger.set(lts)[i]; 1342 1343 // Emit loop. 1344 codegen.curVecLength = 1; 1345 llvm::BitVector indices = merger.lat(li).simple; 1346 Operation *loop = 1347 genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices); 1348 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1349 merger.lat(li).bits); 1350 1351 // Visit all lattices points with Li >= Lj to generate the 1352 // loop-body, possibly with if statements for coiteration. 1353 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1354 for (unsigned j = 0; j < lsize; j++) { 1355 unsigned lj = merger.set(lts)[j]; 1356 unsigned ej = merger.lat(lj).exp; 1357 if (li == lj || merger.latGT(li, lj)) { 1358 // Recurse into body of each branch. 1359 if (isWhile) { 1360 scf::IfOp ifOp = 1361 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1362 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1363 rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1364 } else { 1365 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1366 } 1367 } 1368 } 1369 1370 // Wrap-up induction and restore insertion point. 1371 if (isWhile) { 1372 scf::WhileOp whileOp = cast<scf::WhileOp>(loop); 1373 rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1374 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1375 merger.lat(li).bits, whileOp.results()); 1376 } else { 1377 needsUniv = false; 1378 if (codegen.redVal) { 1379 rewriter.create<scf::YieldOp>(loc, codegen.redVal); 1380 codegen.redVal = loop->getResult(0); 1381 } 1382 } 1383 rewriter.setInsertionPointAfter(loop); 1384 } 1385 1386 // Wrap-up loop sequence. 1387 codegen.curVecLength = 1; 1388 genReductionEnd(merger, codegen, rewriter, op); 1389 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); 1390 codegen.loops[idx] = Value(); 1391 } 1392 1393 /// Converts the result computed by the sparse kernel into the required form. 1394 static void genResult(CodeGen &codegen, PatternRewriter &rewriter, 1395 linalg::GenericOp op) { 1396 RankedTensorType resType = op.getOutputTensorTypes()[0]; 1397 Value result = codegen.buffers.back(); 1398 if (getSparseTensorEncoding(resType)) 1399 result = rewriter.create<ToTensorOp>(op.getLoc(), resType, result); 1400 else 1401 result = 1402 rewriter.create<memref::TensorLoadOp>(op.getLoc(), resType, result); 1403 rewriter.replaceOp(op, result); 1404 } 1405 1406 namespace { 1407 1408 /// Sparse rewriting rule for generic Lingalg operation. 1409 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1410 public: 1411 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1412 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1413 1414 LogicalResult matchAndRewrite(linalg::GenericOp op, 1415 PatternRewriter &rewriter) const override { 1416 // Detects sparse annotations and translate the per-dimension sparsity 1417 // information for all tensors to loop indices in the kernel. 1418 assert(op.getNumOutputs() == 1); 1419 unsigned numTensors = op.getNumInputsAndOutputs(); 1420 unsigned numLoops = op.iterator_types().getValue().size(); 1421 Merger merger(numTensors, numLoops); 1422 if (!findSparseAnnotations(merger, op)) 1423 return failure(); 1424 1425 // Computes a topologically sorted iteration graph to ensure 1426 // tensors are visited in natural index order. Fails on cycles. 1427 // This assumes that higher-level passes have already put the 1428 // tensors in each tensor expression in a feasible order. 1429 std::vector<unsigned> topSort; 1430 if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) && 1431 !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true)) 1432 return failure(); 1433 1434 // Finds the terminating yield statement and builds the tensor 1435 // expression for the Linalg operation in SSA form. 1436 Operation *yield = op.region().front().getTerminator(); 1437 Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0)); 1438 if (!exp.hasValue()) 1439 return failure(); // build failure 1440 1441 // Recursively generates code. 1442 CodeGen codegen(options, numTensors, numLoops); 1443 if (!genBuffers(merger, codegen, rewriter, op)) 1444 return failure(); // could not bufferize 1445 genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); 1446 genResult(codegen, rewriter, op); 1447 return success(); 1448 } 1449 1450 private: 1451 /// Options to control sparse code generation. 1452 SparsificationOptions options; 1453 }; 1454 1455 } // namespace 1456 1457 /// Populates the given patterns list with rewriting rules required for 1458 /// the sparsification of linear algebra operations. 1459 void mlir::populateSparsificationPatterns( 1460 RewritePatternSet &patterns, const SparsificationOptions &options) { 1461 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1462 } 1463