1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering sparse tensor types to actual sparse code. 10 // 11 // The concept of letting a compiler generate sparse code automatically was 12 // pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and 13 // formalized to tensor algebra by [Kjolstad17,20] for the Sparse Tensor 14 // Algebra Compiler (TACO). The implementation in this file closely follows 15 // the "sparse iteration theory" that forms the foundation of TACO. A rewriting 16 // rule is applied to each tensor expression in linalg (MLIR's tensor index 17 // notation) where the sparsity of tensors is indicated with annotation using 18 // a per-dimension specification of sparse/dense storage together with a 19 // specification of the order on the dimensions. Subsequently, a topologically 20 // sorted iteration graph, reflecting the required order on indices with respect 21 // to the dimensions of each tensor, is constructed to ensure that all tensors 22 // are visited in natural index order. Next, iteration lattices are constructed 23 // for the tensor expression for every index in topological order. Each 24 // iteration lattice point consists of a conjunction of tensor indices together 25 // with a tensor (sub)expression that needs to be evaluated for that 26 // conjunction. Within the lattice, iteration points are ordered according to 27 // the way indices are exhausted. As such these iteration lattices drive actual 28 // sparse code generation, which consists of a tedious but relatively 29 // straightforward one-to-one mapping from iteration lattices to combinations 30 // of for-loops, while-loops, and if-statements. 31 // 32 // [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations. 33 // PhD thesis, Leiden University, May 1996 (aartbik.com/sparse.php). 34 // [Kjolstad17] Fredrik Berg Kjolstad, Shoaib Ashraf Kamil, Stephen Chou, 35 // David Lugato, and Saman Amarasinghe. The Tensor Algebra Compiler. 36 // Proceedings of the ACM on Programming Languages, October 2017. 37 // [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation. 38 // PhD thesis, MIT, February, 2020 (tensor-compiler.org). 39 // 40 // Implementation detail: We use llvm::SmallVector for vectors with 41 // variable lengths and std::vector for vectors with fixed lengths. 42 //===----------------------------------------------------------------------===// 43 44 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 45 #include "mlir/Dialect/Linalg/Utils/Utils.h" 46 #include "mlir/Dialect/SCF/SCF.h" 47 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 48 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 49 #include "mlir/Dialect/StandardOps/IR/Ops.h" 50 #include "mlir/Dialect/Vector/VectorOps.h" 51 #include "mlir/IR/Matchers.h" 52 #include "mlir/IR/TensorEncoding.h" 53 #include "llvm/ADT/SmallBitVector.h" 54 55 using namespace mlir; 56 using namespace mlir::sparse_tensor; 57 58 namespace { 59 60 enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI }; 61 enum class Dim { kSparse, kDense, kSingle, kUndef }; 62 63 /// Tensor expression. Represents a MLIR expression in tensor index notation. 64 /// For tensors, e0 denotes the tensor index. For invariants, the IR value is 65 /// stored directly. For binary operations, e0 and e1 denote the index of the 66 /// children tensor expressions. 67 struct TensorExp { 68 TensorExp(Kind k, unsigned x, unsigned y, Value v) 69 : kind(k), e0(x), e1(y), val(v) { 70 assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) || 71 (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) || 72 (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val)); 73 } 74 Kind kind; 75 /// Indices of children expression(s). 76 unsigned e0; 77 unsigned e1; 78 /// Direct link to IR for an invariant. During code generation, 79 /// field is used to cache "hoisted" loop invariant tensor loads. 80 Value val; 81 }; 82 83 /// Lattice point. Each lattice point consists of a conjunction of tensor 84 /// loop indices (encoded in a bitvector) and the index of the corresponding 85 /// tensor expression. 86 struct LatPoint { 87 LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) { 88 bits.set(b); 89 } 90 LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {} 91 /// Conjunction of tensor loop indices as bitvector. This represents 92 /// all indices involved in the tensor expression 93 llvm::BitVector bits; 94 /// Simplified conjunction of tensor loop indices as bitvector. This 95 /// represents a simplified condition under which this tensor expression 96 /// must execute. Pre-computed during codegen to avoid repeated eval. 97 llvm::BitVector simple; 98 /// Index of the tensor expresssion. 99 unsigned exp; 100 }; 101 102 /// A class to handle all iteration lattice operations. This class abstracts 103 /// away from some implementation details of storing iteration lattices and 104 /// tensor expressions. This allows for fine-tuning performance characteristics 105 /// independently from the basic algorithm if bottlenecks are identified. 106 class Merger { 107 public: 108 /// Constructs a merger for the given number of tensors and loops. The 109 /// user supplies the number of tensors involved in the kernel, with the 110 /// last tensor in this set denoting the output tensor. The merger adds an 111 /// additional synthetic tensor at the end of this set to represent all 112 /// invariant expressions in the kernel. 113 Merger(unsigned t, unsigned l) 114 : outTensor(t - 1), numTensors(t + 1), numLoops(l), 115 dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {} 116 117 /// Adds a tensor expression. Returns its index. 118 unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) { 119 unsigned e = tensorExps.size(); 120 tensorExps.push_back(TensorExp(k, e0, e1, v)); 121 return e; 122 } 123 unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); } 124 125 /// Adds an iteration lattice point. Returns its index. 126 unsigned addLat(unsigned t, unsigned i, unsigned e) { 127 assert(t < numTensors && i < numLoops); 128 unsigned p = latPoints.size(); 129 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 130 return p; 131 } 132 133 /// Adds a new, initially empty, set. Returns its index. 134 unsigned addSet() { 135 unsigned s = latSets.size(); 136 latSets.emplace_back(SmallVector<unsigned, 16>()); 137 return s; 138 } 139 140 /// Computes a single conjunction of two lattice points by taking the "union" 141 /// of loop indices (effectively constructing a larger "intersection" of those 142 /// indices) with a newly constructed tensor (sub)expression of given kind. 143 /// Returns the index of the new lattice point. 144 unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) { 145 unsigned p = latPoints.size(); 146 llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits); 147 nb |= latPoints[p1].bits; 148 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); 149 latPoints.push_back(LatPoint(nb, e)); 150 return p; 151 } 152 153 /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of 154 /// cartesian product. Returns the index of the new set. 155 unsigned takeConj(Kind kind, unsigned s0, unsigned s1) { 156 unsigned s = addSet(); 157 for (unsigned p0 : latSets[s0]) 158 for (unsigned p1 : latSets[s1]) 159 latSets[s].push_back(conjLatPoint(kind, p0, p1)); 160 return s; 161 } 162 163 /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1). 164 /// Returns the index of the new set. 165 unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) { 166 unsigned s = takeConj(kind, s0, s1); 167 for (unsigned p : latSets[s0]) 168 latSets[s].push_back(p); 169 for (unsigned p : latSets[s1]) 170 latSets[s].push_back(p); 171 return s; 172 } 173 174 /// Optimizes the iteration lattice points in the given set. This 175 /// method should be called right before code generation to avoid 176 /// generating redundant loops and conditions. 177 unsigned optimizeSet(unsigned s0) { 178 unsigned s = addSet(); 179 assert(latSets[s0].size() != 0); 180 unsigned p0 = latSets[s0][0]; 181 for (unsigned p1 : latSets[s0]) { 182 bool add = true; 183 if (p0 != p1) { 184 // Is this a straightforward copy? 185 unsigned e = latPoints[p1].exp; 186 if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor) 187 continue; 188 // Conjunction already covered? 189 for (unsigned p2 : latSets[s]) { 190 assert(!latGT(p1, p2)); // Lj => Li would be bad 191 if (onlyDenseDiff(p2, p1)) { 192 add = false; 193 break; 194 } 195 } 196 assert(!add || latGT(p0, p1)); 197 } 198 if (add) 199 latSets[s].push_back(p1); 200 } 201 for (unsigned p : latSets[s]) 202 latPoints[p].simple = simplifyCond(s, p); 203 return s; 204 } 205 206 /// Simplifies the conditions in a conjunction of a given lattice point 207 /// within the given set using just two basic rules: 208 /// (1) multiple dense conditions are reduced to single dense, and 209 /// (2) a *singleton* sparse/dense is reduced to sparse/random access. 210 llvm::BitVector simplifyCond(unsigned s, unsigned p0) { 211 // First determine if this lattice point is a *singleton*, i.e., 212 // the last point in a lattice, no other is less than this one. 213 bool isSingleton = true; 214 for (unsigned p1 : latSets[s]) { 215 if (p0 != p1 && latGT(p0, p1)) { 216 isSingleton = false; 217 break; 218 } 219 } 220 // Now apply the two basic rules. 221 llvm::BitVector simple = latPoints[p0].bits; 222 bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse); 223 for (unsigned b = 0, be = simple.size(); b < be; b++) { 224 if (simple[b] && !isDim(b, Dim::kSparse)) { 225 if (reset) 226 simple.reset(b); 227 reset = true; 228 } 229 } 230 return simple; 231 } 232 233 /// Returns true if Li > Lj. 234 bool latGT(unsigned i, unsigned j) const { 235 const llvm::BitVector &bitsi = latPoints[i].bits; 236 const llvm::BitVector &bitsj = latPoints[j].bits; 237 assert(bitsi.size() == bitsj.size()); 238 if (bitsi.count() > bitsj.count()) { 239 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 240 if (bitsj[b] && !bitsi[b]) 241 return false; 242 return true; 243 } 244 return false; 245 } 246 247 /// Returns true if Li and Lj only differ in dense. 248 bool onlyDenseDiff(unsigned i, unsigned j) { 249 llvm::BitVector tmp = latPoints[j].bits; 250 tmp ^= latPoints[i].bits; 251 return !hasAnyDimOf(tmp, Dim::kSparse); 252 } 253 254 /// Bit translation. 255 unsigned tensor(unsigned b) const { return b % numTensors; } 256 unsigned index(unsigned b) const { return b / numTensors; } 257 258 /// Returns true if bit corresponds to queried dim. 259 bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); } 260 261 /// Returns true if bit corresponds to index of output tensor. 262 bool isOutTensor(unsigned b, unsigned i) const { 263 return tensor(b) == outTensor && index(b) == i; 264 } 265 266 /// Returns true if tensor access at given index has queried dim. 267 bool isDim(unsigned t, unsigned i, Dim d) const { 268 assert(t < numTensors && i < numLoops); 269 return dims[t][i] == d; 270 } 271 272 /// Returns true if any set bit corresponds to queried dim. 273 bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const { 274 for (unsigned b = 0, be = bits.size(); b < be; b++) 275 if (bits[b] && isDim(b, d)) 276 return true; 277 return false; 278 } 279 280 /// Setter 281 void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; } 282 283 /// Getters. 284 TensorExp &exp(unsigned e) { return tensorExps[e]; } 285 LatPoint &lat(unsigned l) { return latPoints[l]; } 286 SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; } 287 288 private: 289 const unsigned outTensor; 290 const unsigned numTensors; 291 const unsigned numLoops; 292 293 std::vector<std::vector<Dim>> dims; 294 llvm::SmallVector<TensorExp, 32> tensorExps; 295 llvm::SmallVector<LatPoint, 16> latPoints; 296 llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets; 297 }; 298 299 // Code generation. 300 struct CodeGen { 301 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) 302 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 303 pointers(numTensors, std::vector<Value>(numLoops)), 304 indices(numTensors, std::vector<Value>(numLoops)), 305 highs(numTensors, std::vector<Value>(numLoops)), 306 pidxs(numTensors, std::vector<Value>(numLoops)), 307 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 308 curVecLength(1), curVecMask() {} 309 /// Sparsification options. 310 SparsificationOptions options; 311 /// Universal dense indices and upper bounds (by index). The loops array 312 /// is updated with the value of the universal dense index in the current 313 /// loop. The sizes array is set once with the inferred dimension sizes. 314 std::vector<Value> loops; 315 std::vector<Value> sizes; 316 /// Buffers for storing dense and sparse numerical values (by tensor). 317 /// This array is set once during bufferization of all tensors. 318 std::vector<Value> buffers; 319 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 320 /// This array is set once during bufferization of all sparse tensors. 321 std::vector<std::vector<Value>> pointers; 322 std::vector<std::vector<Value>> indices; 323 /// Sparse iteration information (by tensor and index). These arrays 324 /// are updated to remain current within the current loop. 325 std::vector<std::vector<Value>> highs; 326 std::vector<std::vector<Value>> pidxs; 327 std::vector<std::vector<Value>> idxs; 328 /// Current reduction, updated during code generation. When indices of a 329 /// reduction are exhausted, all inner loops can "scalarize" the reduction. 330 // TODO: currently only done for (a chain of) innermost for-loops, where it 331 // is most effective; we could generalize to more outer and while-loops. 332 unsigned redExp; 333 Value redVal; 334 // Current vector length and mask. 335 unsigned curVecLength; 336 Value curVecMask; 337 }; 338 339 } // namespace 340 341 // Helper method to apply dimension ordering permutation. 342 static unsigned perm(SparseTensorEncodingAttr &enc, unsigned d) { 343 if (enc) { 344 auto order = enc.getDimOrdering(); 345 if (order) { 346 assert(order.isPermutation()); 347 return order.getDimPosition(d); 348 } 349 } 350 return d; 351 } 352 353 // Helper method to translate dim level type to internal representation. 354 static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) { 355 if (enc) { 356 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 357 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 358 return Dim::kSparse; 359 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 360 return Dim::kSingle; 361 } 362 return Dim::kDense; 363 } 364 365 /// Helper method to inspect sparse encodings in the tensor types. 366 /// Fills the per-dimension sparsity information for all tensors. 367 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 368 bool annotated = false; 369 OpOperand *lhs = op.getOutputOperand(0); 370 for (OpOperand *t : op.getInputAndOutputOperands()) { 371 auto map = op.getTiedIndexingMap(t); 372 if (!map.isProjectedPermutation()) 373 return false; 374 auto enc = getSparseTensorEncoding(t->get().getType()); 375 if (enc) 376 annotated = true; 377 assert(map.getNumResults() == op.getRank(t)); 378 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 379 unsigned idx = map.getDimPosition(perm(enc, d)); 380 Dim dim = toDim(enc, d); 381 merger.setDim(t->getOperandNumber(), idx, toDim(enc, d)); 382 // Accept only all-dense annotated "sparse" output. 383 // TODO: support truly sparse outputs too 384 if (t == lhs && dim != Dim::kDense) 385 return false; 386 } 387 } 388 return annotated; 389 } 390 391 /// A DFS helper to compute a topological sort. Note that recursion is 392 /// bounded by the number of implicit loops, which is always small. 393 /// Returns false when a cycle is detected. 394 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 395 std::vector<unsigned> &topSort, 396 std::vector<std::vector<bool>> &adjM) { 397 if (visit[i] != 0) 398 return visit[i] != 1; // 1 denotes cycle! 399 visit[i] = 1; 400 for (unsigned j = 0, e = visit.size(); j < e; j++) 401 if (adjM[i][j]) 402 if (!topSortDFS(j, visit, topSort, adjM)) 403 return false; 404 visit[i] = 2; 405 topSort.push_back(i); 406 return true; 407 } 408 409 /// Computes a topologically sorted iteration graph for the linalg operation. 410 /// Ensures all tensors are visited in natural index order. This is essential 411 /// for sparse storage formats since these only support access along fixed 412 /// dimensions. Even for dense storage formats, however, the natural index 413 /// order yields innermost unit-stride access with better spatial locality. 414 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 415 std::vector<unsigned> &topSort, 416 bool sparseOnly) { 417 // Set up an n x n from/to adjacency matrix of the iteration graph 418 // for the implicit loop indices i_0 .. i_n-1. 419 unsigned n = op.getNumLoops(); 420 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 421 422 // Iterate over the indexing maps of every tensor in the tensor expression. 423 for (OpOperand *t : op.getInputAndOutputOperands()) { 424 auto map = op.getTiedIndexingMap(t); 425 auto enc = getSparseTensorEncoding(t->get().getType()); 426 assert(map.getNumDims() == n); 427 // Skip dense tensor constraints when sparse only is requested. 428 if (sparseOnly && !enc) 429 continue; 430 // Each tensor expression and optional dimension ordering (row-major 431 // by default) puts an ordering constraint on the loop indices. For 432 // example, the tensor expresion A_ijk forces the ordering i < j < k 433 // on the loop indices if no explicit dimension ordering is given. 434 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 435 unsigned f = map.getDimPosition(perm(enc, d - 1)); 436 unsigned t = map.getDimPosition(perm(enc, d)); 437 adjM[f][t] = true; 438 } 439 } 440 441 // Topologically sort the iteration graph to determine loop order. 442 // Report failure for a cyclic iteration graph. 443 topSort.clear(); 444 topSort.reserve(n); 445 std::vector<unsigned> visit(n, 0); 446 for (unsigned i = 0; i < n; i++) 447 if (visit[i] == 0) 448 if (!topSortDFS(i, visit, topSort, adjM)) 449 return false; // cycle! 450 std::reverse(std::begin(topSort), std::end(topSort)); 451 return true; 452 } 453 454 /// Traverses the SSA tree (possibly a DAG) to build a tensor expression. 455 /// This simplifies constructing (sub)expressions during iteration lattice 456 /// building (compared to using the SSA representation everywhere). 457 static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op, 458 Value val) { 459 if (auto arg = val.dyn_cast<BlockArgument>()) { 460 unsigned argN = arg.getArgNumber(); 461 // Any parameter of the generic op is considered a tensor, 462 // indexed by the implicit loop bounds. 463 if (arg.getOwner()->getParentOp() == op) 464 return merger.addExp(Kind::kTensor, argN); 465 // Any parameter of a higher op is invariant. 466 return merger.addExp(Kind::kInvariant, val); 467 } 468 Operation *def = val.getDefiningOp(); 469 if (def->getBlock() != &op.region().front()) { 470 // Something defined outside is invariant. 471 return merger.addExp(Kind::kInvariant, val); 472 } else if (def->getNumOperands() == 2) { 473 // Construct binary operations if subexpressions could be built. 474 auto x = buildTensorExp(merger, op, def->getOperand(0)); 475 auto y = buildTensorExp(merger, op, def->getOperand(1)); 476 if (x.hasValue() && y.hasValue()) { 477 unsigned e0 = x.getValue(); 478 unsigned e1 = y.getValue(); 479 if (isa<MulFOp>(def)) 480 return merger.addExp(Kind::kMulF, e0, e1); 481 if (isa<MulIOp>(def)) 482 return merger.addExp(Kind::kMulI, e0, e1); 483 if (isa<AddFOp>(def)) 484 return merger.addExp(Kind::kAddF, e0, e1); 485 if (isa<AddIOp>(def)) 486 return merger.addExp(Kind::kAddI, e0, e1); 487 } 488 } 489 // Cannot build (yet). 490 return None; 491 } 492 493 /// Builds the iteration lattices in a bottom-up traversal given the remaining 494 /// tensor (sub)expression and the next loop index in the iteration graph. 495 static unsigned buildLattices(Merger &merger, linalg::GenericOp op, 496 unsigned exp, unsigned idx) { 497 Kind kind = merger.exp(exp).kind; 498 if (kind == Kind::kTensor || kind == Kind::kInvariant) { 499 // Either the index is really used in the tensor expression, or it is 500 // set to the undefined index in that dimension. An invariant expression 501 // is set to a synthetic tensor with undefined indices only. 502 unsigned s = merger.addSet(); 503 unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0 504 : op.getNumInputsAndOutputs(); 505 merger.set(s).push_back(merger.addLat(t, idx, exp)); 506 return s; 507 } 508 unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx); 509 unsigned s1 = buildLattices(merger, op, merger.exp(exp).e1, idx); 510 switch (kind) { 511 case Kind::kTensor: 512 case Kind::kInvariant: 513 llvm_unreachable("handled above"); 514 case Kind::kMulF: 515 case Kind::kMulI: 516 return merger.takeConj(kind, s0, s1); 517 case Kind::kAddF: 518 case Kind::kAddI: 519 return merger.takeDisj(kind, s0, s1); 520 } 521 llvm_unreachable("unexpected expression kind"); 522 } 523 524 /// Maps sparse integer option to actual integral storage type. 525 static Type genIntType(PatternRewriter &rewriter, unsigned width) { 526 if (width == 0) 527 return rewriter.getIndexType(); 528 return rewriter.getIntegerType(width); 529 } 530 531 /// Detects in-place annotation on tensor argument. 532 static bool getInPlace(Value val) { 533 if (auto arg = val.dyn_cast<BlockArgument>()) 534 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 535 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 536 arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) 537 return attr.getValue(); 538 return false; 539 } 540 541 /// Generates buffer for the output tensor. 542 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 543 linalg::GenericOp op, MemRefType denseTp, 544 ArrayRef<Value> args) { 545 Location loc = op.getLoc(); 546 Value tensor = op.getOutputOperand(0)->get(); 547 // The output tensor simply could materialize from the buffer that will 548 // be generated for the tensor present in the outs() clause. This has 549 // the major advantage that the sparse kernel only updates the nonzero 550 // positions for the output tensor. 551 if (getInPlace(tensor)) 552 return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 553 // By default, a new buffer is allocated which is initialized to the 554 // tensor defined in the outs() clause. This is always correct but 555 // introduces a dense initialization component that may negatively 556 // impact the running complexity of the sparse kernel. 557 Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 558 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 559 rewriter.create<linalg::CopyOp>(loc, init, alloc); 560 return alloc; 561 } 562 563 /// Local bufferization of all dense and sparse data structures. 564 /// This code enables testing the first prototype sparse compiler. 565 // TODO: replace this with a proliferated bufferization strategy 566 static bool genBuffers(Merger &merger, CodeGen &codegen, 567 PatternRewriter &rewriter, linalg::GenericOp op) { 568 Location loc = op.getLoc(); 569 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 570 // For every tensor, find lower and upper bound on dimensions, set the 571 // same bounds on loop indices, and obtain dense or sparse buffer(s). 572 SmallVector<Value, 4> args; 573 for (OpOperand *t : op.getInputAndOutputOperands()) { 574 unsigned tensor = t->getOperandNumber(); 575 auto shape = op.getShape(t); 576 auto map = op.getTiedIndexingMap(t); 577 auto enc = getSparseTensorEncoding(t->get().getType()); 578 // Scan all dimensions of current tensor. 579 args.clear(); 580 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 581 unsigned idx = map.getDimPosition(perm(enc, d)); 582 // Handle sparse storage schemes. 583 if (merger.isDim(tensor, idx, Dim::kSparse)) { 584 auto dynShape = {ShapedType::kDynamicSize}; 585 auto ptrTp = MemRefType::get( 586 dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 587 auto indTp = MemRefType::get( 588 dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 589 Value dim = rewriter.create<ConstantIndexOp>(loc, d); 590 // Generate sparse primitives to obtains pointer and indices. 591 codegen.pointers[tensor][idx] = 592 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 593 codegen.indices[tensor][idx] = 594 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 595 } 596 // Find lower and upper bound in current dimension. 597 Value up; 598 if (shape[d] == MemRefType::kDynamicSize) { 599 up = rewriter.create<memref::DimOp>(loc, t->get(), d); 600 args.push_back(up); 601 } else { 602 up = rewriter.create<ConstantIndexOp>(loc, shape[d]); 603 } 604 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 605 } 606 // Perform the required bufferization. Dense inputs materialize 607 // from the input tensors. Dense outputs need special handling. 608 // Sparse inputs use sparse primitives to obtain the values. 609 // We also accept in-place all-dense annotated "sparse" outputs. 610 Type elementType = getElementTypeOrSelf(t->get().getType()); 611 if (!enc) { 612 // Non-annotated dense tensors. 613 auto denseTp = MemRefType::get(shape, elementType); 614 if (tensor < op.getNumInputs()) 615 codegen.buffers[tensor] = 616 rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get()); 617 else 618 codegen.buffers[tensor] = 619 genOutputBuffer(codegen, rewriter, op, denseTp, args); 620 } else { 621 // Annotated sparse tensors. 622 if (tensor == op.getNumInputs() && !getInPlace(t->get())) 623 return false; // reject output if not in-place 624 auto dynShape = {ShapedType::kDynamicSize}; 625 auto sparseTp = MemRefType::get(dynShape, elementType); 626 codegen.buffers[tensor] = 627 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 628 } 629 } 630 return true; 631 } 632 633 /// Constructs vector type. 634 static VectorType vectorType(CodeGen &codegen, Type etp) { 635 return VectorType::get(codegen.curVecLength, etp); 636 } 637 638 /// Constructs vector type from pointer. 639 static VectorType vectorType(CodeGen &codegen, Value ptr) { 640 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 641 } 642 643 /// Constructs vector iteration mask. 644 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 645 Value iv, Value lo, Value hi, Value step) { 646 Location loc = iv.getLoc(); 647 VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1)); 648 // Special case if the vector length evenly divides the trip count (for 649 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 650 // so that all subsequent masked memory operations are immediately folded 651 // into unconditional memory operations. 652 IntegerAttr loInt, hiInt, stepInt; 653 if (matchPattern(lo, m_Constant(&loInt)) && 654 matchPattern(hi, m_Constant(&hiInt)) && 655 matchPattern(step, m_Constant(&stepInt))) { 656 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 657 return rewriter.create<vector::BroadcastOp>( 658 loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1)); 659 } 660 // Otherwise, generate a vector mask that avoids overrunning the upperbound 661 // during vector execution. Here we rely on subsequent loop optimizations to 662 // avoid executing the mask in all iterations, for example, by splitting the 663 // loop into an unconditional vector loop and a scalar cleanup loop. 664 Value end = rewriter.create<SubIOp>(loc, hi, iv); 665 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 666 } 667 668 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 669 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 670 Value ptr, ArrayRef<Value> args) { 671 Location loc = ptr.getLoc(); 672 VectorType vtp = vectorType(codegen, ptr); 673 Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 674 if (args.back().getType().isa<VectorType>()) { 675 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 676 Value indexVec = args.back(); 677 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 678 return rewriter.create<vector::GatherOp>( 679 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 680 } 681 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 682 codegen.curVecMask, pass); 683 } 684 685 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 686 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 687 Value rhs, Value ptr, ArrayRef<Value> args) { 688 Location loc = ptr.getLoc(); 689 if (args.back().getType().isa<VectorType>()) { 690 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 691 Value indexVec = args.back(); 692 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 693 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 694 codegen.curVecMask, rhs); 695 return; 696 } 697 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 698 rhs); 699 } 700 701 /// Generates a vectorized invariant. Here we rely on subsequent loop 702 /// optimizations to hoist the invariant broadcast out of the vector loop. 703 static Value genVectorInvariantValue(CodeGen &codegen, 704 PatternRewriter &rewriter, Value val) { 705 VectorType vtp = vectorType(codegen, val.getType()); 706 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 707 } 708 709 /// Generates a load on a dense or sparse tensor. 710 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 711 PatternRewriter &rewriter, linalg::GenericOp op, 712 unsigned exp) { 713 // Test if the load was hoisted to a higher loop nest. 714 Value val = merger.exp(exp).val; 715 if (val) { 716 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 717 return genVectorInvariantValue(codegen, rewriter, val); 718 return val; 719 } 720 // Actual load. 721 SmallVector<Value, 4> args; 722 OpOperand *t = merger.exp(exp).e0 < op.getNumInputs() 723 ? op.getInputOperand(merger.exp(exp).e0) 724 : op.getOutputOperand(0); 725 unsigned tensor = t->getOperandNumber(); 726 auto map = op.getTiedIndexingMap(t); 727 auto enc = getSparseTensorEncoding(t->get().getType()); 728 unsigned rank = map.getNumResults(); 729 if (enc) { 730 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 731 assert(codegen.pidxs[tensor][idx] != nullptr); 732 args.push_back(codegen.pidxs[tensor][idx]); // position index 733 } else { 734 for (unsigned d = 0; d < rank; d++) { 735 unsigned idx = map.getDimPosition(d); 736 args.push_back(codegen.loops[idx]); // universal dense index 737 } 738 } 739 Location loc = op.getLoc(); 740 Value ptr = codegen.buffers[tensor]; 741 if (codegen.curVecLength > 1) 742 return genVectorLoad(codegen, rewriter, ptr, args); 743 return rewriter.create<memref::LoadOp>(loc, ptr, args); 744 } 745 746 /// Generates a store on a dense or sparse tensor. 747 static void genTensorStore(Merger &merger, CodeGen &codegen, 748 PatternRewriter &rewriter, linalg::GenericOp op, 749 OpOperand *t, Value rhs) { 750 Location loc = op.getLoc(); 751 // Test if this is a scalarized reduction. 752 OpOperand *lhs = op.getOutputOperand(0); 753 if (lhs == t && codegen.redVal) { 754 if (codegen.curVecLength > 1) 755 rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs, 756 codegen.redVal); 757 codegen.redVal = rhs; 758 return; 759 } 760 // Actual store. 761 SmallVector<Value, 4> args; 762 unsigned tensor = t->getOperandNumber(); 763 auto map = op.getTiedIndexingMap(t); 764 auto enc = getSparseTensorEncoding(t->get().getType()); 765 unsigned rank = map.getNumResults(); 766 if (enc) { 767 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 768 assert(codegen.pidxs[tensor][idx] != nullptr); 769 args.push_back(codegen.pidxs[tensor][idx]); // position index 770 } else { 771 for (unsigned d = 0; d < rank; d++) { 772 unsigned idx = map.getDimPosition(d); 773 args.push_back(codegen.loops[idx]); // universal dense index 774 } 775 } 776 Value ptr = codegen.buffers[tensor]; 777 if (codegen.curVecLength > 1) 778 genVectorStore(codegen, rewriter, rhs, ptr, args); 779 else 780 rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 781 } 782 783 /// Generates a pointer/index load from the sparse storage scheme. Narrower 784 /// data types need to be zero extended before casting the value into the 785 /// index type used for looping and indexing. 786 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 787 Value ptr, Value s) { 788 // See https://llvm.org/docs/GetElementPtr.html for some background on 789 // the complications described below. 790 if (codegen.curVecLength > 1) { 791 // Since the index vector is used in a subsequent gather/scatter operations, 792 // which effectively defines an unsigned pointer + signed index, we must 793 // zero extend the vector to an index width. For 8-bit and 16-bit values, 794 // an 32-bit index width suffices. For 32-bit values, zero extending the 795 // elements into 64-bit loses some performance since the 32-bit indexed 796 // gather/scatter is more efficient than the 64-bit index variant (if the 797 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 798 // preserve this performance). For 64-bit values, there is no good way 799 // to state that the indices are unsigned, with creates the potential of 800 // incorrect address calculations in the unlikely case we need such 801 // extremely large offsets. 802 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 803 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 804 if (!etp.isa<IndexType>()) { 805 if (etp.getIntOrFloatBitWidth() < 32) 806 vload = rewriter.create<ZeroExtendIOp>( 807 loc, vload, vectorType(codegen, rewriter.getIntegerType(32))); 808 else if (etp.getIntOrFloatBitWidth() < 64 && 809 !codegen.options.enableSIMDIndex32) 810 vload = rewriter.create<ZeroExtendIOp>( 811 loc, vload, vectorType(codegen, rewriter.getIntegerType(64))); 812 } 813 return vload; 814 } 815 // For the scalar case, we simply zero extend narrower indices into 64-bit 816 // values before casting to index without a performance penalty. Here too, 817 // however, indices that already are 64-bit, in theory, cannot express the 818 // full range as explained above. 819 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 820 if (!load.getType().isa<IndexType>()) { 821 if (load.getType().getIntOrFloatBitWidth() < 64) 822 load = rewriter.create<ZeroExtendIOp>(loc, load, 823 rewriter.getIntegerType(64)); 824 load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType()); 825 } 826 return load; 827 } 828 829 /// Generates an invariant value. 830 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 831 PatternRewriter &rewriter, unsigned exp) { 832 Value val = merger.exp(exp).val; 833 if (codegen.curVecLength > 1) 834 return genVectorInvariantValue(codegen, rewriter, val); 835 return val; 836 } 837 838 /// Generates an address computation "sz * p + i". 839 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 840 Location loc, Value size, Value p, Value i) { 841 Value mul = rewriter.create<MulIOp>(loc, size, p); 842 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 843 Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType()); 844 mul = genVectorInvariantValue(codegen, rewriter, inv); 845 } 846 return rewriter.create<AddIOp>(loc, mul, i); 847 } 848 849 /// Generates start of a reduction. 850 static Value genReductionStart(Merger &merger, CodeGen &codegen, 851 PatternRewriter &rewriter, 852 linalg::GenericOp op) { 853 if (codegen.redVal) 854 return codegen.redVal; // chained with previous for-loop 855 if (codegen.curVecLength > 1) { 856 // TODO: assumes + reductions for now 857 VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]); 858 return rewriter.create<ConstantOp>(op.getLoc(), vtp, 859 rewriter.getZeroAttr(vtp)); 860 } 861 return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 862 } 863 864 /// Generates end of a reduction. 865 static void genReductionEnd(Merger &merger, CodeGen &codegen, 866 PatternRewriter &rewriter, linalg::GenericOp op) { 867 Value red = codegen.redVal; 868 if (!red) 869 return; 870 assert(codegen.curVecLength == 1); 871 codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain 872 OpOperand *lhs = op.getOutputOperand(0); 873 if (auto vtp = red.getType().dyn_cast<VectorType>()) { 874 // TODO: assumes + reductions for now 875 StringAttr kind = rewriter.getStringAttr("add"); 876 Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 877 // Integer reductions don't accept an accumulator. 878 if (vtp.getElementType().isa<IntegerType>()) { 879 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 880 kind, red, ValueRange{}); 881 red = rewriter.create<AddIOp>(op.getLoc(), red, ld); 882 } else { 883 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 884 kind, red, ld); 885 } 886 } 887 genTensorStore(merger, codegen, rewriter, op, lhs, red); 888 } 889 890 /// Recursively generates tensor expression. 891 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 892 linalg::GenericOp op, unsigned exp) { 893 if (merger.exp(exp).kind == Kind::kTensor) 894 return genTensorLoad(merger, codegen, rewriter, op, exp); 895 else if (merger.exp(exp).kind == Kind::kInvariant) 896 return genInvariantValue(merger, codegen, rewriter, exp); 897 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0); 898 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1); 899 switch (merger.exp(exp).kind) { 900 case Kind::kTensor: 901 case Kind::kInvariant: 902 llvm_unreachable("handled above"); 903 case Kind::kMulF: 904 return rewriter.create<MulFOp>(op.getLoc(), v0, v1); 905 case Kind::kMulI: 906 return rewriter.create<MulIOp>(op.getLoc(), v0, v1); 907 case Kind::kAddF: 908 return rewriter.create<AddFOp>(op.getLoc(), v0, v1); 909 case Kind::kAddI: 910 return rewriter.create<AddIOp>(op.getLoc(), v0, v1); 911 } 912 llvm_unreachable("unexpected expression kind"); 913 } 914 915 /// Hoists loop invariant tensor loads for which indices have been exhausted. 916 static void genInvariants(Merger &merger, CodeGen &codegen, 917 PatternRewriter &rewriter, linalg::GenericOp op, 918 unsigned exp, unsigned ldx, bool hoist) { 919 if (merger.exp(exp).kind == Kind::kTensor) { 920 // Inspect tensor indices. 921 bool atLevel = ldx == -1u; 922 OpOperand *tensor = merger.exp(exp).e0 < op.getNumInputs() 923 ? op.getInputOperand(merger.exp(exp).e0) 924 : op.getOutputOperand(0); 925 auto map = op.getTiedIndexingMap(tensor); 926 auto enc = getSparseTensorEncoding(tensor->get().getType()); 927 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 928 unsigned idx = map.getDimPosition(perm(enc, d)); 929 if (!codegen.loops[idx]) 930 return; // still in play 931 else if (idx == ldx) 932 atLevel = true; 933 } 934 // All exhausted at this level (atLevel denotes exactly at this level). 935 OpOperand *lhs = op.getOutputOperand(0); 936 if (lhs == tensor) { 937 codegen.redExp = hoist ? exp : -1u; 938 } else if (atLevel) { 939 merger.exp(exp).val = 940 hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 941 } 942 } else if (merger.exp(exp).kind != Kind::kInvariant) { 943 // Traverse into the binary operations. Note that we only hoist 944 // tensor loads, since subsequent MLIR/LLVM passes know how to 945 // deal with all other kinds of derived loop invariants. 946 unsigned e0 = merger.exp(exp).e0; 947 unsigned e1 = merger.exp(exp).e1; 948 genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist); 949 genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist); 950 } 951 } 952 953 /// Generates initialization code for the subsequent loop sequence at 954 /// current index level. Returns true if the loop sequence needs to 955 /// maintain the universal index. 956 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 957 linalg::GenericOp op, std::vector<unsigned> &topSort, 958 unsigned at, llvm::BitVector &inits) { 959 bool needsUniv = false; 960 Location loc = op.getLoc(); 961 unsigned idx = topSort[at]; 962 963 // Initialize sparse positions. 964 for (unsigned b = 0, be = inits.size(); b < be; b++) { 965 if (inits[b]) { 966 unsigned tensor = merger.tensor(b); 967 assert(idx == merger.index(b)); 968 if (merger.isDim(b, Dim::kSparse)) { 969 // Initialize sparse index. 970 unsigned pat = at; 971 for (; pat != 0; pat--) { 972 if (codegen.pidxs[tensor][topSort[pat - 1]]) 973 break; 974 } 975 Value ptr = codegen.pointers[tensor][idx]; 976 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 977 Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 978 : codegen.pidxs[tensor][topSort[pat - 1]]; 979 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 980 Value p1 = rewriter.create<AddIOp>(loc, p0, one); 981 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 982 } else { 983 // Dense index still in play. 984 needsUniv = true; 985 } 986 } 987 } 988 989 // Initialize the universal dense index. 990 codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0); 991 return needsUniv; 992 } 993 994 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 995 /// operation is a candidate. Whether it is actually converted to SIMD code 996 /// depends on the requested strategy. 997 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 998 switch (codegen.options.vectorizationStrategy) { 999 case SparseVectorizationStrategy::kNone: 1000 return false; 1001 case SparseVectorizationStrategy::kDenseInnerLoop: 1002 return isInner && !isSparse; 1003 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 1004 return isInner; 1005 } 1006 llvm_unreachable("unexpected vectorization strategy"); 1007 } 1008 1009 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 1010 /// that is marked "parallel" is a candidate. Whether it is actually converted 1011 /// to a parallel operation depends on the requested strategy. 1012 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 1013 bool isSparse, bool isVector) { 1014 switch (codegen.options.parallelizationStrategy) { 1015 case SparseParallelizationStrategy::kNone: 1016 return false; 1017 case SparseParallelizationStrategy::kDenseOuterLoop: 1018 return isOuter && !isSparse && !isReduction && !isVector; 1019 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 1020 return isOuter && !isReduction && !isVector; 1021 case SparseParallelizationStrategy::kDenseAnyLoop: 1022 return !isSparse && !isReduction && !isVector; 1023 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 1024 return !isReduction && !isVector; 1025 } 1026 llvm_unreachable("unexpected parallelization strategy"); 1027 } 1028 1029 /// Checks unit strides for dense tensors. The iteration graph may have ignored 1030 /// dense access patterns in order to avoid cycles (sparse access patterns are 1031 /// always placed innermost), but that means dense access has become strided. 1032 /// For now, we reject vectorization of such cases. 1033 /// TODO: implement strided load/stores on dense arrays 1034 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 1035 unsigned idx) { 1036 for (OpOperand *t : op.getInputAndOutputOperands()) { 1037 if (!getSparseTensorEncoding(t->get().getType())) { 1038 auto map = op.getTiedIndexingMap(t); 1039 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1040 if (map.getDimPosition(d) == idx && d != rank - 1) 1041 return false; 1042 } 1043 } 1044 } 1045 return true; 1046 } 1047 1048 /// Generates a for-loop on a single index. 1049 static Operation *genFor(Merger &merger, CodeGen &codegen, 1050 PatternRewriter &rewriter, linalg::GenericOp op, 1051 bool isOuter, bool isInner, unsigned idx, 1052 llvm::BitVector &indices) { 1053 unsigned fb = indices.find_first(); 1054 unsigned tensor = merger.tensor(fb); 1055 assert(idx == merger.index(fb)); 1056 auto iteratorTypes = op.iterator_types().getValue(); 1057 bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]); 1058 bool isSparse = merger.isDim(fb, Dim::kSparse); 1059 bool isVector = isVectorFor(codegen, isInner, isSparse) && 1060 denseUnitStrides(merger, op, idx); 1061 bool isParallel = 1062 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 1063 1064 // Prepare vector length. 1065 if (isVector) 1066 codegen.curVecLength = codegen.options.vectorLength; 1067 1068 // Loop bounds and increment. 1069 Location loc = op.getLoc(); 1070 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 1071 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 1072 Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength); 1073 1074 // Emit a parallel loop. 1075 if (isParallel) { 1076 assert(!isVector); 1077 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 1078 if (isSparse) 1079 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1080 else 1081 codegen.loops[idx] = parOp.getInductionVars()[0]; 1082 rewriter.setInsertionPointToStart(parOp.getBody()); 1083 return parOp; 1084 } 1085 1086 // Emit a sequential loop, potentially with a scalarized reduction. 1087 bool scalarRed = isInner && codegen.redExp != -1u; 1088 SmallVector<Value, 4> operands; 1089 if (scalarRed) { 1090 Value load = genReductionStart(merger, codegen, rewriter, op); 1091 operands.push_back(load); 1092 } 1093 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 1094 if (scalarRed) { 1095 codegen.redVal = merger.exp(codegen.redExp).val = 1096 forOp.getRegionIterArgs().front(); 1097 } 1098 // Assign induction variable to sparse or dense index. 1099 Value iv = forOp.getInductionVar(); 1100 if (isSparse) 1101 codegen.pidxs[tensor][idx] = iv; 1102 else 1103 codegen.loops[idx] = iv; 1104 rewriter.setInsertionPointToStart(forOp.getBody()); 1105 // Share vector iteration mask between all subsequent loads/stores. 1106 if (isVector) 1107 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1108 return forOp; 1109 } 1110 1111 /// Emit a while-loop for co-iteration over multiple indices. 1112 static Operation *genWhile(Merger &merger, CodeGen &codegen, 1113 PatternRewriter &rewriter, linalg::GenericOp op, 1114 unsigned idx, bool needsUniv, 1115 llvm::BitVector &indices) { 1116 SmallVector<Type, 4> types; 1117 SmallVector<Value, 4> operands; 1118 // Construct the while-loop with a parameter for each index. 1119 Type indexType = rewriter.getIndexType(); 1120 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1121 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1122 unsigned tensor = merger.tensor(b); 1123 assert(idx == merger.index(b)); 1124 types.push_back(indexType); 1125 assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() && 1126 "type mismatch for sparse index"); 1127 operands.push_back(codegen.pidxs[tensor][idx]); 1128 } 1129 } 1130 if (needsUniv) { 1131 types.push_back(indexType); 1132 assert(codegen.loops[idx].getType().isa<IndexType>() && 1133 "type mismatch for universal index"); 1134 operands.push_back(codegen.loops[idx]); 1135 } 1136 Location loc = op.getLoc(); 1137 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1138 Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 1139 Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 1140 1141 // Build the "before" region, which effectively consists 1142 // of a conjunction of "i < upper" tests on all induction. 1143 rewriter.setInsertionPointToStart(&whileOp.before().front()); 1144 Value cond; 1145 unsigned o = 0; 1146 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1147 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1148 unsigned tensor = merger.tensor(b); 1149 assert(idx == merger.index(b)); 1150 Value op1 = before->getArgument(o); 1151 Value op2 = codegen.highs[tensor][idx]; 1152 Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2); 1153 cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc; 1154 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1155 } 1156 } 1157 if (needsUniv) 1158 codegen.loops[idx] = after->getArgument(o++); 1159 assert(o == operands.size()); 1160 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1161 rewriter.setInsertionPointToStart(&whileOp.after().front()); 1162 return whileOp; 1163 } 1164 1165 /// Generates a for-loop or a while-loop, depending on whether it implements 1166 /// singleton iteration or co-iteration over the given conjunction. 1167 static Operation *genLoop(Merger &merger, CodeGen &codegen, 1168 PatternRewriter &rewriter, linalg::GenericOp op, 1169 std::vector<unsigned> &topSort, unsigned at, 1170 bool needsUniv, llvm::BitVector &indices) { 1171 unsigned idx = topSort[at]; 1172 if (indices.count() == 1) { 1173 bool isOuter = at == 0; 1174 bool isInner = at == topSort.size() - 1; 1175 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1176 indices); 1177 } 1178 genReductionEnd(merger, codegen, rewriter, op); // cannot chain 1179 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1180 } 1181 1182 /// Generates the local variables for this loop, consisting of the sparse 1183 /// indices, restored universal dense index, and dense positions. 1184 static void genLocals(Merger &merger, CodeGen &codegen, 1185 PatternRewriter &rewriter, linalg::GenericOp op, 1186 std::vector<unsigned> &topSort, unsigned at, 1187 bool needsUniv, llvm::BitVector &locals) { 1188 Location loc = op.getLoc(); 1189 unsigned idx = topSort[at]; 1190 1191 // Initialize sparse indices. 1192 Value min; 1193 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1194 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1195 unsigned tensor = merger.tensor(b); 1196 assert(idx == merger.index(b)); 1197 Value ptr = codegen.indices[tensor][idx]; 1198 Value s = codegen.pidxs[tensor][idx]; 1199 Value load = genLoad(codegen, rewriter, loc, ptr, s); 1200 codegen.idxs[tensor][idx] = load; 1201 if (!needsUniv) { 1202 if (min) { 1203 Value cmp = 1204 rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min); 1205 min = rewriter.create<SelectOp>(loc, cmp, load, min); 1206 } else { 1207 min = load; 1208 } 1209 } 1210 } 1211 } 1212 1213 // Merge dense universal index over minimum. 1214 if (min) { 1215 assert(!needsUniv); 1216 codegen.loops[idx] = min; 1217 } 1218 1219 // Initialize dense positions. Note that we generate dense indices of the 1220 // output tensor unconditionally, since they may not appear in the lattice, 1221 // but may be needed for linearized codegen. 1222 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1223 if ((locals[b] || merger.isOutTensor(b, idx)) && 1224 merger.isDim(b, Dim::kDense)) { 1225 unsigned tensor = merger.tensor(b); 1226 assert(idx == merger.index(b)); 1227 unsigned pat = at; 1228 for (; pat != 0; pat--) 1229 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1230 break; 1231 Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 1232 : codegen.pidxs[tensor][topSort[pat - 1]]; 1233 codegen.pidxs[tensor][idx] = genAddress( 1234 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1235 } 1236 } 1237 } 1238 1239 /// Generates the induction structure for a while-loop. 1240 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1241 PatternRewriter &rewriter, linalg::GenericOp op, 1242 unsigned idx, bool needsUniv, 1243 llvm::BitVector &induction, ResultRange results) { 1244 Location loc = op.getLoc(); 1245 unsigned o = 0; 1246 SmallVector<Value, 4> operands; 1247 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 1248 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1249 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1250 unsigned tensor = merger.tensor(b); 1251 assert(idx == merger.index(b)); 1252 Value op1 = codegen.idxs[tensor][idx]; 1253 Value op2 = codegen.loops[idx]; 1254 Value op3 = codegen.pidxs[tensor][idx]; 1255 Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 1256 Value add = rewriter.create<AddIOp>(loc, op3, one); 1257 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 1258 codegen.pidxs[tensor][idx] = results[o++]; 1259 } 1260 } 1261 if (needsUniv) { 1262 operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one)); 1263 codegen.loops[idx] = results[o++]; 1264 } 1265 assert(o == operands.size()); 1266 rewriter.create<scf::YieldOp>(loc, operands); 1267 } 1268 1269 /// Generates a single if-statement within a while-loop. 1270 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1271 PatternRewriter &rewriter, linalg::GenericOp op, 1272 unsigned idx, llvm::BitVector &conditions) { 1273 Location loc = op.getLoc(); 1274 Value cond; 1275 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1276 if (conditions[b]) { 1277 unsigned tensor = merger.tensor(b); 1278 assert(idx == merger.index(b)); 1279 Value clause; 1280 if (merger.isDim(b, Dim::kSparse)) { 1281 Value op1 = codegen.idxs[tensor][idx]; 1282 Value op2 = codegen.loops[idx]; 1283 clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 1284 } else { 1285 clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true 1286 } 1287 cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause; 1288 } 1289 } 1290 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true); 1291 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 1292 return ifOp; 1293 } 1294 1295 /// Recursively generates code while computing iteration lattices in order 1296 /// to manage the complexity of implementing co-iteration over unions 1297 /// and intersections of sparse iterations spaces. 1298 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1299 linalg::GenericOp op, std::vector<unsigned> &topSort, 1300 unsigned exp, unsigned at) { 1301 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1302 if (at == topSort.size()) { 1303 OpOperand *lhs = op.getOutputOperand(0); 1304 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1305 genTensorStore(merger, codegen, rewriter, op, lhs, rhs); 1306 return; 1307 } 1308 assert(codegen.curVecLength == 1); 1309 1310 // Construct iteration lattices for current loop index, with L0 at top. 1311 // Then emit initialization code for the loop sequence at this level. 1312 // We maintain the universal dense index if dense indices are still 1313 // in play for a non-singleton loop sequence. 1314 Location loc = op.getLoc(); 1315 unsigned idx = topSort[at]; 1316 unsigned lts = merger.optimizeSet(buildLattices(merger, op, exp, idx)); 1317 unsigned lsize = merger.set(lts).size(); 1318 assert(lsize != 0); 1319 unsigned l0 = merger.set(lts)[0]; 1320 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1321 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true); 1322 bool needsUniv = false; 1323 if (genInit(merger, codegen, rewriter, op, topSort, at, 1324 merger.lat(l0).bits)) { 1325 // Maintain the universal index only if it is actually 1326 // consumed by a subsequent lattice point. 1327 for (unsigned i = 1; i < lsize; i++) { 1328 unsigned li = merger.set(lts)[i]; 1329 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) { 1330 needsUniv = true; 1331 break; 1332 } 1333 } 1334 } 1335 1336 // Emit a loop for every lattice point L0 >= Li. 1337 for (unsigned i = 0; i < lsize; i++) { 1338 unsigned li = merger.set(lts)[i]; 1339 1340 // Emit loop. 1341 codegen.curVecLength = 1; 1342 llvm::BitVector indices = merger.lat(li).simple; 1343 Operation *loop = 1344 genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices); 1345 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1346 merger.lat(li).bits); 1347 1348 // Visit all lattices points with Li >= Lj to generate the 1349 // loop-body, possibly with if statements for coiteration. 1350 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1351 for (unsigned j = 0; j < lsize; j++) { 1352 unsigned lj = merger.set(lts)[j]; 1353 unsigned ej = merger.lat(lj).exp; 1354 if (li == lj || merger.latGT(li, lj)) { 1355 // Recurse into body of each branch. 1356 if (isWhile) { 1357 scf::IfOp ifOp = 1358 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1359 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1360 rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1361 } else { 1362 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1363 } 1364 } 1365 } 1366 1367 // Wrap-up induction and restore insertion point. 1368 if (isWhile) { 1369 scf::WhileOp whileOp = cast<scf::WhileOp>(loop); 1370 rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1371 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1372 merger.lat(li).bits, whileOp.results()); 1373 } else { 1374 needsUniv = false; 1375 if (codegen.redVal) { 1376 rewriter.create<scf::YieldOp>(loc, codegen.redVal); 1377 codegen.redVal = loop->getResult(0); 1378 } 1379 } 1380 rewriter.setInsertionPointAfter(loop); 1381 } 1382 1383 // Wrap-up loop sequence. 1384 codegen.curVecLength = 1; 1385 genReductionEnd(merger, codegen, rewriter, op); 1386 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); 1387 codegen.loops[idx] = Value(); 1388 } 1389 1390 /// Converts the result computed by the sparse kernel into the required form. 1391 static void genResult(CodeGen &codegen, PatternRewriter &rewriter, 1392 linalg::GenericOp op) { 1393 RankedTensorType resType = op.getOutputTensorTypes()[0]; 1394 Value result = codegen.buffers.back(); 1395 if (getSparseTensorEncoding(resType)) 1396 result = rewriter.create<ToTensorOp>(op.getLoc(), resType, result); 1397 else 1398 result = 1399 rewriter.create<memref::TensorLoadOp>(op.getLoc(), resType, result); 1400 rewriter.replaceOp(op, result); 1401 } 1402 1403 namespace { 1404 1405 /// Sparse rewriting rule for generic Lingalg operation. 1406 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1407 public: 1408 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1409 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1410 1411 LogicalResult matchAndRewrite(linalg::GenericOp op, 1412 PatternRewriter &rewriter) const override { 1413 // Detects sparse annotations and translate the per-dimension sparsity 1414 // information for all tensors to loop indices in the kernel. 1415 assert(op.getNumOutputs() == 1); 1416 assert(llvm::none_of(op.getInputAndOutputOperands(), 1417 [&](OpOperand *t) { return op.isScalar(t); })); 1418 unsigned numTensors = op.getNumInputsAndOutputs(); 1419 unsigned numLoops = op.iterator_types().getValue().size(); 1420 Merger merger(numTensors, numLoops); 1421 if (!findSparseAnnotations(merger, op)) 1422 return failure(); 1423 1424 // Computes a topologically sorted iteration graph to ensure 1425 // tensors are visited in natural index order. Fails on cycles. 1426 // This assumes that higher-level passes have already put the 1427 // tensors in each tensor expression in a feasible order. 1428 std::vector<unsigned> topSort; 1429 if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) && 1430 !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true)) 1431 return failure(); 1432 1433 // Finds the terminating yield statement and builds the tensor 1434 // expression for the Linalg operation in SSA form. 1435 Operation *yield = op.region().front().getTerminator(); 1436 Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0)); 1437 if (!exp.hasValue()) 1438 return failure(); // build failure 1439 1440 // Recursively generates code. 1441 CodeGen codegen(options, numTensors, numLoops); 1442 if (!genBuffers(merger, codegen, rewriter, op)) 1443 return failure(); // could not bufferize 1444 genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); 1445 genResult(codegen, rewriter, op); 1446 return success(); 1447 } 1448 1449 private: 1450 /// Options to control sparse code generation. 1451 SparsificationOptions options; 1452 }; 1453 1454 } // namespace 1455 1456 /// Populates the given patterns list with rewriting rules required for 1457 /// the sparsification of linear algebra operations. 1458 void mlir::populateSparsificationPatterns( 1459 RewritePatternSet &patterns, const SparsificationOptions &options) { 1460 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1461 } 1462