1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering sparse tensor types to actual sparse code. 10 // 11 // The concept of letting a compiler generate sparse code automatically was 12 // pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and 13 // formalized to tensor algebra by [Kjolstad17,20] for the Sparse Tensor 14 // Algebra Compiler (TACO). The implementation in this file closely follows 15 // the "sparse iteration theory" that forms the foundation of TACO. A rewriting 16 // rule is applied to each tensor expression in linalg (MLIR's tensor index 17 // notation) where the sparsity of tensors is indicated with annotation using 18 // a per-dimension specification of sparse/dense storage together with a 19 // specification of the order on the dimensions. Subsequently, a topologically 20 // sorted iteration graph, reflecting the required order on indices with respect 21 // to the dimensions of each tensor, is constructed to ensure that all tensors 22 // are visited in natural index order. Next, iteration lattices are constructed 23 // for the tensor expression for every index in topological order. Each 24 // iteration lattice point consists of a conjunction of tensor indices together 25 // with a tensor (sub)expression that needs to be evaluated for that 26 // conjunction. Within the lattice, iteration points are ordered according to 27 // the way indices are exhausted. As such these iteration lattices drive actual 28 // sparse code generation, which consists of a tedious but relatively 29 // straightforward one-to-one mapping from iteration lattices to combinations 30 // of for-loops, while-loops, and if-statements. 31 // 32 // [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations. 33 // PhD thesis, Leiden University, May 1996 (aartbik.com/sparse.php). 34 // [Kjolstad17] Fredrik Berg Kjolstad, Shoaib Ashraf Kamil, Stephen Chou, 35 // David Lugato, and Saman Amarasinghe. The Tensor Algebra Compiler. 36 // Proceedings of the ACM on Programming Languages, October 2017. 37 // [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation. 38 // PhD thesis, MIT, February, 2020 (tensor-compiler.org). 39 // 40 // Implementation detail: We use llvm::SmallVector for vectors with 41 // variable lengths and std::vector for vectors with fixed lengths. 42 //===----------------------------------------------------------------------===// 43 44 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 45 #include "mlir/Dialect/Linalg/Utils/Utils.h" 46 #include "mlir/Dialect/MemRef/IR/MemRef.h" 47 #include "mlir/Dialect/SCF/SCF.h" 48 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 49 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 50 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 51 #include "mlir/Dialect/StandardOps/IR/Ops.h" 52 #include "mlir/Dialect/Vector/VectorOps.h" 53 #include "mlir/IR/Matchers.h" 54 #include "mlir/IR/TensorEncoding.h" 55 #include "llvm/ADT/SmallBitVector.h" 56 57 using namespace mlir; 58 using namespace mlir::sparse_tensor; 59 60 namespace { 61 62 // Code generation. 63 struct CodeGen { 64 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) 65 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 66 pointers(numTensors, std::vector<Value>(numLoops)), 67 indices(numTensors, std::vector<Value>(numLoops)), 68 highs(numTensors, std::vector<Value>(numLoops)), 69 pidxs(numTensors, std::vector<Value>(numLoops)), 70 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 71 curVecLength(1), curVecMask() {} 72 /// Sparsification options. 73 SparsificationOptions options; 74 /// Universal dense indices and upper bounds (by index). The loops array 75 /// is updated with the value of the universal dense index in the current 76 /// loop. The sizes array is set once with the inferred dimension sizes. 77 std::vector<Value> loops; 78 std::vector<Value> sizes; 79 /// Buffers for storing dense and sparse numerical values (by tensor). 80 /// This array is set once during bufferization of all tensors. 81 std::vector<Value> buffers; 82 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 83 /// This array is set once during bufferization of all sparse tensors. 84 std::vector<std::vector<Value>> pointers; 85 std::vector<std::vector<Value>> indices; 86 /// Sparse iteration information (by tensor and index). These arrays 87 /// are updated to remain current within the current loop. 88 std::vector<std::vector<Value>> highs; 89 std::vector<std::vector<Value>> pidxs; 90 std::vector<std::vector<Value>> idxs; 91 /// Current reduction, updated during code generation. When indices of a 92 /// reduction are exhausted, all inner loops can "scalarize" the reduction. 93 // TODO: currently only done for (a chain of) innermost for-loops, where it 94 // is most effective; we could generalize to more outer and while-loops. 95 unsigned redExp; 96 Value redVal; 97 // Current vector length and mask. 98 unsigned curVecLength; 99 Value curVecMask; 100 }; 101 102 } // namespace 103 104 // Helper method to apply dimension ordering permutation. 105 static unsigned perm(SparseTensorEncodingAttr &enc, unsigned d) { 106 if (enc) { 107 auto order = enc.getDimOrdering(); 108 if (order) { 109 assert(order.isPermutation()); 110 return order.getDimPosition(d); 111 } 112 } 113 return d; 114 } 115 116 // Helper method to translate dim level type to internal representation. 117 static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) { 118 if (enc) { 119 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 120 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 121 return Dim::kSparse; 122 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 123 return Dim::kSingle; 124 } 125 return Dim::kDense; 126 } 127 128 /// Helper method to inspect sparse encodings in the tensor types. 129 /// Fills the per-dimension sparsity information for all tensors. 130 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 131 bool annotated = false; 132 for (OpOperand *t : op.getInputAndOutputOperands()) { 133 auto map = op.getTiedIndexingMap(t); 134 if (!map.isProjectedPermutation()) 135 return false; 136 auto enc = getSparseTensorEncoding(t->get().getType()); 137 if (enc) 138 annotated = true; 139 assert(map.getNumResults() == op.getRank(t)); 140 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 141 unsigned idx = map.getDimPosition(perm(enc, d)); 142 merger.setDim(t->getOperandNumber(), idx, toDim(enc, d)); 143 } 144 } 145 return annotated; 146 } 147 148 /// A DFS helper to compute a topological sort. Note that recursion is 149 /// bounded by the number of implicit loops, which is always small. 150 /// Returns false when a cycle is detected. 151 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 152 std::vector<unsigned> &topSort, 153 std::vector<std::vector<bool>> &adjM) { 154 if (visit[i] != 0) 155 return visit[i] != 1; // 1 denotes cycle! 156 visit[i] = 1; 157 for (unsigned j = 0, e = visit.size(); j < e; j++) 158 if (adjM[i][j]) 159 if (!topSortDFS(j, visit, topSort, adjM)) 160 return false; 161 visit[i] = 2; 162 topSort.push_back(i); 163 return true; 164 } 165 166 /// Computes a topologically sorted iteration graph for the linalg operation. 167 /// Ensures all tensors are visited in natural index order. This is essential 168 /// for sparse storage formats since these only support access along fixed 169 /// dimensions. Even for dense storage formats, however, the natural index 170 /// order yields innermost unit-stride access with better spatial locality. 171 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 172 std::vector<unsigned> &topSort, 173 bool sparseOnly) { 174 // Set up an n x n from/to adjacency matrix of the iteration graph 175 // for the implicit loop indices i_0 .. i_n-1. 176 unsigned n = op.getNumLoops(); 177 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 178 179 // Iterate over the indexing maps of every tensor in the tensor expression. 180 for (OpOperand *t : op.getInputAndOutputOperands()) { 181 auto map = op.getTiedIndexingMap(t); 182 auto enc = getSparseTensorEncoding(t->get().getType()); 183 assert(map.getNumDims() == n); 184 // Skip dense tensor constraints when sparse only is requested. 185 if (sparseOnly && !enc) 186 continue; 187 // Each tensor expression and optional dimension ordering (row-major 188 // by default) puts an ordering constraint on the loop indices. For 189 // example, the tensor expresion A_ijk forces the ordering i < j < k 190 // on the loop indices if no explicit dimension ordering is given. 191 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 192 unsigned f = map.getDimPosition(perm(enc, d - 1)); 193 unsigned t = map.getDimPosition(perm(enc, d)); 194 adjM[f][t] = true; 195 } 196 } 197 198 // Topologically sort the iteration graph to determine loop order. 199 // Report failure for a cyclic iteration graph. 200 topSort.clear(); 201 topSort.reserve(n); 202 std::vector<unsigned> visit(n, 0); 203 for (unsigned i = 0; i < n; i++) 204 if (visit[i] == 0) 205 if (!topSortDFS(i, visit, topSort, adjM)) 206 return false; // cycle! 207 std::reverse(std::begin(topSort), std::end(topSort)); 208 return true; 209 } 210 211 /// Returns true if given tensor co-iterates with conjunction only. 212 /// For the output tensor, this defines a "simply dynamic" operation. 213 /// For instance: A(I) = A(I) * B(I) * C(I) 214 static unsigned isConjunction(Merger &merger, unsigned tensor, unsigned exp) { 215 switch (merger.exp(exp).kind) { 216 case Kind::kTensor: 217 return merger.exp(exp).tensor == tensor; 218 case Kind::kMulF: 219 case Kind::kMulI: 220 return isConjunction(merger, tensor, merger.exp(exp).children.e0) || 221 isConjunction(merger, tensor, merger.exp(exp).children.e1); 222 default: 223 return false; 224 } 225 } 226 227 /// Returns true when the tensor expression is admissable for codegen. 228 /// Since all sparse input tensors are admissable, we just need to check 229 /// whether the output tensor in the tensor expression codegen is admissable. 230 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 231 unsigned exp) { 232 OpOperand *lhs = op.getOutputOperand(0); 233 unsigned tensor = lhs->getOperandNumber(); 234 auto enc = getSparseTensorEncoding(lhs->get().getType()); 235 // An non-annotated output tensor is assumed dense, and becomes a random 236 // access n-dim memref. Admissable since inserstions cannot occur. 237 if (!enc) 238 return true; 239 // An all-dense annotated "sparse" output tensor becomes a linearized random 240 // access 1-dim memref. Also admissable since insertions cannot occur. 241 bool allDense = true; 242 unsigned numLoops = op.iterator_types().getValue().size(); 243 for (unsigned i = 0; i < numLoops; i++) 244 if (merger.isDim(tensor, i, Dim::kSparse)) { 245 allDense = false; 246 break; 247 } 248 if (allDense) 249 return true; 250 // A tensor expression with a sparse output tensor that changes its values 251 // but not its nonzero structure, an operation called "simply dynamic" in 252 // [Bik96,Ch9], is also admissable without special codegen. 253 if (isConjunction(merger, tensor, exp)) 254 return true; 255 // Reject for now since this requires changes to the nonzero structure. 256 // TODO: implement "workspaces" [Kjolstad2019] 257 return false; 258 } 259 260 /// Maps sparse integer option to actual integral storage type. 261 static Type genIntType(PatternRewriter &rewriter, unsigned width) { 262 if (width == 0) 263 return rewriter.getIndexType(); 264 return rewriter.getIntegerType(width); 265 } 266 267 /// Detects in-place annotation on tensor argument. 268 static bool getInPlace(Value val) { 269 if (auto arg = val.dyn_cast<BlockArgument>()) 270 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 271 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 272 arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) 273 return attr.getValue(); 274 return false; 275 } 276 277 /// Generates buffer for the output tensor. 278 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 279 linalg::GenericOp op, MemRefType denseTp, 280 ArrayRef<Value> args) { 281 Location loc = op.getLoc(); 282 Value tensor = op.getOutputOperand(0)->get(); 283 // The output tensor simply could materialize from the buffer that will 284 // be generated for the tensor present in the outs() clause. This has 285 // the major advantage that the sparse kernel only updates the nonzero 286 // positions for the output tensor. 287 if (getInPlace(tensor)) 288 return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 289 // By default, a new buffer is allocated which is initialized to the 290 // tensor defined in the outs() clause. This is always correct but 291 // introduces a dense initialization component that may negatively 292 // impact the running complexity of the sparse kernel. 293 Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 294 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 295 rewriter.create<linalg::CopyOp>(loc, init, alloc); 296 return alloc; 297 } 298 299 /// Local bufferization of all dense and sparse data structures. 300 /// This code enables testing the first prototype sparse compiler. 301 // TODO: replace this with a proliferated bufferization strategy 302 static bool genBuffers(Merger &merger, CodeGen &codegen, 303 PatternRewriter &rewriter, linalg::GenericOp op) { 304 Location loc = op.getLoc(); 305 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 306 // For every tensor, find lower and upper bound on dimensions, set the 307 // same bounds on loop indices, and obtain dense or sparse buffer(s). 308 SmallVector<Value, 4> args; 309 for (OpOperand *t : op.getInputAndOutputOperands()) { 310 unsigned tensor = t->getOperandNumber(); 311 auto shape = op.getShape(t); 312 auto map = op.getTiedIndexingMap(t); 313 auto enc = getSparseTensorEncoding(t->get().getType()); 314 // Scan all dimensions of current tensor. 315 args.clear(); 316 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 317 unsigned idx = map.getDimPosition(perm(enc, d)); 318 // Handle sparse storage schemes. 319 if (merger.isDim(tensor, idx, Dim::kSparse)) { 320 auto dynShape = {ShapedType::kDynamicSize}; 321 auto ptrTp = MemRefType::get( 322 dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 323 auto indTp = MemRefType::get( 324 dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 325 Value dim = rewriter.create<ConstantIndexOp>(loc, d); 326 // Generate sparse primitives to obtains pointer and indices. 327 codegen.pointers[tensor][idx] = 328 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 329 codegen.indices[tensor][idx] = 330 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 331 } 332 // Find lower and upper bound in current dimension. 333 Value up; 334 if (shape[d] == MemRefType::kDynamicSize) { 335 up = createOrFoldDimOp(rewriter, loc, t->get(), d); 336 args.push_back(up); 337 } else { 338 up = rewriter.create<ConstantIndexOp>(loc, shape[d]); 339 } 340 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 341 } 342 // Perform the required bufferization. Dense inputs materialize 343 // from the input tensors. Dense outputs need special handling. 344 // Sparse inputs use sparse primitives to obtain the values. 345 // We also accept in-place all-dense annotated "sparse" outputs. 346 Type elementType = getElementTypeOrSelf(t->get().getType()); 347 if (!enc) { 348 // Non-annotated dense tensors. 349 auto denseTp = MemRefType::get(shape, elementType); 350 if (tensor < op.getNumInputs()) 351 codegen.buffers[tensor] = 352 rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get()); 353 else 354 codegen.buffers[tensor] = 355 genOutputBuffer(codegen, rewriter, op, denseTp, args); 356 } else { 357 // Annotated sparse tensors. 358 if (tensor == op.getNumInputs() && !getInPlace(t->get())) 359 return false; // reject output if not in-place 360 auto dynShape = {ShapedType::kDynamicSize}; 361 auto sparseTp = MemRefType::get(dynShape, elementType); 362 codegen.buffers[tensor] = 363 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 364 } 365 } 366 return true; 367 } 368 369 /// Constructs vector type. 370 static VectorType vectorType(CodeGen &codegen, Type etp) { 371 return VectorType::get(codegen.curVecLength, etp); 372 } 373 374 /// Constructs vector type from pointer. 375 static VectorType vectorType(CodeGen &codegen, Value ptr) { 376 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 377 } 378 379 /// Constructs vector iteration mask. 380 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 381 Value iv, Value lo, Value hi, Value step) { 382 Location loc = iv.getLoc(); 383 VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1)); 384 // Special case if the vector length evenly divides the trip count (for 385 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 386 // so that all subsequent masked memory operations are immediately folded 387 // into unconditional memory operations. 388 IntegerAttr loInt, hiInt, stepInt; 389 if (matchPattern(lo, m_Constant(&loInt)) && 390 matchPattern(hi, m_Constant(&hiInt)) && 391 matchPattern(step, m_Constant(&stepInt))) { 392 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 393 return rewriter.create<vector::BroadcastOp>( 394 loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1)); 395 } 396 // Otherwise, generate a vector mask that avoids overrunning the upperbound 397 // during vector execution. Here we rely on subsequent loop optimizations to 398 // avoid executing the mask in all iterations, for example, by splitting the 399 // loop into an unconditional vector loop and a scalar cleanup loop. 400 Value end = rewriter.create<SubIOp>(loc, hi, iv); 401 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 402 } 403 404 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 405 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 406 Value ptr, ArrayRef<Value> args) { 407 Location loc = ptr.getLoc(); 408 VectorType vtp = vectorType(codegen, ptr); 409 Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 410 if (args.back().getType().isa<VectorType>()) { 411 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 412 Value indexVec = args.back(); 413 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 414 return rewriter.create<vector::GatherOp>( 415 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 416 } 417 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 418 codegen.curVecMask, pass); 419 } 420 421 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 422 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 423 Value rhs, Value ptr, ArrayRef<Value> args) { 424 Location loc = ptr.getLoc(); 425 if (args.back().getType().isa<VectorType>()) { 426 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 427 Value indexVec = args.back(); 428 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 429 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 430 codegen.curVecMask, rhs); 431 return; 432 } 433 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 434 rhs); 435 } 436 437 /// Generates a vectorized invariant. Here we rely on subsequent loop 438 /// optimizations to hoist the invariant broadcast out of the vector loop. 439 static Value genVectorInvariantValue(CodeGen &codegen, 440 PatternRewriter &rewriter, Value val) { 441 VectorType vtp = vectorType(codegen, val.getType()); 442 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 443 } 444 445 /// Generates a load on a dense or sparse tensor. 446 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 447 PatternRewriter &rewriter, linalg::GenericOp op, 448 unsigned exp) { 449 // Test if the load was hoisted to a higher loop nest. 450 Value val = merger.exp(exp).val; 451 if (val) { 452 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 453 return genVectorInvariantValue(codegen, rewriter, val); 454 return val; 455 } 456 // Actual load. 457 SmallVector<Value, 4> args; 458 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 459 unsigned tensor = t->getOperandNumber(); 460 auto map = op.getTiedIndexingMap(t); 461 auto enc = getSparseTensorEncoding(t->get().getType()); 462 unsigned rank = map.getNumResults(); 463 if (enc) { 464 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 465 assert(codegen.pidxs[tensor][idx] != nullptr); 466 args.push_back(codegen.pidxs[tensor][idx]); // position index 467 } else { 468 for (unsigned d = 0; d < rank; d++) { 469 unsigned idx = map.getDimPosition(d); 470 args.push_back(codegen.loops[idx]); // universal dense index 471 } 472 } 473 Location loc = op.getLoc(); 474 Value ptr = codegen.buffers[tensor]; 475 if (codegen.curVecLength > 1) 476 return genVectorLoad(codegen, rewriter, ptr, args); 477 return rewriter.create<memref::LoadOp>(loc, ptr, args); 478 } 479 480 /// Generates a store on a dense or sparse tensor. 481 static void genTensorStore(Merger &merger, CodeGen &codegen, 482 PatternRewriter &rewriter, linalg::GenericOp op, 483 OpOperand *t, Value rhs) { 484 Location loc = op.getLoc(); 485 // Test if this is a scalarized reduction. 486 OpOperand *lhs = op.getOutputOperand(0); 487 if (lhs == t && codegen.redVal) { 488 if (codegen.curVecLength > 1) 489 rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs, 490 codegen.redVal); 491 codegen.redVal = rhs; 492 return; 493 } 494 // Actual store. 495 SmallVector<Value, 4> args; 496 unsigned tensor = t->getOperandNumber(); 497 auto map = op.getTiedIndexingMap(t); 498 auto enc = getSparseTensorEncoding(t->get().getType()); 499 unsigned rank = map.getNumResults(); 500 if (enc) { 501 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 502 assert(codegen.pidxs[tensor][idx] != nullptr); 503 args.push_back(codegen.pidxs[tensor][idx]); // position index 504 } else { 505 for (unsigned d = 0; d < rank; d++) { 506 unsigned idx = map.getDimPosition(d); 507 args.push_back(codegen.loops[idx]); // universal dense index 508 } 509 } 510 Value ptr = codegen.buffers[tensor]; 511 if (codegen.curVecLength > 1) 512 genVectorStore(codegen, rewriter, rhs, ptr, args); 513 else 514 rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 515 } 516 517 /// Generates a pointer/index load from the sparse storage scheme. Narrower 518 /// data types need to be zero extended before casting the value into the 519 /// index type used for looping and indexing. 520 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 521 Value ptr, Value s) { 522 // See https://llvm.org/docs/GetElementPtr.html for some background on 523 // the complications described below. 524 if (codegen.curVecLength > 1) { 525 // Since the index vector is used in a subsequent gather/scatter operations, 526 // which effectively defines an unsigned pointer + signed index, we must 527 // zero extend the vector to an index width. For 8-bit and 16-bit values, 528 // an 32-bit index width suffices. For 32-bit values, zero extending the 529 // elements into 64-bit loses some performance since the 32-bit indexed 530 // gather/scatter is more efficient than the 64-bit index variant (if the 531 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 532 // preserve this performance). For 64-bit values, there is no good way 533 // to state that the indices are unsigned, with creates the potential of 534 // incorrect address calculations in the unlikely case we need such 535 // extremely large offsets. 536 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 537 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 538 if (!etp.isa<IndexType>()) { 539 if (etp.getIntOrFloatBitWidth() < 32) 540 vload = rewriter.create<ZeroExtendIOp>( 541 loc, vload, vectorType(codegen, rewriter.getIntegerType(32))); 542 else if (etp.getIntOrFloatBitWidth() < 64 && 543 !codegen.options.enableSIMDIndex32) 544 vload = rewriter.create<ZeroExtendIOp>( 545 loc, vload, vectorType(codegen, rewriter.getIntegerType(64))); 546 } 547 return vload; 548 } 549 // For the scalar case, we simply zero extend narrower indices into 64-bit 550 // values before casting to index without a performance penalty. Here too, 551 // however, indices that already are 64-bit, in theory, cannot express the 552 // full range as explained above. 553 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 554 if (!load.getType().isa<IndexType>()) { 555 if (load.getType().getIntOrFloatBitWidth() < 64) 556 load = rewriter.create<ZeroExtendIOp>(loc, load, 557 rewriter.getIntegerType(64)); 558 load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType()); 559 } 560 return load; 561 } 562 563 /// Generates an invariant value. 564 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 565 PatternRewriter &rewriter, unsigned exp) { 566 Value val = merger.exp(exp).val; 567 if (codegen.curVecLength > 1) 568 return genVectorInvariantValue(codegen, rewriter, val); 569 return val; 570 } 571 572 /// Generates an address computation "sz * p + i". 573 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 574 Location loc, Value size, Value p, Value i) { 575 Value mul = rewriter.create<MulIOp>(loc, size, p); 576 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 577 Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType()); 578 mul = genVectorInvariantValue(codegen, rewriter, inv); 579 } 580 return rewriter.create<AddIOp>(loc, mul, i); 581 } 582 583 /// Generates start of a reduction. 584 static Value genReductionStart(Merger &merger, CodeGen &codegen, 585 PatternRewriter &rewriter, 586 linalg::GenericOp op) { 587 if (codegen.redVal) 588 return codegen.redVal; // chained with previous for-loop 589 if (codegen.curVecLength > 1) { 590 // TODO: assumes + reductions for now 591 VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]); 592 return rewriter.create<ConstantOp>(op.getLoc(), vtp, 593 rewriter.getZeroAttr(vtp)); 594 } 595 return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 596 } 597 598 /// Generates end of a reduction. 599 static void genReductionEnd(Merger &merger, CodeGen &codegen, 600 PatternRewriter &rewriter, linalg::GenericOp op) { 601 Value red = codegen.redVal; 602 if (!red) 603 return; 604 assert(codegen.curVecLength == 1); 605 codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain 606 OpOperand *lhs = op.getOutputOperand(0); 607 if (auto vtp = red.getType().dyn_cast<VectorType>()) { 608 // TODO: assumes + reductions for now 609 StringAttr kind = rewriter.getStringAttr("add"); 610 Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 611 // Integer reductions don't accept an accumulator. 612 if (vtp.getElementType().isa<IntegerType>()) { 613 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 614 kind, red, ValueRange{}); 615 red = rewriter.create<AddIOp>(op.getLoc(), red, ld); 616 } else { 617 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 618 kind, red, ld); 619 } 620 } 621 genTensorStore(merger, codegen, rewriter, op, lhs, red); 622 } 623 624 /// Recursively generates tensor expression. 625 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 626 linalg::GenericOp op, unsigned exp) { 627 if (merger.exp(exp).kind == Kind::kTensor) 628 return genTensorLoad(merger, codegen, rewriter, op, exp); 629 else if (merger.exp(exp).kind == Kind::kInvariant) 630 return genInvariantValue(merger, codegen, rewriter, exp); 631 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); 632 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); 633 switch (merger.exp(exp).kind) { 634 case Kind::kTensor: 635 case Kind::kInvariant: 636 llvm_unreachable("handled above"); 637 case Kind::kMulF: 638 return rewriter.create<MulFOp>(op.getLoc(), v0, v1); 639 case Kind::kMulI: 640 return rewriter.create<MulIOp>(op.getLoc(), v0, v1); 641 case Kind::kAddF: 642 return rewriter.create<AddFOp>(op.getLoc(), v0, v1); 643 case Kind::kAddI: 644 return rewriter.create<AddIOp>(op.getLoc(), v0, v1); 645 } 646 llvm_unreachable("unexpected expression kind"); 647 } 648 649 /// Hoists loop invariant tensor loads for which indices have been exhausted. 650 static void genInvariants(Merger &merger, CodeGen &codegen, 651 PatternRewriter &rewriter, linalg::GenericOp op, 652 unsigned exp, unsigned ldx, bool hoist) { 653 if (merger.exp(exp).kind == Kind::kTensor) { 654 // Inspect tensor indices. 655 bool atLevel = ldx == -1u; 656 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 657 auto map = op.getTiedIndexingMap(t); 658 auto enc = getSparseTensorEncoding(t->get().getType()); 659 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 660 unsigned idx = map.getDimPosition(perm(enc, d)); 661 if (!codegen.loops[idx]) 662 return; // still in play 663 else if (idx == ldx) 664 atLevel = true; 665 } 666 // All exhausted at this level (atLevel denotes exactly at this level). 667 OpOperand *lhs = op.getOutputOperand(0); 668 if (lhs == t) { 669 codegen.redExp = hoist ? exp : -1u; 670 } else if (atLevel) { 671 merger.exp(exp).val = 672 hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 673 } 674 } else if (merger.exp(exp).kind != Kind::kInvariant) { 675 // Traverse into the binary operations. Note that we only hoist 676 // tensor loads, since subsequent MLIR/LLVM passes know how to 677 // deal with all other kinds of derived loop invariants. 678 unsigned e0 = merger.exp(exp).children.e0; 679 unsigned e1 = merger.exp(exp).children.e1; 680 genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist); 681 genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist); 682 } 683 } 684 685 /// Generates initialization code for the subsequent loop sequence at 686 /// current index level. Returns true if the loop sequence needs to 687 /// maintain the universal index. 688 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 689 linalg::GenericOp op, std::vector<unsigned> &topSort, 690 unsigned at, llvm::BitVector &inits) { 691 bool needsUniv = false; 692 Location loc = op.getLoc(); 693 unsigned idx = topSort[at]; 694 695 // Initialize sparse positions. 696 for (unsigned b = 0, be = inits.size(); b < be; b++) { 697 if (inits[b]) { 698 unsigned tensor = merger.tensor(b); 699 assert(idx == merger.index(b)); 700 if (merger.isDim(b, Dim::kSparse)) { 701 // Initialize sparse index. 702 unsigned pat = at; 703 for (; pat != 0; pat--) { 704 if (codegen.pidxs[tensor][topSort[pat - 1]]) 705 break; 706 } 707 Value ptr = codegen.pointers[tensor][idx]; 708 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 709 Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 710 : codegen.pidxs[tensor][topSort[pat - 1]]; 711 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 712 Value p1 = rewriter.create<AddIOp>(loc, p0, one); 713 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 714 } else { 715 // Dense index still in play. 716 needsUniv = true; 717 } 718 } 719 } 720 721 // Initialize the universal dense index. 722 codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0); 723 return needsUniv; 724 } 725 726 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 727 /// operation is a candidate. Whether it is actually converted to SIMD code 728 /// depends on the requested strategy. 729 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 730 switch (codegen.options.vectorizationStrategy) { 731 case SparseVectorizationStrategy::kNone: 732 return false; 733 case SparseVectorizationStrategy::kDenseInnerLoop: 734 return isInner && !isSparse; 735 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 736 return isInner; 737 } 738 llvm_unreachable("unexpected vectorization strategy"); 739 } 740 741 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 742 /// that is marked "parallel" is a candidate. Whether it is actually converted 743 /// to a parallel operation depends on the requested strategy. 744 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 745 bool isSparse, bool isVector) { 746 switch (codegen.options.parallelizationStrategy) { 747 case SparseParallelizationStrategy::kNone: 748 return false; 749 case SparseParallelizationStrategy::kDenseOuterLoop: 750 return isOuter && !isSparse && !isReduction && !isVector; 751 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 752 return isOuter && !isReduction && !isVector; 753 case SparseParallelizationStrategy::kDenseAnyLoop: 754 return !isSparse && !isReduction && !isVector; 755 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 756 return !isReduction && !isVector; 757 } 758 llvm_unreachable("unexpected parallelization strategy"); 759 } 760 761 /// Checks unit strides for dense tensors. The iteration graph may have ignored 762 /// dense access patterns in order to avoid cycles (sparse access patterns are 763 /// always placed innermost), but that means dense access has become strided. 764 /// For now, we reject vectorization of such cases. 765 /// TODO: implement strided load/stores on dense arrays 766 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 767 unsigned idx) { 768 for (OpOperand *t : op.getInputAndOutputOperands()) { 769 if (!getSparseTensorEncoding(t->get().getType())) { 770 auto map = op.getTiedIndexingMap(t); 771 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 772 if (map.getDimPosition(d) == idx && d != rank - 1) 773 return false; 774 } 775 } 776 } 777 return true; 778 } 779 780 /// Generates a for-loop on a single index. 781 static Operation *genFor(Merger &merger, CodeGen &codegen, 782 PatternRewriter &rewriter, linalg::GenericOp op, 783 bool isOuter, bool isInner, unsigned idx, 784 llvm::BitVector &indices) { 785 unsigned fb = indices.find_first(); 786 unsigned tensor = merger.tensor(fb); 787 assert(idx == merger.index(fb)); 788 auto iteratorTypes = op.iterator_types().getValue(); 789 bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]); 790 bool isSparse = merger.isDim(fb, Dim::kSparse); 791 bool isVector = isVectorFor(codegen, isInner, isSparse) && 792 denseUnitStrides(merger, op, idx); 793 bool isParallel = 794 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 795 796 // Prepare vector length. 797 if (isVector) 798 codegen.curVecLength = codegen.options.vectorLength; 799 800 // Loop bounds and increment. 801 Location loc = op.getLoc(); 802 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 803 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 804 Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength); 805 806 // Emit a parallel loop. 807 if (isParallel) { 808 assert(!isVector); 809 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 810 if (isSparse) 811 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 812 else 813 codegen.loops[idx] = parOp.getInductionVars()[0]; 814 rewriter.setInsertionPointToStart(parOp.getBody()); 815 return parOp; 816 } 817 818 // Emit a sequential loop, potentially with a scalarized reduction. 819 bool scalarRed = isInner && codegen.redExp != -1u; 820 SmallVector<Value, 4> operands; 821 if (scalarRed) { 822 Value load = genReductionStart(merger, codegen, rewriter, op); 823 operands.push_back(load); 824 } 825 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 826 if (scalarRed) { 827 codegen.redVal = merger.exp(codegen.redExp).val = 828 forOp.getRegionIterArgs().front(); 829 } 830 // Assign induction variable to sparse or dense index. 831 Value iv = forOp.getInductionVar(); 832 if (isSparse) 833 codegen.pidxs[tensor][idx] = iv; 834 else 835 codegen.loops[idx] = iv; 836 rewriter.setInsertionPointToStart(forOp.getBody()); 837 // Share vector iteration mask between all subsequent loads/stores. 838 if (isVector) 839 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 840 return forOp; 841 } 842 843 /// Emit a while-loop for co-iteration over multiple indices. 844 static Operation *genWhile(Merger &merger, CodeGen &codegen, 845 PatternRewriter &rewriter, linalg::GenericOp op, 846 unsigned idx, bool needsUniv, 847 llvm::BitVector &indices) { 848 SmallVector<Type, 4> types; 849 SmallVector<Value, 4> operands; 850 // Construct the while-loop with a parameter for each index. 851 Type indexType = rewriter.getIndexType(); 852 for (unsigned b = 0, be = indices.size(); b < be; b++) { 853 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 854 unsigned tensor = merger.tensor(b); 855 assert(idx == merger.index(b)); 856 types.push_back(indexType); 857 assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() && 858 "type mismatch for sparse index"); 859 operands.push_back(codegen.pidxs[tensor][idx]); 860 } 861 } 862 if (needsUniv) { 863 types.push_back(indexType); 864 assert(codegen.loops[idx].getType().isa<IndexType>() && 865 "type mismatch for universal index"); 866 operands.push_back(codegen.loops[idx]); 867 } 868 Location loc = op.getLoc(); 869 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 870 Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 871 Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 872 873 // Build the "before" region, which effectively consists 874 // of a conjunction of "i < upper" tests on all induction. 875 rewriter.setInsertionPointToStart(&whileOp.before().front()); 876 Value cond; 877 unsigned o = 0; 878 for (unsigned b = 0, be = indices.size(); b < be; b++) { 879 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 880 unsigned tensor = merger.tensor(b); 881 assert(idx == merger.index(b)); 882 Value op1 = before->getArgument(o); 883 Value op2 = codegen.highs[tensor][idx]; 884 Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2); 885 cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc; 886 codegen.pidxs[tensor][idx] = after->getArgument(o++); 887 } 888 } 889 if (needsUniv) 890 codegen.loops[idx] = after->getArgument(o++); 891 assert(o == operands.size()); 892 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 893 rewriter.setInsertionPointToStart(&whileOp.after().front()); 894 return whileOp; 895 } 896 897 /// Generates a for-loop or a while-loop, depending on whether it implements 898 /// singleton iteration or co-iteration over the given conjunction. 899 static Operation *genLoop(Merger &merger, CodeGen &codegen, 900 PatternRewriter &rewriter, linalg::GenericOp op, 901 std::vector<unsigned> &topSort, unsigned at, 902 bool needsUniv, llvm::BitVector &indices) { 903 unsigned idx = topSort[at]; 904 if (indices.count() == 1) { 905 bool isOuter = at == 0; 906 bool isInner = at == topSort.size() - 1; 907 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 908 indices); 909 } 910 genReductionEnd(merger, codegen, rewriter, op); // cannot chain 911 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 912 } 913 914 /// Generates the local variables for this loop, consisting of the sparse 915 /// indices, restored universal dense index, and dense positions. 916 static void genLocals(Merger &merger, CodeGen &codegen, 917 PatternRewriter &rewriter, linalg::GenericOp op, 918 std::vector<unsigned> &topSort, unsigned at, 919 bool needsUniv, llvm::BitVector &locals) { 920 Location loc = op.getLoc(); 921 unsigned idx = topSort[at]; 922 923 // Initialize sparse indices. 924 Value min; 925 for (unsigned b = 0, be = locals.size(); b < be; b++) { 926 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 927 unsigned tensor = merger.tensor(b); 928 assert(idx == merger.index(b)); 929 Value ptr = codegen.indices[tensor][idx]; 930 Value s = codegen.pidxs[tensor][idx]; 931 Value load = genLoad(codegen, rewriter, loc, ptr, s); 932 codegen.idxs[tensor][idx] = load; 933 if (!needsUniv) { 934 if (min) { 935 Value cmp = 936 rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min); 937 min = rewriter.create<SelectOp>(loc, cmp, load, min); 938 } else { 939 min = load; 940 } 941 } 942 } 943 } 944 945 // Merge dense universal index over minimum. 946 if (min) { 947 assert(!needsUniv); 948 codegen.loops[idx] = min; 949 } 950 951 // Initialize dense positions. Note that we generate dense indices of the 952 // output tensor unconditionally, since they may not appear in the lattice, 953 // but may be needed for linearized codegen. 954 for (unsigned b = 0, be = locals.size(); b < be; b++) { 955 if ((locals[b] || merger.isOutTensor(b, idx)) && 956 merger.isDim(b, Dim::kDense)) { 957 unsigned tensor = merger.tensor(b); 958 assert(idx == merger.index(b)); 959 unsigned pat = at; 960 for (; pat != 0; pat--) 961 if (codegen.pidxs[tensor][topSort[pat - 1]]) 962 break; 963 Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 964 : codegen.pidxs[tensor][topSort[pat - 1]]; 965 codegen.pidxs[tensor][idx] = genAddress( 966 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 967 } 968 } 969 } 970 971 /// Generates the induction structure for a while-loop. 972 static void genWhileInduction(Merger &merger, CodeGen &codegen, 973 PatternRewriter &rewriter, linalg::GenericOp op, 974 unsigned idx, bool needsUniv, 975 llvm::BitVector &induction, ResultRange results) { 976 Location loc = op.getLoc(); 977 unsigned o = 0; 978 SmallVector<Value, 4> operands; 979 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 980 for (unsigned b = 0, be = induction.size(); b < be; b++) { 981 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 982 unsigned tensor = merger.tensor(b); 983 assert(idx == merger.index(b)); 984 Value op1 = codegen.idxs[tensor][idx]; 985 Value op2 = codegen.loops[idx]; 986 Value op3 = codegen.pidxs[tensor][idx]; 987 Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 988 Value add = rewriter.create<AddIOp>(loc, op3, one); 989 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 990 codegen.pidxs[tensor][idx] = results[o++]; 991 } 992 } 993 if (needsUniv) { 994 operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one)); 995 codegen.loops[idx] = results[o++]; 996 } 997 assert(o == operands.size()); 998 rewriter.create<scf::YieldOp>(loc, operands); 999 } 1000 1001 /// Generates a single if-statement within a while-loop. 1002 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1003 PatternRewriter &rewriter, linalg::GenericOp op, 1004 unsigned idx, llvm::BitVector &conditions) { 1005 Location loc = op.getLoc(); 1006 Value cond; 1007 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1008 if (conditions[b]) { 1009 unsigned tensor = merger.tensor(b); 1010 assert(idx == merger.index(b)); 1011 Value clause; 1012 if (merger.isDim(b, Dim::kSparse)) { 1013 Value op1 = codegen.idxs[tensor][idx]; 1014 Value op2 = codegen.loops[idx]; 1015 clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 1016 } else { 1017 clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true 1018 } 1019 cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause; 1020 } 1021 } 1022 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true); 1023 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 1024 return ifOp; 1025 } 1026 1027 /// Recursively generates code while computing iteration lattices in order 1028 /// to manage the complexity of implementing co-iteration over unions 1029 /// and intersections of sparse iterations spaces. 1030 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1031 linalg::GenericOp op, std::vector<unsigned> &topSort, 1032 unsigned exp, unsigned at) { 1033 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1034 if (at == topSort.size()) { 1035 OpOperand *lhs = op.getOutputOperand(0); 1036 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1037 genTensorStore(merger, codegen, rewriter, op, lhs, rhs); 1038 return; 1039 } 1040 assert(codegen.curVecLength == 1); 1041 1042 // Construct iteration lattices for current loop index, with L0 at top. 1043 // Then emit initialization code for the loop sequence at this level. 1044 // We maintain the universal dense index if dense indices are still 1045 // in play for a non-singleton loop sequence. 1046 Location loc = op.getLoc(); 1047 unsigned idx = topSort[at]; 1048 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1049 unsigned lsize = merger.set(lts).size(); 1050 assert(lsize != 0); 1051 unsigned l0 = merger.set(lts)[0]; 1052 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1053 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true); 1054 bool needsUniv = false; 1055 if (genInit(merger, codegen, rewriter, op, topSort, at, 1056 merger.lat(l0).bits)) { 1057 // Maintain the universal index only if it is actually 1058 // consumed by a subsequent lattice point. 1059 for (unsigned i = 1; i < lsize; i++) { 1060 unsigned li = merger.set(lts)[i]; 1061 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) { 1062 needsUniv = true; 1063 break; 1064 } 1065 } 1066 } 1067 1068 // Emit a loop for every lattice point L0 >= Li. 1069 for (unsigned i = 0; i < lsize; i++) { 1070 unsigned li = merger.set(lts)[i]; 1071 1072 // Emit loop. 1073 codegen.curVecLength = 1; 1074 llvm::BitVector indices = merger.lat(li).simple; 1075 Operation *loop = 1076 genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices); 1077 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1078 merger.lat(li).bits); 1079 1080 // Visit all lattices points with Li >= Lj to generate the 1081 // loop-body, possibly with if statements for coiteration. 1082 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1083 for (unsigned j = 0; j < lsize; j++) { 1084 unsigned lj = merger.set(lts)[j]; 1085 unsigned ej = merger.lat(lj).exp; 1086 if (li == lj || merger.latGT(li, lj)) { 1087 // Recurse into body of each branch. 1088 if (isWhile) { 1089 scf::IfOp ifOp = 1090 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1091 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1092 rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1093 } else { 1094 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1095 } 1096 } 1097 } 1098 1099 // Wrap-up induction and restore insertion point. 1100 if (isWhile) { 1101 scf::WhileOp whileOp = cast<scf::WhileOp>(loop); 1102 rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1103 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1104 merger.lat(li).bits, whileOp.results()); 1105 } else { 1106 needsUniv = false; 1107 if (codegen.redVal) { 1108 rewriter.create<scf::YieldOp>(loc, codegen.redVal); 1109 codegen.redVal = loop->getResult(0); 1110 } 1111 } 1112 rewriter.setInsertionPointAfter(loop); 1113 } 1114 1115 // Wrap-up loop sequence. 1116 codegen.curVecLength = 1; 1117 genReductionEnd(merger, codegen, rewriter, op); 1118 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); 1119 codegen.loops[idx] = Value(); 1120 } 1121 1122 /// Converts the result computed by the sparse kernel into the required form. 1123 static void genResult(Merger &merger, CodeGen &codegen, 1124 PatternRewriter &rewriter, linalg::GenericOp op) { 1125 Location loc = op.getLoc(); 1126 OpOperand *lhs = op.getOutputOperand(0); 1127 Type resType = lhs->get().getType(); 1128 unsigned tensor = lhs->getOperandNumber(); 1129 auto map = op.getTiedIndexingMap(lhs); 1130 auto enc = getSparseTensorEncoding(resType); 1131 Value result = codegen.buffers.back(); // value array 1132 if (enc) { 1133 // The sparse annotation unambigiously defines the arrays needed 1134 // to "reconstruct" the sparse tensor from the storage scheme 1135 // (even though lowering should never need this eventually). 1136 SmallVector<Value, 4> args; 1137 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1138 unsigned idx = map.getDimPosition(perm(enc, d)); 1139 if (merger.isDim(tensor, idx, Dim::kSparse)) { 1140 args.push_back(codegen.pointers[tensor][idx]); 1141 args.push_back(codegen.indices[tensor][idx]); 1142 } 1143 } 1144 args.push_back(result); 1145 result = rewriter.create<ToTensorOp>(loc, resType, args); 1146 } else { 1147 // To "reconstruct" an non-annotated tensor, sipmly load it 1148 // from the bufferized value. 1149 result = rewriter.create<memref::TensorLoadOp>(loc, resType, result); 1150 } 1151 rewriter.replaceOp(op, result); 1152 } 1153 1154 namespace { 1155 1156 /// Sparse rewriting rule for generic Lingalg operation. 1157 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1158 public: 1159 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1160 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1161 1162 LogicalResult matchAndRewrite(linalg::GenericOp op, 1163 PatternRewriter &rewriter) const override { 1164 // Detects sparse annotations and translate the per-dimension sparsity 1165 // information for all tensors to loop indices in the kernel. 1166 assert(op.getNumOutputs() == 1); 1167 unsigned numTensors = op.getNumInputsAndOutputs(); 1168 unsigned numLoops = op.iterator_types().getValue().size(); 1169 Merger merger(numTensors, numLoops); 1170 if (!findSparseAnnotations(merger, op)) 1171 return failure(); 1172 1173 // Computes a topologically sorted iteration graph to ensure 1174 // tensors are visited in natural index order. Fails on cycles. 1175 // This assumes that higher-level passes have already put the 1176 // tensors in each tensor expression in a feasible order. 1177 std::vector<unsigned> topSort; 1178 if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) && 1179 !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true)) 1180 return failure(); 1181 1182 // Builds the tensor expression for the Linalg operation in SSA form. 1183 Optional<unsigned> exp = merger.buildTensorExpFromLinalg(op); 1184 if (!exp.hasValue()) 1185 return failure(); 1186 1187 // Rejects an inadmissable tensor expression. 1188 if (!isAdmissableTensorExp(merger, op, exp.getValue())) 1189 return failure(); 1190 1191 // Recursively generates code. 1192 CodeGen codegen(options, numTensors, numLoops); 1193 if (!genBuffers(merger, codegen, rewriter, op)) 1194 return failure(); // could not bufferize 1195 genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); 1196 genResult(merger, codegen, rewriter, op); 1197 return success(); 1198 } 1199 1200 private: 1201 /// Options to control sparse code generation. 1202 SparsificationOptions options; 1203 }; 1204 1205 } // namespace 1206 1207 /// Populates the given patterns list with rewriting rules required for 1208 /// the sparsification of linear algebra operations. 1209 void mlir::populateSparsificationPatterns( 1210 RewritePatternSet &patterns, const SparsificationOptions &options) { 1211 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1212 } 1213