1 //===- Sparsification.cpp - Implementation of linalg sparsification -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering annotated linalg dialect to sparse code. 10 // 11 // The concept of letting a compiler generate sparse code automatically was 12 // pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and 13 // formalized to tensor algebra by [Kjolstad17,20] for the Sparse Tensor 14 // Algebra Compiler (TACO). The implementation in this file closely follows 15 // the "sparse iteration theory" that forms the foundation of TACO. A rewriting 16 // rule is applied to each tensor expression in linalg (MLIR's tensor index 17 // notation) where the sparsity of tensors is indicated with annotation using 18 // a per-dimension specification of sparse/dense storage together with a 19 // specification of the order on the dimensions. Subsequently, a topologically 20 // sorted iteration graph, reflecting the required order on indices with respect 21 // to the dimensions of each tensor, is constructed to ensure that all tensors 22 // are visited in natural index order. Next, iteration lattices are constructed 23 // for the tensor expression for every index in topological order. Each 24 // iteration lattice point consists of a conjunction of tensor indices together 25 // with a tensor (sub)expression that needs to be evaluated for that 26 // conjunction. Within the lattice, iteration points are ordered according to 27 // the way indices are exhausted. As such these iteration lattices drive actual 28 // sparse code generation, which consists of a tedious but relatively 29 // straightforward one-to-one mapping from iteration lattices to combinations 30 // of for-loops, while-loops, and if-statements. 31 // 32 // [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations. 33 // PhD thesis, Leiden University, May 1996 (aartbik.com/sparse.php). 34 // [Kjolstad17] Fredrik Berg Kjolstad, Shoaib Ashraf Kamil, Stephen Chou, 35 // David Lugato, and Saman Amarasinghe. The Tensor Algebra Compiler. 36 // Proceedings of the ACM on Programming Languages, October 2017. 37 // [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation. 38 // PhD thesis, MIT, February, 2020 (tensor-compiler.org). 39 // 40 // Implementation detail: We use llvm::SmallVector for vectors with 41 // variable lengths and std::vector for vectors with fixed lengths. 42 //===----------------------------------------------------------------------===// 43 44 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 45 #include "mlir/Dialect/Linalg/Utils/Utils.h" 46 #include "mlir/Dialect/SCF/SCF.h" 47 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 48 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 49 #include "mlir/Dialect/StandardOps/IR/Ops.h" 50 #include "mlir/Dialect/Vector/VectorOps.h" 51 #include "mlir/IR/Matchers.h" 52 #include "llvm/ADT/SmallBitVector.h" 53 54 using namespace mlir; 55 56 namespace { 57 58 enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI }; 59 enum class Dim { kSparse, kDense, kUndef }; 60 61 /// Tensor expression. Represents a MLIR expression in tensor index notation. 62 /// For tensors, e0 denotes the tensor index. For invariants, the IR value is 63 /// stored directly. For binary operations, e0 and e1 denote the index of the 64 /// children tensor expressions. 65 struct TensorExp { 66 TensorExp(Kind k, unsigned x, unsigned y, Value v) 67 : kind(k), e0(x), e1(y), val(v) { 68 assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) || 69 (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) || 70 (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val)); 71 } 72 Kind kind; 73 /// Indices of children expression(s). 74 unsigned e0; 75 unsigned e1; 76 /// Direct link to IR for an invariant. During code generation, 77 /// field is used to cache "hoisted" loop invariant tensor loads. 78 Value val; 79 }; 80 81 /// Lattice point. Each lattice point consists of a conjunction of tensor 82 /// loop indices (encoded in a bitvector) and the index of the corresponding 83 /// tensor expression. 84 struct LatPoint { 85 LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) { 86 bits.set(b); 87 } 88 LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {} 89 /// Conjunction of tensor loop indices as bitvector. This represents 90 /// all indices involved in the tensor expression 91 llvm::BitVector bits; 92 /// Simplified conjunction of tensor loop indices as bitvector. This 93 /// represents a simplified condition under which this tensor expression 94 /// must execute. Pre-computed during codegen to avoid repeated eval. 95 llvm::BitVector simple; 96 /// Index of the tensor expresssion. 97 unsigned exp; 98 }; 99 100 /// A class to handle all iteration lattice operations. This class abstracts 101 /// away from some implementation details of storing iteration lattices and 102 /// tensor expressions. This allows for fine-tuning performance characteristics 103 /// independently from the basic algorithm if bottlenecks are identified. 104 class Merger { 105 public: 106 /// Constructs a merger for the given number of tensors and loops. The 107 /// user supplies the number of tensors involved in the kernel, with the 108 /// last tensor in this set denoting the output tensor. The merger adds an 109 /// additional synthetic tensor at the end of this set to represent all 110 /// invariant expressions in the kernel. 111 Merger(unsigned t, unsigned l) 112 : outTensor(t - 1), numTensors(t + 1), numLoops(l), 113 dims(t + 1, std::vector<Dim>(l, Dim::kUndef)) {} 114 115 /// Adds a tensor expression. Returns its index. 116 unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) { 117 unsigned e = tensorExps.size(); 118 tensorExps.push_back(TensorExp(k, e0, e1, v)); 119 return e; 120 } 121 unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); } 122 123 /// Adds an iteration lattice point. Returns its index. 124 unsigned addLat(unsigned t, unsigned i, unsigned e) { 125 assert(t < numTensors && i < numLoops); 126 unsigned p = latPoints.size(); 127 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 128 return p; 129 } 130 131 /// Adds a new, initially empty, set. Returns its index. 132 unsigned addSet() { 133 unsigned s = latSets.size(); 134 latSets.emplace_back(SmallVector<unsigned, 16>()); 135 return s; 136 } 137 138 /// Computes a single conjunction of two lattice points by taking the "union" 139 /// of loop indices (effectively constructing a larger "intersection" of those 140 /// indices) with a newly constructed tensor (sub)expression of given kind. 141 /// Returns the index of the new lattice point. 142 unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) { 143 unsigned p = latPoints.size(); 144 llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits); 145 nb |= latPoints[p1].bits; 146 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); 147 latPoints.push_back(LatPoint(nb, e)); 148 return p; 149 } 150 151 /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of 152 /// cartesian product. Returns the index of the new set. 153 unsigned takeConj(Kind kind, unsigned s0, unsigned s1) { 154 unsigned s = addSet(); 155 for (unsigned p0 : latSets[s0]) 156 for (unsigned p1 : latSets[s1]) 157 latSets[s].push_back(conjLatPoint(kind, p0, p1)); 158 return s; 159 } 160 161 /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1). 162 /// Returns the index of the new set. 163 unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) { 164 unsigned s = takeConj(kind, s0, s1); 165 for (unsigned p : latSets[s0]) 166 latSets[s].push_back(p); 167 for (unsigned p : latSets[s1]) 168 latSets[s].push_back(p); 169 return s; 170 } 171 172 /// Optimizes the iteration lattice points in the given set. This 173 /// method should be called right before code generation to avoid 174 /// generating redundant loops and conditions. 175 unsigned optimizeSet(unsigned s0) { 176 unsigned s = addSet(); 177 assert(latSets[s0].size() != 0); 178 unsigned p0 = latSets[s0][0]; 179 for (unsigned p1 : latSets[s0]) { 180 bool add = true; 181 if (p0 != p1) { 182 // Is this a straightforward copy? 183 unsigned e = latPoints[p1].exp; 184 if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor) 185 continue; 186 // Conjunction already covered? 187 for (unsigned p2 : latSets[s]) { 188 assert(!latGT(p1, p2)); // Lj => Li would be bad 189 if (onlyDenseDiff(p2, p1)) { 190 add = false; 191 break; 192 } 193 } 194 assert(!add || latGT(p0, p1)); 195 } 196 if (add) 197 latSets[s].push_back(p1); 198 } 199 for (unsigned p : latSets[s]) 200 latPoints[p].simple = simplifyCond(s, p); 201 return s; 202 } 203 204 /// Simplifies the conditions in a conjunction of a given lattice point 205 /// within the given set using just two basic rules: 206 /// (1) multiple dense conditions are reduced to single dense, and 207 /// (2) a *singleton* sparse/dense is reduced to sparse/random access. 208 llvm::BitVector simplifyCond(unsigned s, unsigned p0) { 209 // First determine if this lattice point is a *singleton*, i.e., 210 // the last point in a lattice, no other is less than this one. 211 bool isSingleton = true; 212 for (unsigned p1 : latSets[s]) { 213 if (p0 != p1 && latGT(p0, p1)) { 214 isSingleton = false; 215 break; 216 } 217 } 218 // Now apply the two basic rules. 219 llvm::BitVector simple = latPoints[p0].bits; 220 bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse); 221 for (unsigned b = 0, be = simple.size(); b < be; b++) { 222 if (simple[b] && !isDim(b, Dim::kSparse)) { 223 if (reset) 224 simple.reset(b); 225 reset = true; 226 } 227 } 228 return simple; 229 } 230 231 /// Returns true if Li > Lj. 232 bool latGT(unsigned i, unsigned j) const { 233 const llvm::BitVector &bitsi = latPoints[i].bits; 234 const llvm::BitVector &bitsj = latPoints[j].bits; 235 assert(bitsi.size() == bitsj.size()); 236 if (bitsi.count() > bitsj.count()) { 237 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 238 if (bitsj[b] && !bitsi[b]) 239 return false; 240 return true; 241 } 242 return false; 243 } 244 245 /// Returns true if Li and Lj only differ in dense. 246 bool onlyDenseDiff(unsigned i, unsigned j) { 247 llvm::BitVector tmp = latPoints[j].bits; 248 tmp ^= latPoints[i].bits; 249 return !hasAnyDimOf(tmp, Dim::kSparse); 250 } 251 252 /// Bit translation. 253 unsigned tensor(unsigned b) const { return b % numTensors; } 254 unsigned index(unsigned b) const { return b / numTensors; } 255 256 /// Returns true if bit corresponds to queried dim. 257 bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); } 258 259 /// Returns true if tensor access at given index has queried dim. 260 bool isDim(unsigned t, unsigned i, Dim d) const { 261 assert(t < numTensors && i < numLoops); 262 return dims[t][i] == d; 263 } 264 265 /// Returns true if any set bit corresponds to queried dim. 266 bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const { 267 for (unsigned b = 0, be = bits.size(); b < be; b++) 268 if (bits[b] && isDim(b, d)) 269 return true; 270 return false; 271 } 272 273 /// Returns true if tensor has any sparse dimension. 274 bool isSparseTensor(unsigned t) const { 275 return llvm::any_of(dims[t], [](Dim d) { return d == Dim::kSparse; }); 276 } 277 278 /// Setter 279 void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; } 280 281 /// Getters. 282 TensorExp &exp(unsigned e) { return tensorExps[e]; } 283 LatPoint &lat(unsigned l) { return latPoints[l]; } 284 SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; } 285 286 private: 287 const unsigned outTensor; 288 const unsigned numTensors; 289 const unsigned numLoops; 290 291 std::vector<std::vector<Dim>> dims; 292 llvm::SmallVector<TensorExp, 32> tensorExps; 293 llvm::SmallVector<LatPoint, 16> latPoints; 294 llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets; 295 }; 296 297 // Code generation. 298 struct CodeGen { 299 CodeGen(mlir::SparsificationOptions o, unsigned numTensors, unsigned numLoops) 300 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 301 pointers(numTensors, std::vector<Value>(numLoops)), 302 indices(numTensors, std::vector<Value>(numLoops)), 303 highs(numTensors, std::vector<Value>(numLoops)), 304 pidxs(numTensors, std::vector<Value>(numLoops)), 305 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 306 curVecLength(1), curVecMask() {} 307 /// Sparsification options. 308 mlir::SparsificationOptions options; 309 /// Universal dense indices and upper bounds (by index). The loops array 310 /// is updated with the value of the universal dense index in the current 311 /// loop. The sizes array is set once with the inferred dimension sizes. 312 std::vector<Value> loops; 313 std::vector<Value> sizes; 314 /// Buffers for storing dense and sparse numerical values (by tensor). 315 /// This array is set once during bufferization of all tensors. 316 std::vector<Value> buffers; 317 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 318 /// This array is set once during bufferization of all sparse tensors. 319 std::vector<std::vector<Value>> pointers; 320 std::vector<std::vector<Value>> indices; 321 /// Sparse iteration information (by tensor and index). These arrays 322 /// are updated to remain current within the current loop. 323 std::vector<std::vector<Value>> highs; 324 std::vector<std::vector<Value>> pidxs; 325 std::vector<std::vector<Value>> idxs; 326 /// Current reduction, updated during code generation. When indices of a 327 /// reduction are exhausted, all inner loops can "scalarize" the reduction. 328 // TODO: currently only done for (a chain of) innermost for-loops, where it 329 // is most effective; we could generalize to more outer and while-loops. 330 unsigned redExp; 331 Value redVal; 332 // Current vector length and mask. 333 unsigned curVecLength; 334 Value curVecMask; 335 }; 336 337 } // namespace 338 339 /// Helper method to inspect sparse annotations in the linalg operation. 340 /// Fills the per-dimension sparsity information for all tensors. 341 static void findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 342 unsigned numTensors = op.getNumShapedOperands(); 343 ArrayAttr sparseAttr = op.sparseAttr(); 344 for (unsigned t = 0; t < numTensors; t++) { 345 auto map = op.getIndexingMap(t); 346 auto dimAttr = sparseAttr[t].cast<ArrayAttr>(); 347 // For each tensor, we accept a per-dimension Sparse or Dense annotation. 348 // This is translated to the loop index that indexes that dimension. 349 unsigned rank = op.getShapedType(t).getRank(); 350 for (unsigned d = 0; d < rank; d++) { 351 unsigned idx = map.getDimPosition(d); 352 if (isSparseDim(dimAttr[d])) { 353 merger.setDim(t, idx, Dim::kSparse); 354 } else { 355 assert(isDenseDim(dimAttr[d])); 356 merger.setDim(t, idx, Dim::kDense); 357 } 358 } 359 } 360 } 361 362 /// Returns true if tensor was set up with sparse storage scheme. 363 static bool linkedSparse(linalg::GenericOp op, unsigned tensor) { 364 if (tensor < op.getNumInputs()) 365 return isa_and_nonnull<sparse_tensor::FromPointerOp>( 366 op.getInput(tensor).getDefiningOp()); 367 return false; 368 } 369 370 /// A DFS helper to compute a topological sort. Note that recursion is 371 /// bounded by the number of implicit loops, which is always small. 372 /// Returns false when a cycle is detected. 373 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 374 std::vector<unsigned> &topSort, 375 std::vector<std::vector<bool>> &adjM) { 376 if (visit[i] != 0) 377 return visit[i] != 1; // 1 denotes cycle! 378 visit[i] = 1; 379 for (unsigned j = 0, e = visit.size(); j < e; j++) 380 if (adjM[i][j]) 381 if (!topSortDFS(j, visit, topSort, adjM)) 382 return false; 383 visit[i] = 2; 384 topSort.push_back(i); 385 return true; 386 } 387 388 /// Computes a topologically sorted iteration graph for the linalg operation. 389 /// Ensures all tensors are visited in natural index order. This is essential 390 /// for sparse storage formats since these only support access along fixed 391 /// dimensions. Even for dense storage formats, however, the natural index 392 /// order yields innermost unit-stride access with better spatial locality. 393 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 394 std::vector<unsigned> &topSort, 395 bool sparseOnly) { 396 // Set up an n x n from/to adjacency matrix of the iteration graph 397 // for the implicit loop indices i_0 .. i_n-1. 398 unsigned n = op.getNumLoops(); 399 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 400 401 // Iterate over the indexing maps of every tensor in the tensor expression. 402 unsigned numTensors = op.getNumShapedOperands(); 403 for (unsigned t = 0; t < numTensors; t++) { 404 auto map = op.getIndexingMap(t); 405 assert(map.getNumDims() == n); 406 // Skip dense tensor constraints when sparse only is requested. 407 if (sparseOnly && !merger.isSparseTensor(t) && !linkedSparse(op, t)) 408 continue; 409 // At the moment, we take the index variables in the tensor access 410 // expression in the order in which they appear (conceptually a 411 // "row-major" layout of every tensor). So, a tensor access A_ijk 412 // forces the ordering i < j < k on the loop indices. 413 // TODO: support affine map to define alternative dimension orders. 414 for (unsigned d = 1, e = map.getNumResults(); d < e; d++) { 415 unsigned f = map.getDimPosition(d - 1); 416 unsigned t = map.getDimPosition(d); 417 adjM[f][t] = true; 418 } 419 } 420 421 // Topologically sort the iteration graph to determine loop order. 422 // Report failure for a cyclic iteration graph. 423 topSort.clear(); 424 topSort.reserve(n); 425 std::vector<unsigned> visit(n, 0); 426 for (unsigned i = 0; i < n; i++) 427 if (visit[i] == 0) 428 if (!topSortDFS(i, visit, topSort, adjM)) 429 return false; // cycle! 430 std::reverse(std::begin(topSort), std::end(topSort)); 431 return true; 432 } 433 434 /// Traverses the SSA tree (possibly a DAG) to build a tensor expression. 435 /// This simplifies constructing (sub)expressions during iteration lattice 436 /// building (compared to using the SSA representation everywhere). 437 static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op, 438 Value val) { 439 if (auto arg = val.dyn_cast<BlockArgument>()) { 440 unsigned argN = arg.getArgNumber(); 441 if (arg.getOwner()->getParentOp() == op) { 442 // Any parameter of the generic op is considered a tensor, 443 // indexed by the implicit loop bounds. 444 auto map = op.getIndexingMap(argN); 445 if (map.isProjectedPermutation()) 446 return merger.addExp(Kind::kTensor, argN); 447 // Cannot handle (yet). 448 return None; 449 } 450 // Any parameter of a higher op is invariant. 451 return merger.addExp(Kind::kInvariant, val); 452 } 453 Operation *def = val.getDefiningOp(); 454 if (def->getBlock() != &op.region().front()) { 455 // Something defined outside is invariant. 456 return merger.addExp(Kind::kInvariant, val); 457 } else if (def->getNumOperands() == 2) { 458 // Construct binary operations if subexpressions could be built. 459 auto x = buildTensorExp(merger, op, def->getOperand(0)); 460 auto y = buildTensorExp(merger, op, def->getOperand(1)); 461 if (x.hasValue() && y.hasValue()) { 462 unsigned e0 = x.getValue(); 463 unsigned e1 = y.getValue(); 464 if (isa<MulFOp>(def)) 465 return merger.addExp(Kind::kMulF, e0, e1); 466 if (isa<MulIOp>(def)) 467 return merger.addExp(Kind::kMulI, e0, e1); 468 if (isa<AddFOp>(def)) 469 return merger.addExp(Kind::kAddF, e0, e1); 470 if (isa<AddIOp>(def)) 471 return merger.addExp(Kind::kAddI, e0, e1); 472 } 473 } 474 // Cannot build (yet). 475 return None; 476 } 477 478 /// Builds the iteration lattices in a bottom-up traversal given the remaining 479 /// tensor (sub)expression and the next loop index in the iteration graph. 480 static unsigned buildLattices(Merger &merger, linalg::GenericOp op, 481 unsigned exp, unsigned idx) { 482 Kind kind = merger.exp(exp).kind; 483 if (kind == Kind::kTensor || kind == Kind::kInvariant) { 484 // Either the index is really used in the tensor expression, or it is 485 // set to the undefined index in that dimension. An invariant expression 486 // is set to a synthetic tensor with undefined indices only. 487 unsigned s = merger.addSet(); 488 unsigned t = 489 kind == Kind::kTensor ? merger.exp(exp).e0 : op.getNumShapedOperands(); 490 merger.set(s).push_back(merger.addLat(t, idx, exp)); 491 return s; 492 } 493 unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx); 494 unsigned s1 = buildLattices(merger, op, merger.exp(exp).e1, idx); 495 switch (kind) { 496 case Kind::kTensor: 497 case Kind::kInvariant: 498 llvm_unreachable("handled above"); 499 case Kind::kMulF: 500 case Kind::kMulI: 501 return merger.takeConj(kind, s0, s1); 502 case Kind::kAddF: 503 case Kind::kAddI: 504 return merger.takeDisj(kind, s0, s1); 505 } 506 llvm_unreachable("unexpected expression kind"); 507 } 508 509 /// Maps sparse integer option to actual integral storage type. 510 static Type genIntType(PatternRewriter &rewriter, SparseIntType tp) { 511 switch (tp) { 512 case SparseIntType::kNative: 513 return rewriter.getIndexType(); 514 case SparseIntType::kI64: 515 return rewriter.getIntegerType(64); 516 case SparseIntType::kI32: 517 return rewriter.getIntegerType(32); 518 case SparseIntType::kI16: 519 return rewriter.getIntegerType(16); 520 case SparseIntType::kI8: 521 return rewriter.getIntegerType(8); 522 } 523 llvm_unreachable("unexpected SparseIntType"); 524 } 525 526 /// Generates buffer for the output tensor. 527 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 528 linalg::GenericOp op, MemRefType denseTp, 529 ArrayRef<Value> args) { 530 Location loc = op.getLoc(); 531 Value tensor = op.getOutput(0); 532 // The output tensor simply could materialize from the buffer that will 533 // be generated for the tensor present in the outs() clause. This has 534 // the major advantage that the sparse kernel only updates the nonzero 535 // positions for the output tensor. Currently this results in functional, 536 // but slightly imprecise IR, so it is put under an experimental option. 537 if (codegen.options.fastOutput) 538 return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 539 // By default, a new buffer is allocated which is initialized to the 540 // tensor defined in the outs() clause. This is always correct but 541 // introduces a dense initialization component that may negatively 542 // impact the running complexity of the sparse kernel. 543 Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 544 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 545 rewriter.create<linalg::CopyOp>(loc, init, alloc); 546 return alloc; 547 } 548 549 /// Local bufferization of all dense and sparse data structures. 550 /// This code enables testing the first prototype sparse compiler. 551 // TODO: replace this with a proliferated bufferization strategy 552 static void genBuffers(Merger &merger, CodeGen &codegen, 553 PatternRewriter &rewriter, linalg::GenericOp op) { 554 Location loc = op.getLoc(); 555 unsigned numTensors = op.getNumShapedOperands(); 556 unsigned numInputs = op.getNumInputs(); 557 assert(numTensors == numInputs + 1); 558 // For every tensor, find lower and upper bound on dimensions, set the 559 // same bounds on loop indices, and obtain dense or sparse buffer(s). 560 SmallVector<Value, 4> args; 561 for (unsigned t = 0; t < numTensors; t++) { 562 Value tensor = t < numInputs ? op.getInput(t) : op.getOutput(0); 563 auto tensorType = op.getShapedType(t); 564 auto shape = tensorType.getShape(); 565 auto map = op.getIndexingMap(t); 566 // Scan all dimensions of current tensor. 567 bool dense = !linkedSparse(op, t); 568 args.clear(); 569 for (unsigned d = 0, rank = shape.size(); d < rank; d++) { 570 unsigned i = map.getDimPosition(d); 571 // Handle sparse storage schemes. 572 if (merger.isDim(t, i, Dim::kSparse)) { 573 dense = false; 574 auto dynShape = {ShapedType::kDynamicSize}; 575 auto ptrTp = MemRefType::get( 576 dynShape, genIntType(rewriter, codegen.options.ptrType)); 577 auto indTp = MemRefType::get( 578 dynShape, genIntType(rewriter, codegen.options.indType)); 579 Value dim = rewriter.create<ConstantIndexOp>(loc, d); 580 // Generate sparse primitives to obtains pointer and indices. 581 codegen.pointers[t][i] = rewriter.create<sparse_tensor::ToPointersOp>( 582 loc, ptrTp, tensor, dim); 583 codegen.indices[t][i] = rewriter.create<sparse_tensor::ToIndicesOp>( 584 loc, indTp, tensor, dim); 585 } 586 // Find lower and upper bound in current dimension. 587 Value up; 588 if (shape[d] == MemRefType::kDynamicSize) { 589 up = rewriter.create<memref::DimOp>(loc, tensor, d); 590 args.push_back(up); 591 } else { 592 up = rewriter.create<ConstantIndexOp>(loc, shape[d]); 593 } 594 codegen.sizes[i] = codegen.highs[t][i] = up; 595 } 596 // Perform the required bufferization. All dense inputs materialize 597 // from the input tensor. The dense output tensor needs special 598 // handling. Sparse inputs use a sparse primitive to obtain the values. 599 if (dense) { 600 auto denseTp = MemRefType::get(shape, tensorType.getElementType()); 601 if (t < numInputs) 602 codegen.buffers[t] = 603 rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 604 else 605 codegen.buffers[t] = 606 genOutputBuffer(codegen, rewriter, op, denseTp, args); 607 } else { 608 auto dynShape = {ShapedType::kDynamicSize}; 609 auto sparseTp = MemRefType::get(dynShape, tensorType.getElementType()); 610 codegen.buffers[t] = 611 rewriter.create<sparse_tensor::ToValuesOp>(loc, sparseTp, tensor); 612 } 613 } 614 } 615 616 /// Constructs vector type. 617 static VectorType vectorType(CodeGen &codegen, Type etp) { 618 return VectorType::get(codegen.curVecLength, etp); 619 } 620 621 /// Constructs vector type from pointer. 622 static VectorType vectorType(CodeGen &codegen, Value ptr) { 623 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 624 } 625 626 /// Constructs vector iteration mask. 627 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 628 Value iv, Value lo, Value hi, Value step) { 629 Location loc = iv.getLoc(); 630 VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1)); 631 // Special case if the vector length evenly divides the trip count (for 632 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 633 // so that all subsequent masked memory operations are immediately folded 634 // into unconditional memory operations. 635 IntegerAttr loInt, hiInt, stepInt; 636 if (matchPattern(lo, m_Constant(&loInt)) && 637 matchPattern(hi, m_Constant(&hiInt)) && 638 matchPattern(step, m_Constant(&stepInt))) { 639 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 640 return rewriter.create<vector::BroadcastOp>( 641 loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1)); 642 } 643 // Otherwise, generate a vector mask that avoids overrunning the upperbound 644 // during vector execution. Here we rely on subsequent loop optimizations to 645 // avoid executing the mask in all iterations, for example, by splitting the 646 // loop into an unconditional vector loop and a scalar cleanup loop. 647 Value end = rewriter.create<SubIOp>(loc, hi, iv); 648 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 649 } 650 651 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 652 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 653 Value ptr, ArrayRef<Value> args) { 654 Location loc = ptr.getLoc(); 655 VectorType vtp = vectorType(codegen, ptr); 656 Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 657 if (args.back().getType().isa<VectorType>()) { 658 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 659 Value indexVec = args.back(); 660 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 661 return rewriter.create<vector::GatherOp>( 662 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 663 } 664 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 665 codegen.curVecMask, pass); 666 } 667 668 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 669 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 670 Value rhs, Value ptr, ArrayRef<Value> args) { 671 Location loc = ptr.getLoc(); 672 if (args.back().getType().isa<VectorType>()) { 673 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 674 Value indexVec = args.back(); 675 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 676 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 677 codegen.curVecMask, rhs); 678 return; 679 } 680 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 681 rhs); 682 } 683 684 /// Generates a vectorized invariant. Here we rely on subsequent loop 685 /// optimizations to hoist the invariant broadcast out of the vector loop. 686 static Value genVectorInvariantValue(CodeGen &codegen, 687 PatternRewriter &rewriter, Value val) { 688 VectorType vtp = vectorType(codegen, val.getType()); 689 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 690 } 691 692 /// Generates a load on a dense or sparse tensor. 693 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 694 PatternRewriter &rewriter, linalg::GenericOp op, 695 unsigned exp) { 696 // Test if the load was hoisted to a higher loop nest. 697 Value val = merger.exp(exp).val; 698 if (val) { 699 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 700 return genVectorInvariantValue(codegen, rewriter, val); 701 return val; 702 } 703 // Actual load. 704 SmallVector<Value, 4> args; 705 unsigned tensor = merger.exp(exp).e0; 706 auto map = op.getIndexingMap(tensor); 707 bool sparse = linkedSparse(op, tensor); 708 for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) { 709 unsigned idx = map.getDimPosition(i); 710 args.push_back(codegen.loops[idx]); // universal dense index 711 if (sparse || merger.isDim(tensor, idx, Dim::kSparse)) { 712 sparse = true; 713 args.clear(); 714 args.push_back(codegen.pidxs[tensor][idx]); // position index 715 } 716 } 717 Location loc = op.getLoc(); 718 Value ptr = codegen.buffers[tensor]; 719 if (codegen.curVecLength > 1) 720 return genVectorLoad(codegen, rewriter, ptr, args); 721 return rewriter.create<memref::LoadOp>(loc, ptr, args); 722 } 723 724 /// Generates a store on a dense tensor. 725 static void genTensorStore(Merger &merger, CodeGen &codegen, 726 PatternRewriter &rewriter, linalg::GenericOp op, 727 unsigned tensor, Value rhs) { 728 Location loc = op.getLoc(); 729 // Test if this is a scalarized reduction. 730 unsigned lhs = op.getNumShapedOperands() - 1; 731 if (lhs == tensor && codegen.redVal) { 732 if (codegen.curVecLength > 1) 733 rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs, 734 codegen.redVal); 735 codegen.redVal = rhs; 736 return; 737 } 738 // Actual store. 739 SmallVector<Value, 4> args; 740 auto map = op.getIndexingMap(tensor); 741 for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) { 742 unsigned idx = map.getDimPosition(i); 743 args.push_back(codegen.loops[idx]); // universal dense index 744 } 745 Value ptr = codegen.buffers[tensor]; 746 if (codegen.curVecLength > 1) 747 genVectorStore(codegen, rewriter, rhs, ptr, args); 748 else 749 rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 750 } 751 752 /// Generates a pointer/index load from the sparse storage scheme. Narrower 753 /// data types need to be zero extended before casting the value into the 754 /// index type used for looping and indexing. 755 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 756 Value ptr, Value s) { 757 // See https://llvm.org/docs/GetElementPtr.html for some background on 758 // the complications described below. 759 if (codegen.curVecLength > 1) { 760 // Since the index vector is used in a subsequent gather/scatter operations, 761 // which effectively defines an unsigned pointer + signed index, we must 762 // zero extend the vector to an index width. For 8-bit and 16-bit values, 763 // an 32-bit index width suffices. For 32-bit values, zero extending the 764 // elements into 64-bit loses some performance since the 32-bit indexed 765 // gather/scatter is more efficient than the 64-bit index variant (in 766 // the future, we could introduce a flag that states the negative space 767 // of 32-bit indices is unused). For 64-bit values, there is no good way 768 // to state that the indices are unsigned, with creates the potential of 769 // incorrect address calculations in the unlikely case we need such 770 // extremely large offsets. 771 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 772 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 773 if (!etp.isa<IndexType>()) { 774 if (etp.getIntOrFloatBitWidth() < 32) 775 vload = rewriter.create<ZeroExtendIOp>( 776 loc, vload, vectorType(codegen, rewriter.getIntegerType(32))); 777 else if (etp.getIntOrFloatBitWidth() < 64) 778 vload = rewriter.create<ZeroExtendIOp>( 779 loc, vload, vectorType(codegen, rewriter.getIntegerType(64))); 780 } 781 return vload; 782 } 783 // For the scalar case, we simply zero extend narrower indices into 64-bit 784 // values before casting to index without a performance penalty. Here too, 785 // however, indices that already are 64-bit, in theory, cannot express the 786 // full range as explained above. 787 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 788 if (!load.getType().isa<IndexType>()) { 789 if (load.getType().getIntOrFloatBitWidth() < 64) 790 load = rewriter.create<ZeroExtendIOp>(loc, load, 791 rewriter.getIntegerType(64)); 792 load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType()); 793 } 794 return load; 795 } 796 797 /// Generates an invariant value. 798 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 799 PatternRewriter &rewriter, unsigned exp) { 800 Value val = merger.exp(exp).val; 801 if (codegen.curVecLength > 1) 802 return genVectorInvariantValue(codegen, rewriter, val); 803 return val; 804 } 805 806 /// Generates an address computation "sz * p + i". 807 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 808 Location loc, Value size, Value p, Value i) { 809 Value mul = rewriter.create<MulIOp>(loc, size, p); 810 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 811 Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType()); 812 mul = genVectorInvariantValue(codegen, rewriter, inv); 813 } 814 return rewriter.create<AddIOp>(loc, mul, i); 815 } 816 817 /// Generates start of a reduction. 818 static Value genReductionStart(Merger &merger, CodeGen &codegen, 819 PatternRewriter &rewriter, 820 linalg::GenericOp op) { 821 if (codegen.redVal) 822 return codegen.redVal; // chained with previous for-loop 823 if (codegen.curVecLength > 1) { 824 // TODO: assumes + reductions for now 825 VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]); 826 return rewriter.create<ConstantOp>(op.getLoc(), vtp, 827 rewriter.getZeroAttr(vtp)); 828 } 829 return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 830 } 831 832 /// Generates end of a reduction. 833 static void genReductionEnd(Merger &merger, CodeGen &codegen, 834 PatternRewriter &rewriter, linalg::GenericOp op) { 835 Value red = codegen.redVal; 836 if (!red) 837 return; 838 assert(codegen.curVecLength == 1); 839 codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain 840 unsigned lhs = op.getNumShapedOperands() - 1; 841 if (auto vtp = red.getType().dyn_cast<VectorType>()) { 842 // TODO: assumes + reductions for now 843 StringAttr kind = rewriter.getStringAttr("add"); 844 Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 845 // Integer reductions don't accept an accumulator. 846 if (vtp.getElementType().isa<IntegerType>()) { 847 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 848 kind, red, ValueRange{}); 849 red = rewriter.create<AddIOp>(op.getLoc(), red, ld); 850 } else { 851 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 852 kind, red, ld); 853 } 854 } 855 genTensorStore(merger, codegen, rewriter, op, lhs, red); 856 } 857 858 /// Recursively generates tensor expression. 859 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 860 linalg::GenericOp op, unsigned exp) { 861 if (merger.exp(exp).kind == Kind::kTensor) 862 return genTensorLoad(merger, codegen, rewriter, op, exp); 863 else if (merger.exp(exp).kind == Kind::kInvariant) 864 return genInvariantValue(merger, codegen, rewriter, exp); 865 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0); 866 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1); 867 switch (merger.exp(exp).kind) { 868 case Kind::kTensor: 869 case Kind::kInvariant: 870 llvm_unreachable("handled above"); 871 case Kind::kMulF: 872 return rewriter.create<MulFOp>(op.getLoc(), v0, v1); 873 case Kind::kMulI: 874 return rewriter.create<MulIOp>(op.getLoc(), v0, v1); 875 case Kind::kAddF: 876 return rewriter.create<AddFOp>(op.getLoc(), v0, v1); 877 case Kind::kAddI: 878 return rewriter.create<AddIOp>(op.getLoc(), v0, v1); 879 } 880 llvm_unreachable("unexpected expression kind"); 881 } 882 883 /// Hoists loop invariant tensor loads for which indices have been exhausted. 884 static void genInvariants(Merger &merger, CodeGen &codegen, 885 PatternRewriter &rewriter, linalg::GenericOp op, 886 unsigned exp, unsigned ldx, bool hoist) { 887 if (merger.exp(exp).kind == Kind::kTensor) { 888 // Inspect tensor indices. 889 bool atLevel = ldx == -1u; 890 unsigned tensor = merger.exp(exp).e0; 891 auto map = op.getIndexingMap(tensor); 892 for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) { 893 unsigned idx = map.getDimPosition(i); 894 if (!codegen.loops[idx]) 895 return; // still in play 896 else if (idx == ldx) 897 atLevel = true; 898 } 899 // All exhausted at this level (atLevel denotes exactly at this level). 900 unsigned lhs = op.getNumShapedOperands() - 1; 901 if (lhs == tensor) { 902 codegen.redExp = hoist ? exp : -1u; 903 } else if (atLevel) { 904 merger.exp(exp).val = 905 hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 906 } 907 } else if (merger.exp(exp).kind != Kind::kInvariant) { 908 // Traverse into the binary operations. Note that we only hoist 909 // tensor loads, since subsequent MLIR/LLVM passes know how to 910 // deal with all other kinds of derived loop invariants. 911 unsigned e0 = merger.exp(exp).e0; 912 unsigned e1 = merger.exp(exp).e1; 913 genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist); 914 genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist); 915 } 916 } 917 918 /// Generates initialization code for the subsequent loop sequence at 919 /// current index level. Returns true if the loop sequence needs to 920 /// maintain the universal index. 921 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 922 linalg::GenericOp op, std::vector<unsigned> &topSort, 923 unsigned at, llvm::BitVector &inits) { 924 bool needsUniv = false; 925 Location loc = op.getLoc(); 926 unsigned idx = topSort[at]; 927 928 // Initialize sparse positions. 929 for (unsigned b = 0, be = inits.size(); b < be; b++) { 930 if (inits[b]) { 931 unsigned tensor = merger.tensor(b); 932 assert(idx == merger.index(b)); 933 if (merger.isDim(b, Dim::kSparse)) { 934 // Initialize sparse index. 935 unsigned pat = at; 936 for (; pat != 0; pat--) { 937 if (codegen.pidxs[tensor][topSort[pat - 1]]) 938 break; 939 } 940 Value ptr = codegen.pointers[tensor][idx]; 941 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 942 Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 943 : codegen.pidxs[tensor][topSort[pat - 1]]; 944 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 945 Value p1 = rewriter.create<AddIOp>(loc, p0, one); 946 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 947 } else { 948 // Dense index still in play. 949 needsUniv = true; 950 } 951 } 952 } 953 954 // Initialize the universal dense index. 955 codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0); 956 return needsUniv; 957 } 958 959 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 960 /// operation is a candidate. Whether it is actually converted to SIMD code 961 /// depends on the requested strategy. 962 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 963 switch (codegen.options.vectorizationStrategy) { 964 case SparseVectorizationStrategy::kNone: 965 return false; 966 case SparseVectorizationStrategy::kDenseInnerLoop: 967 return isInner && !isSparse; 968 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 969 return isInner; 970 } 971 llvm_unreachable("unexpected vectorization strategy"); 972 } 973 974 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 975 /// that is marked "parallel" is a candidate. Whether it is actually converted 976 /// to a parallel operation depends on the requested strategy. 977 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 978 bool isSparse, bool isVector) { 979 switch (codegen.options.parallelizationStrategy) { 980 case SparseParallelizationStrategy::kNone: 981 return false; 982 case SparseParallelizationStrategy::kDenseOuterLoop: 983 return isOuter && !isSparse && !isReduction && !isVector; 984 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 985 return isOuter && !isReduction && !isVector; 986 case SparseParallelizationStrategy::kDenseAnyLoop: 987 return !isSparse && !isReduction && !isVector; 988 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 989 return !isReduction && !isVector; 990 } 991 llvm_unreachable("unexpected parallelization strategy"); 992 } 993 994 /// Checks unit strides for dense tensors. The iteration graph may have ignored 995 /// dense access patterns in order to avoid cycles (sparse access patterns are 996 /// always placed innermost), but that means dense access has become strided. 997 /// For now, we reject vectorization of such cases. 998 /// TODO: implement strided load/stores on dense arrays 999 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 1000 unsigned idx) { 1001 unsigned numTensors = op.getNumShapedOperands(); 1002 for (unsigned t = 0; t < numTensors; t++) { 1003 if (!merger.isSparseTensor(t) && !linkedSparse(op, t)) { 1004 auto map = op.getIndexingMap(t); 1005 unsigned r = map.getNumResults(); 1006 for (unsigned i = 0; i < r; i++) { 1007 if (map.getDimPosition(i) == idx && i != r - 1) 1008 return false; 1009 } 1010 } 1011 } 1012 return true; 1013 } 1014 1015 /// Generates a for-loop on a single index. 1016 static Operation *genFor(Merger &merger, CodeGen &codegen, 1017 PatternRewriter &rewriter, linalg::GenericOp op, 1018 bool isOuter, bool isInner, unsigned idx, 1019 llvm::BitVector &indices) { 1020 unsigned fb = indices.find_first(); 1021 unsigned tensor = merger.tensor(fb); 1022 assert(idx == merger.index(fb)); 1023 auto iteratorTypes = op.iterator_types().getValue(); 1024 bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]); 1025 bool isSparse = merger.isDim(fb, Dim::kSparse); 1026 bool isVector = isVectorFor(codegen, isInner, isSparse) && 1027 denseUnitStrides(merger, op, idx); 1028 bool isParallel = 1029 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 1030 1031 // Prepare vector length. 1032 if (isVector) 1033 codegen.curVecLength = codegen.options.vectorLength; 1034 1035 // Loop bounds and increment. 1036 Location loc = op.getLoc(); 1037 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 1038 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 1039 Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength); 1040 1041 // Emit a parallel loop. 1042 if (isParallel) { 1043 assert(!isVector); 1044 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 1045 if (isSparse) 1046 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1047 else 1048 codegen.loops[idx] = parOp.getInductionVars()[0]; 1049 rewriter.setInsertionPointToStart(parOp.getBody()); 1050 return parOp; 1051 } 1052 1053 // Emit a sequential loop, potentially with a scalarized reduction. 1054 bool scalarRed = isInner && codegen.redExp != -1u; 1055 SmallVector<Value, 4> operands; 1056 if (scalarRed) { 1057 Value load = genReductionStart(merger, codegen, rewriter, op); 1058 operands.push_back(load); 1059 } 1060 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 1061 if (scalarRed) { 1062 codegen.redVal = merger.exp(codegen.redExp).val = 1063 forOp.getRegionIterArgs().front(); 1064 } 1065 // Assign induction variable to sparse or dense index. 1066 Value iv = forOp.getInductionVar(); 1067 if (isSparse) 1068 codegen.pidxs[tensor][idx] = iv; 1069 else 1070 codegen.loops[idx] = iv; 1071 rewriter.setInsertionPointToStart(forOp.getBody()); 1072 // Share vector iteration mask between all subsequent loads/stores. 1073 if (isVector) 1074 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1075 return forOp; 1076 } 1077 1078 /// Emit a while-loop for co-iteration over multiple indices. 1079 static Operation *genWhile(Merger &merger, CodeGen &codegen, 1080 PatternRewriter &rewriter, linalg::GenericOp op, 1081 unsigned idx, bool needsUniv, 1082 llvm::BitVector &indices) { 1083 SmallVector<Type, 4> types; 1084 SmallVector<Value, 4> operands; 1085 // Construct the while-loop with a parameter for each index. 1086 Type indexType = rewriter.getIndexType(); 1087 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1088 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1089 unsigned tensor = merger.tensor(b); 1090 assert(idx == merger.index(b)); 1091 types.push_back(indexType); 1092 assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() && 1093 "type mismatch for sparse index"); 1094 operands.push_back(codegen.pidxs[tensor][idx]); 1095 } 1096 } 1097 if (needsUniv) { 1098 types.push_back(indexType); 1099 assert(codegen.loops[idx].getType().isa<IndexType>() && 1100 "type mismatch for universal index"); 1101 operands.push_back(codegen.loops[idx]); 1102 } 1103 Location loc = op.getLoc(); 1104 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1105 Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 1106 Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 1107 1108 // Build the "before" region, which effectively consists 1109 // of a conjunction of "i < upper" tests on all induction. 1110 rewriter.setInsertionPointToStart(&whileOp.before().front()); 1111 Value cond; 1112 unsigned o = 0; 1113 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1114 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1115 unsigned tensor = merger.tensor(b); 1116 assert(idx == merger.index(b)); 1117 Value op1 = before->getArgument(o); 1118 Value op2 = codegen.highs[tensor][idx]; 1119 Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2); 1120 cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc; 1121 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1122 } 1123 } 1124 if (needsUniv) 1125 codegen.loops[idx] = after->getArgument(o++); 1126 assert(o == operands.size()); 1127 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1128 rewriter.setInsertionPointToStart(&whileOp.after().front()); 1129 return whileOp; 1130 } 1131 1132 /// Generates a for-loop or a while-loop, depending on whether it implements 1133 /// singleton iteration or co-iteration over the given conjunction. 1134 static Operation *genLoop(Merger &merger, CodeGen &codegen, 1135 PatternRewriter &rewriter, linalg::GenericOp op, 1136 std::vector<unsigned> &topSort, unsigned at, 1137 bool needsUniv, llvm::BitVector &indices) { 1138 unsigned idx = topSort[at]; 1139 if (indices.count() == 1) { 1140 bool isOuter = at == 0; 1141 bool isInner = at == topSort.size() - 1; 1142 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1143 indices); 1144 } 1145 genReductionEnd(merger, codegen, rewriter, op); // cannot chain 1146 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1147 } 1148 1149 /// Generates the local variables for this loop, consisting of the sparse 1150 /// indices, restored universal dense index, and dense positions. 1151 static void genLocals(Merger &merger, CodeGen &codegen, 1152 PatternRewriter &rewriter, linalg::GenericOp op, 1153 std::vector<unsigned> &topSort, unsigned at, 1154 bool needsUniv, llvm::BitVector &locals) { 1155 Location loc = op.getLoc(); 1156 unsigned idx = topSort[at]; 1157 1158 // Initialize sparse indices. 1159 Value min; 1160 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1161 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1162 unsigned tensor = merger.tensor(b); 1163 assert(idx == merger.index(b)); 1164 Value ptr = codegen.indices[tensor][idx]; 1165 Value s = codegen.pidxs[tensor][idx]; 1166 Value load = genLoad(codegen, rewriter, loc, ptr, s); 1167 codegen.idxs[tensor][idx] = load; 1168 if (!needsUniv) { 1169 if (min) { 1170 Value cmp = 1171 rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min); 1172 min = rewriter.create<SelectOp>(loc, cmp, load, min); 1173 } else { 1174 min = load; 1175 } 1176 } 1177 } 1178 } 1179 1180 // Merge dense universal index over minimum. 1181 if (min) { 1182 assert(!needsUniv); 1183 codegen.loops[idx] = min; 1184 } 1185 1186 // Initialize dense positions. 1187 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1188 if (locals[b] && merger.isDim(b, Dim::kDense)) { 1189 unsigned tensor = merger.tensor(b); 1190 assert(idx == merger.index(b)); 1191 unsigned pat = at; 1192 for (; pat != 0; pat--) 1193 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1194 break; 1195 Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 1196 : codegen.pidxs[tensor][topSort[pat - 1]]; 1197 codegen.pidxs[tensor][idx] = genAddress( 1198 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1199 } 1200 } 1201 } 1202 1203 /// Generates the induction structure for a while-loop. 1204 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1205 PatternRewriter &rewriter, linalg::GenericOp op, 1206 unsigned idx, bool needsUniv, 1207 llvm::BitVector &induction, ResultRange results) { 1208 Location loc = op.getLoc(); 1209 unsigned o = 0; 1210 SmallVector<Value, 4> operands; 1211 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 1212 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1213 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1214 unsigned tensor = merger.tensor(b); 1215 assert(idx == merger.index(b)); 1216 Value op1 = codegen.idxs[tensor][idx]; 1217 Value op2 = codegen.loops[idx]; 1218 Value op3 = codegen.pidxs[tensor][idx]; 1219 Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 1220 Value add = rewriter.create<AddIOp>(loc, op3, one); 1221 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 1222 codegen.pidxs[tensor][idx] = results[o++]; 1223 } 1224 } 1225 if (needsUniv) { 1226 operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one)); 1227 codegen.loops[idx] = results[o++]; 1228 } 1229 assert(o == operands.size()); 1230 rewriter.create<scf::YieldOp>(loc, operands); 1231 } 1232 1233 /// Generates a single if-statement within a while-loop. 1234 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1235 PatternRewriter &rewriter, linalg::GenericOp op, 1236 unsigned idx, llvm::BitVector &conditions) { 1237 Location loc = op.getLoc(); 1238 Value cond; 1239 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1240 if (conditions[b]) { 1241 unsigned tensor = merger.tensor(b); 1242 assert(idx == merger.index(b)); 1243 Value clause; 1244 if (merger.isDim(b, Dim::kSparse)) { 1245 Value op1 = codegen.idxs[tensor][idx]; 1246 Value op2 = codegen.loops[idx]; 1247 clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 1248 } else { 1249 clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true 1250 } 1251 cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause; 1252 } 1253 } 1254 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true); 1255 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 1256 return ifOp; 1257 } 1258 1259 /// Recursively generates code while computing iteration lattices in order 1260 /// to manage the complexity of implementing co-iteration over unions 1261 /// and intersections of sparse iterations spaces. 1262 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1263 linalg::GenericOp op, std::vector<unsigned> &topSort, 1264 unsigned exp, unsigned at) { 1265 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1266 if (at == topSort.size()) { 1267 unsigned lhs = op.getNumShapedOperands() - 1; 1268 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1269 genTensorStore(merger, codegen, rewriter, op, lhs, rhs); 1270 return; 1271 } 1272 assert(codegen.curVecLength == 1); 1273 1274 // Construct iteration lattices for current loop index, with L0 at top. 1275 // Then emit initialization code for the loop sequence at this level. 1276 // We maintain the universal dense index if dense indices are still 1277 // in play for a non-singleton loop sequence. 1278 Location loc = op.getLoc(); 1279 unsigned idx = topSort[at]; 1280 unsigned lts = merger.optimizeSet(buildLattices(merger, op, exp, idx)); 1281 unsigned lsize = merger.set(lts).size(); 1282 assert(lsize != 0); 1283 unsigned l0 = merger.set(lts)[0]; 1284 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1285 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true); 1286 bool needsUniv = false; 1287 if (genInit(merger, codegen, rewriter, op, topSort, at, 1288 merger.lat(l0).bits)) { 1289 // Maintain the universal index only if it is actually 1290 // consumed by a subsequent lattice point. 1291 for (unsigned i = 1; i < lsize; i++) { 1292 unsigned li = merger.set(lts)[i]; 1293 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) { 1294 needsUniv = true; 1295 break; 1296 } 1297 } 1298 } 1299 1300 // Emit a loop for every lattice point L0 >= Li. 1301 for (unsigned i = 0; i < lsize; i++) { 1302 unsigned li = merger.set(lts)[i]; 1303 1304 // Emit loop. 1305 codegen.curVecLength = 1; 1306 llvm::BitVector indices = merger.lat(li).simple; 1307 Operation *loop = 1308 genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices); 1309 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1310 merger.lat(li).bits); 1311 1312 // Visit all lattices points with Li >= Lj to generate the 1313 // loop-body, possibly with if statements for coiteration. 1314 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1315 for (unsigned j = 0; j < lsize; j++) { 1316 unsigned lj = merger.set(lts)[j]; 1317 unsigned ej = merger.lat(lj).exp; 1318 if (li == lj || merger.latGT(li, lj)) { 1319 // Recurse into body of each branch. 1320 if (isWhile) { 1321 scf::IfOp ifOp = 1322 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1323 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1324 rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1325 } else { 1326 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1327 } 1328 } 1329 } 1330 1331 // Wrap-up induction and restore insertion point. 1332 if (isWhile) { 1333 scf::WhileOp whileOp = cast<scf::WhileOp>(loop); 1334 rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1335 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1336 merger.lat(li).bits, whileOp.results()); 1337 } else { 1338 needsUniv = false; 1339 if (codegen.redVal) { 1340 rewriter.create<scf::YieldOp>(loc, codegen.redVal); 1341 codegen.redVal = loop->getResult(0); 1342 } 1343 } 1344 rewriter.setInsertionPointAfter(loop); 1345 } 1346 1347 // Wrap-up loop sequence. 1348 codegen.curVecLength = 1; 1349 genReductionEnd(merger, codegen, rewriter, op); 1350 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); 1351 codegen.loops[idx] = Value(); 1352 } 1353 1354 namespace { 1355 1356 /// Sparse rewriting rule for generic Lingalg operation. 1357 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1358 public: 1359 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1360 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1361 1362 LogicalResult matchAndRewrite(linalg::GenericOp op, 1363 PatternRewriter &rewriter) const override { 1364 // Detects sparse annotations and translate the per-dimension sparsity 1365 // information for all tensors to loop indices in the kernel. 1366 if (!op.hasSparseSemantics()) 1367 return failure(); 1368 assert(op.getNumOutputs() == 1); 1369 unsigned numTensors = op.getNumShapedOperands(); 1370 unsigned numLoops = op.iterator_types().getValue().size(); 1371 Merger merger(numTensors, numLoops); 1372 findSparseAnnotations(merger, op); 1373 1374 // Computes a topologically sorted iteration graph to ensure 1375 // tensors are visited in natural index order. Fails on cycles. 1376 // This assumes that higher-level passes have already put the 1377 // tensors in each tensor expression in a feasible order. 1378 std::vector<unsigned> topSort; 1379 if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) && 1380 !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true)) 1381 return failure(); 1382 1383 // Finds the terminating yield statement and builds the tensor 1384 // expression for the Linalg operation in SSA form. 1385 Operation *yield = op.region().front().getTerminator(); 1386 Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0)); 1387 if (!exp.hasValue()) 1388 return failure(); // build failure 1389 1390 // Recursively generates code. 1391 CodeGen codegen(options, numTensors, numLoops); 1392 genBuffers(merger, codegen, rewriter, op); 1393 genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); 1394 Value result = rewriter.create<memref::TensorLoadOp>( 1395 op.getLoc(), codegen.buffers.back()); 1396 rewriter.replaceOp(op, result); 1397 return success(); 1398 } 1399 1400 private: 1401 /// Options to control sparse code generation. 1402 SparsificationOptions options; 1403 }; 1404 1405 } // namespace 1406 1407 /// Populates the given patterns list with rewriting rules required for 1408 /// the sparsification of linear algebra operations. 1409 void mlir::populateSparsificationPatterns( 1410 RewritePatternSet &patterns, const SparsificationOptions &options) { 1411 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1412 } 1413