1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering sparse tensor types to actual sparse code. 10 // 11 // The concept of letting a compiler generate sparse code automatically was 12 // pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and 13 // formalized to tensor algebra by [Kjolstad17,20] for the Sparse Tensor 14 // Algebra Compiler (TACO). The implementation in this file closely follows 15 // the "sparse iteration theory" that forms the foundation of TACO. A rewriting 16 // rule is applied to each tensor expression in linalg (MLIR's tensor index 17 // notation) where the sparsity of tensors is indicated with annotation using 18 // a per-dimension specification of sparse/dense storage together with a 19 // specification of the order on the dimensions. Subsequently, a topologically 20 // sorted iteration graph, reflecting the required order on indices with respect 21 // to the dimensions of each tensor, is constructed to ensure that all tensors 22 // are visited in natural index order. Next, iteration lattices are constructed 23 // for the tensor expression for every index in topological order. Each 24 // iteration lattice point consists of a conjunction of tensor indices together 25 // with a tensor (sub)expression that needs to be evaluated for that 26 // conjunction. Within the lattice, iteration points are ordered according to 27 // the way indices are exhausted. As such these iteration lattices drive actual 28 // sparse code generation, which consists of a tedious but relatively 29 // straightforward one-to-one mapping from iteration lattices to combinations 30 // of for-loops, while-loops, and if-statements. 31 // 32 // [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations. 33 // PhD thesis, Leiden University, May 1996 (aartbik.com/sparse.php). 34 // [Kjolstad17] Fredrik Berg Kjolstad, Shoaib Ashraf Kamil, Stephen Chou, 35 // David Lugato, and Saman Amarasinghe. The Tensor Algebra Compiler. 36 // Proceedings of the ACM on Programming Languages, October 2017. 37 // [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation. 38 // PhD thesis, MIT, February, 2020 (tensor-compiler.org). 39 // 40 // Implementation detail: We use llvm::SmallVector for vectors with 41 // variable lengths and std::vector for vectors with fixed lengths. 42 //===----------------------------------------------------------------------===// 43 44 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 45 #include "mlir/Dialect/Linalg/Utils/Utils.h" 46 #include "mlir/Dialect/MemRef/IR/MemRef.h" 47 #include "mlir/Dialect/SCF/SCF.h" 48 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 49 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 50 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 51 #include "mlir/Dialect/StandardOps/IR/Ops.h" 52 #include "mlir/Dialect/Vector/VectorOps.h" 53 #include "mlir/IR/Matchers.h" 54 #include "mlir/IR/TensorEncoding.h" 55 #include "llvm/ADT/SmallBitVector.h" 56 57 using namespace mlir; 58 using namespace mlir::sparse_tensor; 59 60 namespace { 61 62 // Code generation. 63 struct CodeGen { 64 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) 65 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 66 pointers(numTensors, std::vector<Value>(numLoops)), 67 indices(numTensors, std::vector<Value>(numLoops)), 68 highs(numTensors, std::vector<Value>(numLoops)), 69 pidxs(numTensors, std::vector<Value>(numLoops)), 70 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 71 curVecLength(1), curVecMask() {} 72 /// Sparsification options. 73 SparsificationOptions options; 74 /// Universal dense indices and upper bounds (by index). The loops array 75 /// is updated with the value of the universal dense index in the current 76 /// loop. The sizes array is set once with the inferred dimension sizes. 77 std::vector<Value> loops; 78 std::vector<Value> sizes; 79 /// Buffers for storing dense and sparse numerical values (by tensor). 80 /// This array is set once during bufferization of all tensors. 81 std::vector<Value> buffers; 82 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 83 /// This array is set once during bufferization of all sparse tensors. 84 std::vector<std::vector<Value>> pointers; 85 std::vector<std::vector<Value>> indices; 86 /// Sparse iteration information (by tensor and index). These arrays 87 /// are updated to remain current within the current loop. 88 std::vector<std::vector<Value>> highs; 89 std::vector<std::vector<Value>> pidxs; 90 std::vector<std::vector<Value>> idxs; 91 /// Current reduction, updated during code generation. When indices of a 92 /// reduction are exhausted, all inner loops can "scalarize" the reduction. 93 // TODO: currently only done for (a chain of) innermost for-loops, where it 94 // is most effective; we could generalize to more outer and while-loops. 95 unsigned redExp; 96 Value redVal; 97 // Current vector length and mask. 98 unsigned curVecLength; 99 Value curVecMask; 100 }; 101 102 } // namespace 103 104 // Helper method to apply dimension ordering permutation. 105 static unsigned perm(SparseTensorEncodingAttr &enc, unsigned d) { 106 if (enc) { 107 auto order = enc.getDimOrdering(); 108 if (order) { 109 assert(order.isPermutation()); 110 return order.getDimPosition(d); 111 } 112 } 113 return d; 114 } 115 116 // Helper method to translate dim level type to internal representation. 117 static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) { 118 if (enc) { 119 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 120 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 121 return Dim::kSparse; 122 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 123 return Dim::kSingle; 124 } 125 return Dim::kDense; 126 } 127 128 /// Helper method to inspect sparse encodings in the tensor types. 129 /// Fills the per-dimension sparsity information for all tensors. 130 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 131 bool annotated = false; 132 for (OpOperand *t : op.getInputAndOutputOperands()) { 133 auto map = op.getTiedIndexingMap(t); 134 if (!map.isProjectedPermutation()) 135 return false; 136 auto enc = getSparseTensorEncoding(t->get().getType()); 137 if (enc) 138 annotated = true; 139 assert(map.getNumResults() == op.getRank(t)); 140 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 141 unsigned idx = map.getDimPosition(perm(enc, d)); 142 merger.setDim(t->getOperandNumber(), idx, toDim(enc, d)); 143 } 144 } 145 return annotated; 146 } 147 148 /// A DFS helper to compute a topological sort. Note that recursion is 149 /// bounded by the number of implicit loops, which is always small. 150 /// Returns false when a cycle is detected. 151 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 152 std::vector<unsigned> &topSort, 153 std::vector<std::vector<bool>> &adjM) { 154 if (visit[i] != 0) 155 return visit[i] != 1; // 1 denotes cycle! 156 visit[i] = 1; 157 for (unsigned j = 0, e = visit.size(); j < e; j++) 158 if (adjM[i][j]) 159 if (!topSortDFS(j, visit, topSort, adjM)) 160 return false; 161 visit[i] = 2; 162 topSort.push_back(i); 163 return true; 164 } 165 166 /// Computes a topologically sorted iteration graph for the linalg operation. 167 /// Ensures all tensors are visited in natural index order. This is essential 168 /// for sparse storage formats since these only support access along fixed 169 /// dimensions. Even for dense storage formats, however, the natural index 170 /// order yields innermost unit-stride access with better spatial locality. 171 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 172 std::vector<unsigned> &topSort, 173 bool sparseOnly) { 174 // Set up an n x n from/to adjacency matrix of the iteration graph 175 // for the implicit loop indices i_0 .. i_n-1. 176 unsigned n = op.getNumLoops(); 177 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 178 179 // Iterate over the indexing maps of every tensor in the tensor expression. 180 for (OpOperand *t : op.getInputAndOutputOperands()) { 181 auto map = op.getTiedIndexingMap(t); 182 auto enc = getSparseTensorEncoding(t->get().getType()); 183 assert(map.getNumDims() == n); 184 // Skip dense tensor constraints when sparse only is requested. 185 if (sparseOnly && !enc) 186 continue; 187 // Each tensor expression and optional dimension ordering (row-major 188 // by default) puts an ordering constraint on the loop indices. For 189 // example, the tensor expresion A_ijk forces the ordering i < j < k 190 // on the loop indices if no explicit dimension ordering is given. 191 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 192 unsigned f = map.getDimPosition(perm(enc, d - 1)); 193 unsigned t = map.getDimPosition(perm(enc, d)); 194 adjM[f][t] = true; 195 } 196 } 197 198 // Topologically sort the iteration graph to determine loop order. 199 // Report failure for a cyclic iteration graph. 200 topSort.clear(); 201 topSort.reserve(n); 202 std::vector<unsigned> visit(n, 0); 203 for (unsigned i = 0; i < n; i++) 204 if (visit[i] == 0) 205 if (!topSortDFS(i, visit, topSort, adjM)) 206 return false; // cycle! 207 std::reverse(std::begin(topSort), std::end(topSort)); 208 return true; 209 } 210 211 /// Returns true if given tensor co-iterates with conjunction only. 212 /// For the output tensor, this defines a "simply dynamic" operation. 213 /// For instance: A(I) = A(I) * B(I) * C(I) 214 static unsigned isConjunction(Merger &merger, unsigned tensor, unsigned exp) { 215 switch (merger.exp(exp).kind) { 216 case Kind::kTensor: 217 return merger.exp(exp).tensor == tensor; 218 case Kind::kMulF: 219 case Kind::kMulI: 220 return isConjunction(merger, tensor, merger.exp(exp).children.e0) || 221 isConjunction(merger, tensor, merger.exp(exp).children.e1); 222 default: 223 return false; 224 } 225 } 226 227 /// Returns true when the tensor expression is admissable for codegen. 228 /// Since all sparse input tensors are admissable, we just need to check 229 /// whether the output tensor in the tensor expression codegen is admissable. 230 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 231 unsigned exp) { 232 OpOperand *lhs = op.getOutputOperand(0); 233 unsigned tensor = lhs->getOperandNumber(); 234 auto enc = getSparseTensorEncoding(lhs->get().getType()); 235 // An non-annotated output tensor is assumed dense, and becomes a random 236 // access n-dim memref. Admissable since inserstions cannot occur. 237 if (!enc) 238 return true; 239 // An all-dense annotated "sparse" output tensor becomes a linearized random 240 // access 1-dim memref. Also admissable since insertions cannot occur. 241 bool allDense = true; 242 unsigned numLoops = op.iterator_types().getValue().size(); 243 for (unsigned i = 0; i < numLoops; i++) 244 if (merger.isDim(tensor, i, Dim::kSparse)) { 245 allDense = false; 246 break; 247 } 248 if (allDense) 249 return true; 250 // A tensor expression with a sparse output tensor that changes its values 251 // but not its nonzero structure, an operation called "simply dynamic" in 252 // [Bik96,Ch9], is also admissable without special codegen. 253 if (isConjunction(merger, tensor, exp)) 254 return true; 255 // Reject for now since this requires changes to the nonzero structure. 256 // TODO: implement "workspaces" [Kjolstad2019] 257 return false; 258 } 259 260 /// Maps sparse integer option to actual integral storage type. 261 static Type genIntType(PatternRewriter &rewriter, unsigned width) { 262 if (width == 0) 263 return rewriter.getIndexType(); 264 return rewriter.getIntegerType(width); 265 } 266 267 /// Detects in-place annotation on tensor argument. 268 static bool getInPlace(Value val) { 269 if (auto arg = val.dyn_cast<BlockArgument>()) 270 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 271 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 272 arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) 273 return attr.getValue(); 274 return false; 275 } 276 277 /// Generates buffer for the output tensor. 278 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 279 linalg::GenericOp op, MemRefType denseTp, 280 ArrayRef<Value> args) { 281 Location loc = op.getLoc(); 282 Value tensor = op.getOutputOperand(0)->get(); 283 // The output tensor simply could materialize from the buffer that will 284 // be generated for the tensor present in the outs() clause. This has 285 // the major advantage that the sparse kernel only updates the nonzero 286 // positions for the output tensor. 287 if (getInPlace(tensor)) 288 return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 289 // By default, a new buffer is allocated which is initialized to the 290 // tensor defined in the outs() clause. This is always correct but 291 // introduces a dense initialization component that may negatively 292 // impact the running complexity of the sparse kernel. 293 Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 294 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 295 rewriter.create<linalg::CopyOp>(loc, init, alloc); 296 return alloc; 297 } 298 299 /// Local bufferization of all dense and sparse data structures. 300 /// This code enables testing the first prototype sparse compiler. 301 // TODO: replace this with a proliferated bufferization strategy 302 static bool genBuffers(Merger &merger, CodeGen &codegen, 303 PatternRewriter &rewriter, linalg::GenericOp op) { 304 Location loc = op.getLoc(); 305 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 306 // For every tensor, find lower and upper bound on dimensions, set the 307 // same bounds on loop indices, and obtain dense or sparse buffer(s). 308 SmallVector<Value, 4> args; 309 for (OpOperand *t : op.getInputAndOutputOperands()) { 310 unsigned tensor = t->getOperandNumber(); 311 auto shape = op.getShape(t); 312 auto map = op.getTiedIndexingMap(t); 313 auto enc = getSparseTensorEncoding(t->get().getType()); 314 // Scan all dimensions of current tensor. 315 args.clear(); 316 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 317 unsigned idx = map.getDimPosition(perm(enc, d)); 318 // Handle sparse storage schemes. 319 if (merger.isDim(tensor, idx, Dim::kSparse)) { 320 auto dynShape = {ShapedType::kDynamicSize}; 321 auto ptrTp = MemRefType::get( 322 dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 323 auto indTp = MemRefType::get( 324 dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 325 Value dim = rewriter.create<ConstantIndexOp>(loc, d); 326 // Generate sparse primitives to obtains pointer and indices. 327 codegen.pointers[tensor][idx] = 328 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 329 codegen.indices[tensor][idx] = 330 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 331 } 332 // Find lower and upper bound in current dimension. 333 Value up; 334 if (shape[d] == MemRefType::kDynamicSize) { 335 up = rewriter.create<tensor::DimOp>(loc, t->get(), d); 336 args.push_back(up); 337 } else { 338 up = rewriter.create<ConstantIndexOp>(loc, shape[d]); 339 } 340 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 341 } 342 // Perform the required bufferization. Dense inputs materialize 343 // from the input tensors. Dense outputs need special handling. 344 // Sparse inputs use sparse primitives to obtain the values. 345 // We also accept in-place all-dense annotated "sparse" outputs. 346 Type elementType = getElementTypeOrSelf(t->get().getType()); 347 if (!enc) { 348 // Non-annotated dense tensors. 349 auto denseTp = MemRefType::get(shape, elementType); 350 if (tensor < op.getNumInputs()) 351 codegen.buffers[tensor] = 352 rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get()); 353 else 354 codegen.buffers[tensor] = 355 genOutputBuffer(codegen, rewriter, op, denseTp, args); 356 } else { 357 // Annotated sparse tensors. 358 if (tensor == op.getNumInputs() && !getInPlace(t->get())) 359 return false; // reject output if not in-place 360 auto dynShape = {ShapedType::kDynamicSize}; 361 auto sparseTp = MemRefType::get(dynShape, elementType); 362 codegen.buffers[tensor] = 363 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 364 } 365 } 366 return true; 367 } 368 369 /// Constructs vector type. 370 static VectorType vectorType(CodeGen &codegen, Type etp) { 371 return VectorType::get(codegen.curVecLength, etp); 372 } 373 374 /// Constructs vector type from pointer. 375 static VectorType vectorType(CodeGen &codegen, Value ptr) { 376 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 377 } 378 379 /// Constructs vector iteration mask. 380 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 381 Value iv, Value lo, Value hi, Value step) { 382 Location loc = iv.getLoc(); 383 VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1)); 384 // Special case if the vector length evenly divides the trip count (for 385 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 386 // so that all subsequent masked memory operations are immediately folded 387 // into unconditional memory operations. 388 IntegerAttr loInt, hiInt, stepInt; 389 if (matchPattern(lo, m_Constant(&loInt)) && 390 matchPattern(hi, m_Constant(&hiInt)) && 391 matchPattern(step, m_Constant(&stepInt))) { 392 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 393 return rewriter.create<vector::BroadcastOp>( 394 loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1)); 395 } 396 // Otherwise, generate a vector mask that avoids overrunning the upperbound 397 // during vector execution. Here we rely on subsequent loop optimizations to 398 // avoid executing the mask in all iterations, for example, by splitting the 399 // loop into an unconditional vector loop and a scalar cleanup loop. 400 Value end = rewriter.create<SubIOp>(loc, hi, iv); 401 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 402 } 403 404 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 405 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 406 Value ptr, ArrayRef<Value> args) { 407 Location loc = ptr.getLoc(); 408 VectorType vtp = vectorType(codegen, ptr); 409 Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 410 if (args.back().getType().isa<VectorType>()) { 411 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 412 Value indexVec = args.back(); 413 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 414 return rewriter.create<vector::GatherOp>( 415 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 416 } 417 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 418 codegen.curVecMask, pass); 419 } 420 421 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 422 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 423 Value rhs, Value ptr, ArrayRef<Value> args) { 424 Location loc = ptr.getLoc(); 425 if (args.back().getType().isa<VectorType>()) { 426 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 427 Value indexVec = args.back(); 428 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 429 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 430 codegen.curVecMask, rhs); 431 return; 432 } 433 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 434 rhs); 435 } 436 437 /// Generates a vectorized invariant. Here we rely on subsequent loop 438 /// optimizations to hoist the invariant broadcast out of the vector loop. 439 static Value genVectorInvariantValue(CodeGen &codegen, 440 PatternRewriter &rewriter, Value val) { 441 VectorType vtp = vectorType(codegen, val.getType()); 442 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 443 } 444 445 /// Generates a load on a dense or sparse tensor. 446 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 447 PatternRewriter &rewriter, linalg::GenericOp op, 448 unsigned exp) { 449 // Test if the load was hoisted to a higher loop nest. 450 Value val = merger.exp(exp).val; 451 if (val) { 452 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 453 return genVectorInvariantValue(codegen, rewriter, val); 454 return val; 455 } 456 // Actual load. 457 SmallVector<Value, 4> args; 458 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 459 unsigned tensor = t->getOperandNumber(); 460 auto map = op.getTiedIndexingMap(t); 461 auto enc = getSparseTensorEncoding(t->get().getType()); 462 unsigned rank = map.getNumResults(); 463 if (enc) { 464 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 465 assert(codegen.pidxs[tensor][idx] != nullptr); 466 args.push_back(codegen.pidxs[tensor][idx]); // position index 467 } else { 468 for (unsigned d = 0; d < rank; d++) { 469 unsigned idx = map.getDimPosition(d); 470 args.push_back(codegen.loops[idx]); // universal dense index 471 } 472 } 473 Location loc = op.getLoc(); 474 Value ptr = codegen.buffers[tensor]; 475 if (codegen.curVecLength > 1) 476 return genVectorLoad(codegen, rewriter, ptr, args); 477 return rewriter.create<memref::LoadOp>(loc, ptr, args); 478 } 479 480 /// Generates a store on a dense or sparse tensor. 481 static void genTensorStore(Merger &merger, CodeGen &codegen, 482 PatternRewriter &rewriter, linalg::GenericOp op, 483 OpOperand *t, Value rhs) { 484 Location loc = op.getLoc(); 485 // Test if this is a scalarized reduction. 486 OpOperand *lhs = op.getOutputOperand(0); 487 if (lhs == t && codegen.redVal) { 488 if (codegen.curVecLength > 1) 489 rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs, 490 codegen.redVal); 491 codegen.redVal = rhs; 492 return; 493 } 494 // Actual store. 495 SmallVector<Value, 4> args; 496 unsigned tensor = t->getOperandNumber(); 497 auto map = op.getTiedIndexingMap(t); 498 auto enc = getSparseTensorEncoding(t->get().getType()); 499 unsigned rank = map.getNumResults(); 500 if (enc) { 501 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 502 assert(codegen.pidxs[tensor][idx] != nullptr); 503 args.push_back(codegen.pidxs[tensor][idx]); // position index 504 } else { 505 for (unsigned d = 0; d < rank; d++) { 506 unsigned idx = map.getDimPosition(d); 507 args.push_back(codegen.loops[idx]); // universal dense index 508 } 509 } 510 Value ptr = codegen.buffers[tensor]; 511 if (codegen.curVecLength > 1) 512 genVectorStore(codegen, rewriter, rhs, ptr, args); 513 else 514 rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 515 } 516 517 /// Generates a pointer/index load from the sparse storage scheme. Narrower 518 /// data types need to be zero extended before casting the value into the 519 /// index type used for looping and indexing. 520 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 521 Value ptr, Value s) { 522 // See https://llvm.org/docs/GetElementPtr.html for some background on 523 // the complications described below. 524 if (codegen.curVecLength > 1) { 525 // Since the index vector is used in a subsequent gather/scatter operations, 526 // which effectively defines an unsigned pointer + signed index, we must 527 // zero extend the vector to an index width. For 8-bit and 16-bit values, 528 // an 32-bit index width suffices. For 32-bit values, zero extending the 529 // elements into 64-bit loses some performance since the 32-bit indexed 530 // gather/scatter is more efficient than the 64-bit index variant (if the 531 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 532 // preserve this performance). For 64-bit values, there is no good way 533 // to state that the indices are unsigned, with creates the potential of 534 // incorrect address calculations in the unlikely case we need such 535 // extremely large offsets. 536 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 537 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 538 if (!etp.isa<IndexType>()) { 539 if (etp.getIntOrFloatBitWidth() < 32) 540 vload = rewriter.create<ZeroExtendIOp>( 541 loc, vload, vectorType(codegen, rewriter.getIntegerType(32))); 542 else if (etp.getIntOrFloatBitWidth() < 64 && 543 !codegen.options.enableSIMDIndex32) 544 vload = rewriter.create<ZeroExtendIOp>( 545 loc, vload, vectorType(codegen, rewriter.getIntegerType(64))); 546 } 547 return vload; 548 } 549 // For the scalar case, we simply zero extend narrower indices into 64-bit 550 // values before casting to index without a performance penalty. Here too, 551 // however, indices that already are 64-bit, in theory, cannot express the 552 // full range as explained above. 553 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 554 if (!load.getType().isa<IndexType>()) { 555 if (load.getType().getIntOrFloatBitWidth() < 64) 556 load = rewriter.create<ZeroExtendIOp>(loc, load, 557 rewriter.getIntegerType(64)); 558 load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType()); 559 } 560 return load; 561 } 562 563 /// Generates an invariant value. 564 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 565 PatternRewriter &rewriter, unsigned exp) { 566 Value val = merger.exp(exp).val; 567 if (codegen.curVecLength > 1) 568 return genVectorInvariantValue(codegen, rewriter, val); 569 return val; 570 } 571 572 /// Generates an address computation "sz * p + i". 573 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 574 Location loc, Value size, Value p, Value i) { 575 Value mul = rewriter.create<MulIOp>(loc, size, p); 576 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 577 Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType()); 578 mul = genVectorInvariantValue(codegen, rewriter, inv); 579 } 580 return rewriter.create<AddIOp>(loc, mul, i); 581 } 582 583 /// Generates start of a reduction. 584 static Value genReductionStart(Merger &merger, CodeGen &codegen, 585 PatternRewriter &rewriter, 586 linalg::GenericOp op) { 587 if (codegen.redVal) 588 return codegen.redVal; // chained with previous for-loop 589 if (codegen.curVecLength > 1) { 590 // TODO: assumes + reductions for now 591 VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]); 592 return rewriter.create<ConstantOp>(op.getLoc(), vtp, 593 rewriter.getZeroAttr(vtp)); 594 } 595 return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 596 } 597 598 /// Generates end of a reduction. 599 static void genReductionEnd(Merger &merger, CodeGen &codegen, 600 PatternRewriter &rewriter, linalg::GenericOp op) { 601 Value red = codegen.redVal; 602 if (!red) 603 return; 604 assert(codegen.curVecLength == 1); 605 codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain 606 OpOperand *lhs = op.getOutputOperand(0); 607 if (auto vtp = red.getType().dyn_cast<VectorType>()) { 608 // TODO: assumes + reductions for now 609 StringAttr kind = rewriter.getStringAttr("add"); 610 Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 611 // Integer reductions don't accept an accumulator. 612 if (vtp.getElementType().isa<IntegerType>()) { 613 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 614 kind, red, ValueRange{}); 615 red = rewriter.create<AddIOp>(op.getLoc(), red, ld); 616 } else { 617 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 618 kind, red, ld); 619 } 620 } 621 genTensorStore(merger, codegen, rewriter, op, lhs, red); 622 } 623 624 /// Recursively generates tensor expression. 625 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 626 linalg::GenericOp op, unsigned exp) { 627 Location loc = op.getLoc(); 628 if (merger.exp(exp).kind == Kind::kTensor) 629 return genTensorLoad(merger, codegen, rewriter, op, exp); 630 if (merger.exp(exp).kind == Kind::kInvariant) 631 return genInvariantValue(merger, codegen, rewriter, exp); 632 if (merger.exp(exp).kind == Kind::kZero) { 633 Type tp = op.getOutputTensorTypes()[0].getElementType(); 634 merger.exp(exp).val = 635 rewriter.create<ConstantOp>(loc, tp, rewriter.getZeroAttr(tp)); 636 return genInvariantValue(merger, codegen, rewriter, exp); 637 } 638 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 639 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 640 switch (merger.exp(exp).kind) { 641 case Kind::kTensor: 642 case Kind::kInvariant: 643 case Kind::kZero: 644 llvm_unreachable("handled above"); 645 case Kind::kMulF: 646 return rewriter.create<MulFOp>(loc, v0, v1); 647 case Kind::kMulI: 648 return rewriter.create<MulIOp>(loc, v0, v1); 649 case Kind::kAddF: 650 return rewriter.create<AddFOp>(loc, v0, v1); 651 case Kind::kAddI: 652 return rewriter.create<AddIOp>(loc, v0, v1); 653 case Kind::kSubF: 654 return rewriter.create<SubFOp>(loc, v0, v1); 655 case Kind::kSubI: 656 return rewriter.create<SubIOp>(loc, v0, v1); 657 } 658 llvm_unreachable("unexpected expression kind"); 659 } 660 661 /// Hoists loop invariant tensor loads for which indices have been exhausted. 662 static void genInvariants(Merger &merger, CodeGen &codegen, 663 PatternRewriter &rewriter, linalg::GenericOp op, 664 unsigned exp, unsigned ldx, bool hoist) { 665 if (merger.exp(exp).kind == Kind::kTensor) { 666 // Inspect tensor indices. 667 bool atLevel = ldx == -1u; 668 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 669 auto map = op.getTiedIndexingMap(t); 670 auto enc = getSparseTensorEncoding(t->get().getType()); 671 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 672 unsigned idx = map.getDimPosition(perm(enc, d)); 673 if (!codegen.loops[idx]) 674 return; // still in play 675 else if (idx == ldx) 676 atLevel = true; 677 } 678 // All exhausted at this level (atLevel denotes exactly at this level). 679 OpOperand *lhs = op.getOutputOperand(0); 680 if (lhs == t) { 681 codegen.redExp = hoist ? exp : -1u; 682 } else if (atLevel) { 683 merger.exp(exp).val = 684 hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 685 } 686 } else if (merger.exp(exp).kind != Kind::kInvariant && 687 merger.exp(exp).kind != Kind::kZero) { 688 // Traverse into the binary operations. Note that we only hoist 689 // tensor loads, since subsequent MLIR/LLVM passes know how to 690 // deal with all other kinds of derived loop invariants. 691 unsigned e0 = merger.exp(exp).children.e0; 692 unsigned e1 = merger.exp(exp).children.e1; 693 genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist); 694 genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist); 695 } 696 } 697 698 /// Generates initialization code for the subsequent loop sequence at 699 /// current index level. Returns true if the loop sequence needs to 700 /// maintain the universal index. 701 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 702 linalg::GenericOp op, std::vector<unsigned> &topSort, 703 unsigned at, llvm::BitVector &inits) { 704 bool needsUniv = false; 705 Location loc = op.getLoc(); 706 unsigned idx = topSort[at]; 707 708 // Initialize sparse positions. 709 for (unsigned b = 0, be = inits.size(); b < be; b++) { 710 if (inits[b]) { 711 unsigned tensor = merger.tensor(b); 712 assert(idx == merger.index(b)); 713 if (merger.isDim(b, Dim::kSparse)) { 714 // Initialize sparse index. 715 unsigned pat = at; 716 for (; pat != 0; pat--) { 717 if (codegen.pidxs[tensor][topSort[pat - 1]]) 718 break; 719 } 720 Value ptr = codegen.pointers[tensor][idx]; 721 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 722 Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 723 : codegen.pidxs[tensor][topSort[pat - 1]]; 724 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 725 Value p1 = rewriter.create<AddIOp>(loc, p0, one); 726 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 727 } else { 728 // Dense index still in play. 729 needsUniv = true; 730 } 731 } 732 } 733 734 // Initialize the universal dense index. 735 codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0); 736 return needsUniv; 737 } 738 739 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 740 /// operation is a candidate. Whether it is actually converted to SIMD code 741 /// depends on the requested strategy. 742 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 743 switch (codegen.options.vectorizationStrategy) { 744 case SparseVectorizationStrategy::kNone: 745 return false; 746 case SparseVectorizationStrategy::kDenseInnerLoop: 747 return isInner && !isSparse; 748 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 749 return isInner; 750 } 751 llvm_unreachable("unexpected vectorization strategy"); 752 } 753 754 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 755 /// that is marked "parallel" is a candidate. Whether it is actually converted 756 /// to a parallel operation depends on the requested strategy. 757 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 758 bool isSparse, bool isVector) { 759 switch (codegen.options.parallelizationStrategy) { 760 case SparseParallelizationStrategy::kNone: 761 return false; 762 case SparseParallelizationStrategy::kDenseOuterLoop: 763 return isOuter && !isSparse && !isReduction && !isVector; 764 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 765 return isOuter && !isReduction && !isVector; 766 case SparseParallelizationStrategy::kDenseAnyLoop: 767 return !isSparse && !isReduction && !isVector; 768 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 769 return !isReduction && !isVector; 770 } 771 llvm_unreachable("unexpected parallelization strategy"); 772 } 773 774 /// Checks unit strides for dense tensors. The iteration graph may have ignored 775 /// dense access patterns in order to avoid cycles (sparse access patterns are 776 /// always placed innermost), but that means dense access has become strided. 777 /// For now, we reject vectorization of such cases. 778 /// TODO: implement strided load/stores on dense arrays 779 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 780 unsigned idx) { 781 for (OpOperand *t : op.getInputAndOutputOperands()) { 782 if (!getSparseTensorEncoding(t->get().getType())) { 783 auto map = op.getTiedIndexingMap(t); 784 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 785 if (map.getDimPosition(d) == idx && d != rank - 1) 786 return false; 787 } 788 } 789 } 790 return true; 791 } 792 793 /// Generates a for-loop on a single index. 794 static Operation *genFor(Merger &merger, CodeGen &codegen, 795 PatternRewriter &rewriter, linalg::GenericOp op, 796 bool isOuter, bool isInner, unsigned idx, 797 llvm::BitVector &indices) { 798 unsigned fb = indices.find_first(); 799 unsigned tensor = merger.tensor(fb); 800 assert(idx == merger.index(fb)); 801 auto iteratorTypes = op.iterator_types().getValue(); 802 bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]); 803 bool isSparse = merger.isDim(fb, Dim::kSparse); 804 bool isVector = isVectorFor(codegen, isInner, isSparse) && 805 denseUnitStrides(merger, op, idx); 806 bool isParallel = 807 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 808 809 // Prepare vector length. 810 if (isVector) 811 codegen.curVecLength = codegen.options.vectorLength; 812 813 // Loop bounds and increment. 814 Location loc = op.getLoc(); 815 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 816 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 817 Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength); 818 819 // Emit a parallel loop. 820 if (isParallel) { 821 assert(!isVector); 822 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 823 if (isSparse) 824 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 825 else 826 codegen.loops[idx] = parOp.getInductionVars()[0]; 827 rewriter.setInsertionPointToStart(parOp.getBody()); 828 return parOp; 829 } 830 831 // Emit a sequential loop, potentially with a scalarized reduction. 832 bool scalarRed = isInner && codegen.redExp != -1u; 833 SmallVector<Value, 4> operands; 834 if (scalarRed) { 835 Value load = genReductionStart(merger, codegen, rewriter, op); 836 operands.push_back(load); 837 } 838 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 839 if (scalarRed) { 840 codegen.redVal = merger.exp(codegen.redExp).val = 841 forOp.getRegionIterArgs().front(); 842 } 843 // Assign induction variable to sparse or dense index. 844 Value iv = forOp.getInductionVar(); 845 if (isSparse) 846 codegen.pidxs[tensor][idx] = iv; 847 else 848 codegen.loops[idx] = iv; 849 rewriter.setInsertionPointToStart(forOp.getBody()); 850 // Share vector iteration mask between all subsequent loads/stores. 851 if (isVector) 852 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 853 return forOp; 854 } 855 856 /// Emit a while-loop for co-iteration over multiple indices. 857 static Operation *genWhile(Merger &merger, CodeGen &codegen, 858 PatternRewriter &rewriter, linalg::GenericOp op, 859 unsigned idx, bool needsUniv, 860 llvm::BitVector &indices) { 861 SmallVector<Type, 4> types; 862 SmallVector<Value, 4> operands; 863 // Construct the while-loop with a parameter for each index. 864 Type indexType = rewriter.getIndexType(); 865 for (unsigned b = 0, be = indices.size(); b < be; b++) { 866 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 867 unsigned tensor = merger.tensor(b); 868 assert(idx == merger.index(b)); 869 types.push_back(indexType); 870 assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() && 871 "type mismatch for sparse index"); 872 operands.push_back(codegen.pidxs[tensor][idx]); 873 } 874 } 875 if (needsUniv) { 876 types.push_back(indexType); 877 assert(codegen.loops[idx].getType().isa<IndexType>() && 878 "type mismatch for universal index"); 879 operands.push_back(codegen.loops[idx]); 880 } 881 Location loc = op.getLoc(); 882 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 883 Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 884 Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 885 886 // Build the "before" region, which effectively consists 887 // of a conjunction of "i < upper" tests on all induction. 888 rewriter.setInsertionPointToStart(&whileOp.before().front()); 889 Value cond; 890 unsigned o = 0; 891 for (unsigned b = 0, be = indices.size(); b < be; b++) { 892 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 893 unsigned tensor = merger.tensor(b); 894 assert(idx == merger.index(b)); 895 Value op1 = before->getArgument(o); 896 Value op2 = codegen.highs[tensor][idx]; 897 Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2); 898 cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc; 899 codegen.pidxs[tensor][idx] = after->getArgument(o++); 900 } 901 } 902 if (needsUniv) 903 codegen.loops[idx] = after->getArgument(o++); 904 assert(o == operands.size()); 905 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 906 rewriter.setInsertionPointToStart(&whileOp.after().front()); 907 return whileOp; 908 } 909 910 /// Generates a for-loop or a while-loop, depending on whether it implements 911 /// singleton iteration or co-iteration over the given conjunction. 912 static Operation *genLoop(Merger &merger, CodeGen &codegen, 913 PatternRewriter &rewriter, linalg::GenericOp op, 914 std::vector<unsigned> &topSort, unsigned at, 915 bool needsUniv, llvm::BitVector &indices) { 916 unsigned idx = topSort[at]; 917 if (indices.count() == 1) { 918 bool isOuter = at == 0; 919 bool isInner = at == topSort.size() - 1; 920 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 921 indices); 922 } 923 genReductionEnd(merger, codegen, rewriter, op); // cannot chain 924 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 925 } 926 927 /// Generates the local variables for this loop, consisting of the sparse 928 /// indices, restored universal dense index, and dense positions. 929 static void genLocals(Merger &merger, CodeGen &codegen, 930 PatternRewriter &rewriter, linalg::GenericOp op, 931 std::vector<unsigned> &topSort, unsigned at, 932 bool needsUniv, llvm::BitVector &locals) { 933 Location loc = op.getLoc(); 934 unsigned idx = topSort[at]; 935 936 // Initialize sparse indices. 937 Value min; 938 for (unsigned b = 0, be = locals.size(); b < be; b++) { 939 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 940 unsigned tensor = merger.tensor(b); 941 assert(idx == merger.index(b)); 942 Value ptr = codegen.indices[tensor][idx]; 943 Value s = codegen.pidxs[tensor][idx]; 944 Value load = genLoad(codegen, rewriter, loc, ptr, s); 945 codegen.idxs[tensor][idx] = load; 946 if (!needsUniv) { 947 if (min) { 948 Value cmp = 949 rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min); 950 min = rewriter.create<SelectOp>(loc, cmp, load, min); 951 } else { 952 min = load; 953 } 954 } 955 } 956 } 957 958 // Merge dense universal index over minimum. 959 if (min) { 960 assert(!needsUniv); 961 codegen.loops[idx] = min; 962 } 963 964 // Initialize dense positions. Note that we generate dense indices of the 965 // output tensor unconditionally, since they may not appear in the lattice, 966 // but may be needed for linearized codegen. 967 for (unsigned b = 0, be = locals.size(); b < be; b++) { 968 if ((locals[b] || merger.isOutTensor(b, idx)) && 969 merger.isDim(b, Dim::kDense)) { 970 unsigned tensor = merger.tensor(b); 971 assert(idx == merger.index(b)); 972 unsigned pat = at; 973 for (; pat != 0; pat--) 974 if (codegen.pidxs[tensor][topSort[pat - 1]]) 975 break; 976 Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 977 : codegen.pidxs[tensor][topSort[pat - 1]]; 978 codegen.pidxs[tensor][idx] = genAddress( 979 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 980 } 981 } 982 } 983 984 /// Generates the induction structure for a while-loop. 985 static void genWhileInduction(Merger &merger, CodeGen &codegen, 986 PatternRewriter &rewriter, linalg::GenericOp op, 987 unsigned idx, bool needsUniv, 988 llvm::BitVector &induction, ResultRange results) { 989 Location loc = op.getLoc(); 990 unsigned o = 0; 991 SmallVector<Value, 4> operands; 992 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 993 for (unsigned b = 0, be = induction.size(); b < be; b++) { 994 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 995 unsigned tensor = merger.tensor(b); 996 assert(idx == merger.index(b)); 997 Value op1 = codegen.idxs[tensor][idx]; 998 Value op2 = codegen.loops[idx]; 999 Value op3 = codegen.pidxs[tensor][idx]; 1000 Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 1001 Value add = rewriter.create<AddIOp>(loc, op3, one); 1002 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 1003 codegen.pidxs[tensor][idx] = results[o++]; 1004 } 1005 } 1006 if (needsUniv) { 1007 operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one)); 1008 codegen.loops[idx] = results[o++]; 1009 } 1010 assert(o == operands.size()); 1011 rewriter.create<scf::YieldOp>(loc, operands); 1012 } 1013 1014 /// Generates a single if-statement within a while-loop. 1015 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1016 PatternRewriter &rewriter, linalg::GenericOp op, 1017 unsigned idx, llvm::BitVector &conditions) { 1018 Location loc = op.getLoc(); 1019 Value cond; 1020 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1021 if (conditions[b]) { 1022 unsigned tensor = merger.tensor(b); 1023 assert(idx == merger.index(b)); 1024 Value clause; 1025 if (merger.isDim(b, Dim::kSparse)) { 1026 Value op1 = codegen.idxs[tensor][idx]; 1027 Value op2 = codegen.loops[idx]; 1028 clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 1029 } else { 1030 clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true 1031 } 1032 cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause; 1033 } 1034 } 1035 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true); 1036 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 1037 return ifOp; 1038 } 1039 1040 /// Recursively generates code while computing iteration lattices in order 1041 /// to manage the complexity of implementing co-iteration over unions 1042 /// and intersections of sparse iterations spaces. 1043 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1044 linalg::GenericOp op, std::vector<unsigned> &topSort, 1045 unsigned exp, unsigned at) { 1046 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1047 if (at == topSort.size()) { 1048 OpOperand *lhs = op.getOutputOperand(0); 1049 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1050 genTensorStore(merger, codegen, rewriter, op, lhs, rhs); 1051 return; 1052 } 1053 assert(codegen.curVecLength == 1); 1054 1055 // Construct iteration lattices for current loop index, with L0 at top. 1056 // Then emit initialization code for the loop sequence at this level. 1057 // We maintain the universal dense index if dense indices are still 1058 // in play for a non-singleton loop sequence. 1059 Location loc = op.getLoc(); 1060 unsigned idx = topSort[at]; 1061 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1062 unsigned lsize = merger.set(lts).size(); 1063 assert(lsize != 0); 1064 unsigned l0 = merger.set(lts)[0]; 1065 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1066 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true); 1067 bool needsUniv = false; 1068 if (genInit(merger, codegen, rewriter, op, topSort, at, 1069 merger.lat(l0).bits)) { 1070 // Maintain the universal index only if it is actually 1071 // consumed by a subsequent lattice point. 1072 for (unsigned i = 1; i < lsize; i++) { 1073 unsigned li = merger.set(lts)[i]; 1074 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) { 1075 needsUniv = true; 1076 break; 1077 } 1078 } 1079 } 1080 1081 // Emit a loop for every lattice point L0 >= Li. 1082 for (unsigned i = 0; i < lsize; i++) { 1083 unsigned li = merger.set(lts)[i]; 1084 1085 // Emit loop. 1086 codegen.curVecLength = 1; 1087 llvm::BitVector indices = merger.lat(li).simple; 1088 Operation *loop = 1089 genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices); 1090 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1091 merger.lat(li).bits); 1092 1093 // Visit all lattices points with Li >= Lj to generate the 1094 // loop-body, possibly with if statements for coiteration. 1095 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1096 for (unsigned j = 0; j < lsize; j++) { 1097 unsigned lj = merger.set(lts)[j]; 1098 unsigned ej = merger.lat(lj).exp; 1099 if (li == lj || merger.latGT(li, lj)) { 1100 // Recurse into body of each branch. 1101 if (isWhile) { 1102 scf::IfOp ifOp = 1103 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1104 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1105 rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1106 } else { 1107 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1108 } 1109 } 1110 } 1111 1112 // Wrap-up induction and restore insertion point. 1113 if (isWhile) { 1114 scf::WhileOp whileOp = cast<scf::WhileOp>(loop); 1115 rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1116 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1117 merger.lat(li).bits, whileOp.results()); 1118 } else { 1119 needsUniv = false; 1120 if (codegen.redVal) { 1121 rewriter.create<scf::YieldOp>(loc, codegen.redVal); 1122 codegen.redVal = loop->getResult(0); 1123 } 1124 } 1125 rewriter.setInsertionPointAfter(loop); 1126 } 1127 1128 // Wrap-up loop sequence. 1129 codegen.curVecLength = 1; 1130 genReductionEnd(merger, codegen, rewriter, op); 1131 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); 1132 codegen.loops[idx] = Value(); 1133 } 1134 1135 /// Converts the result computed by the sparse kernel into the required form. 1136 static void genResult(Merger &merger, CodeGen &codegen, 1137 PatternRewriter &rewriter, linalg::GenericOp op) { 1138 Location loc = op.getLoc(); 1139 OpOperand *lhs = op.getOutputOperand(0); 1140 Type resType = lhs->get().getType(); 1141 unsigned tensor = lhs->getOperandNumber(); 1142 auto map = op.getTiedIndexingMap(lhs); 1143 auto enc = getSparseTensorEncoding(resType); 1144 Value result = codegen.buffers.back(); // value array 1145 if (enc) { 1146 // The sparse annotation unambigiously defines the arrays needed 1147 // to "reconstruct" the sparse tensor from the storage scheme 1148 // (even though lowering should never need this eventually). 1149 SmallVector<Value, 4> args; 1150 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1151 unsigned idx = map.getDimPosition(perm(enc, d)); 1152 if (merger.isDim(tensor, idx, Dim::kSparse)) { 1153 args.push_back(codegen.pointers[tensor][idx]); 1154 args.push_back(codegen.indices[tensor][idx]); 1155 } 1156 } 1157 args.push_back(result); 1158 result = rewriter.create<ToTensorOp>(loc, resType, args); 1159 } else { 1160 // To "reconstruct" an non-annotated tensor, sipmly load it 1161 // from the bufferized value. 1162 result = rewriter.create<memref::TensorLoadOp>(loc, resType, result); 1163 } 1164 rewriter.replaceOp(op, result); 1165 } 1166 1167 namespace { 1168 1169 /// Sparse rewriting rule for generic Lingalg operation. 1170 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1171 public: 1172 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1173 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1174 1175 LogicalResult matchAndRewrite(linalg::GenericOp op, 1176 PatternRewriter &rewriter) const override { 1177 // Detects sparse annotations and translate the per-dimension sparsity 1178 // information for all tensors to loop indices in the kernel. 1179 assert(op.getNumOutputs() == 1); 1180 unsigned numTensors = op.getNumInputsAndOutputs(); 1181 unsigned numLoops = op.iterator_types().getValue().size(); 1182 Merger merger(numTensors, numLoops); 1183 if (!findSparseAnnotations(merger, op)) 1184 return failure(); 1185 1186 // Computes a topologically sorted iteration graph to ensure 1187 // tensors are visited in natural index order. Fails on cycles. 1188 // This assumes that higher-level passes have already put the 1189 // tensors in each tensor expression in a feasible order. 1190 std::vector<unsigned> topSort; 1191 if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) && 1192 !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true)) 1193 return failure(); 1194 1195 // Builds the tensor expression for the Linalg operation in SSA form. 1196 Optional<unsigned> exp = merger.buildTensorExpFromLinalg(op); 1197 if (!exp.hasValue()) 1198 return failure(); 1199 1200 // Rejects an inadmissable tensor expression. 1201 if (!isAdmissableTensorExp(merger, op, exp.getValue())) 1202 return failure(); 1203 1204 // Recursively generates code. 1205 CodeGen codegen(options, numTensors, numLoops); 1206 if (!genBuffers(merger, codegen, rewriter, op)) 1207 return failure(); // could not bufferize 1208 genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); 1209 genResult(merger, codegen, rewriter, op); 1210 return success(); 1211 } 1212 1213 private: 1214 /// Options to control sparse code generation. 1215 SparsificationOptions options; 1216 }; 1217 1218 } // namespace 1219 1220 /// Populates the given patterns list with rewriting rules required for 1221 /// the sparsification of linear algebra operations. 1222 void mlir::populateSparsificationPatterns( 1223 RewritePatternSet &patterns, const SparsificationOptions &options) { 1224 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1225 } 1226