1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering sparse tensor types to actual sparse code. 10 // 11 // The concept of letting a compiler generate sparse code automatically was 12 // pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and 13 // formalized to tensor algebra by [Kjolstad17,20] for the Sparse Tensor 14 // Algebra Compiler (TACO). The implementation in this file closely follows 15 // the "sparse iteration theory" that forms the foundation of TACO. A rewriting 16 // rule is applied to each tensor expression in linalg (MLIR's tensor index 17 // notation) where the sparsity of tensors is indicated with annotation using 18 // a per-dimension specification of sparse/dense storage together with a 19 // specification of the order on the dimensions. Subsequently, a topologically 20 // sorted iteration graph, reflecting the required order on indices with respect 21 // to the dimensions of each tensor, is constructed to ensure that all tensors 22 // are visited in natural index order. Next, iteration lattices are constructed 23 // for the tensor expression for every index in topological order. Each 24 // iteration lattice point consists of a conjunction of tensor indices together 25 // with a tensor (sub)expression that needs to be evaluated for that 26 // conjunction. Within the lattice, iteration points are ordered according to 27 // the way indices are exhausted. As such these iteration lattices drive actual 28 // sparse code generation, which consists of a tedious but relatively 29 // straightforward one-to-one mapping from iteration lattices to combinations 30 // of for-loops, while-loops, and if-statements. 31 // 32 // [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations. 33 // PhD thesis, Leiden University, May 1996 (aartbik.com/sparse.php). 34 // [Kjolstad17] Fredrik Berg Kjolstad, Shoaib Ashraf Kamil, Stephen Chou, 35 // David Lugato, and Saman Amarasinghe. The Tensor Algebra Compiler. 36 // Proceedings of the ACM on Programming Languages, October 2017. 37 // [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation. 38 // PhD thesis, MIT, February, 2020 (tensor-compiler.org). 39 // 40 // Implementation detail: We use llvm::SmallVector for vectors with 41 // variable lengths and std::vector for vectors with fixed lengths. 42 //===----------------------------------------------------------------------===// 43 44 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 45 #include "mlir/Dialect/Linalg/Utils/Utils.h" 46 #include "mlir/Dialect/MemRef/IR/MemRef.h" 47 #include "mlir/Dialect/SCF/SCF.h" 48 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 49 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 50 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 51 #include "mlir/Dialect/StandardOps/IR/Ops.h" 52 #include "mlir/Dialect/Vector/VectorOps.h" 53 #include "mlir/IR/Matchers.h" 54 #include "mlir/IR/TensorEncoding.h" 55 #include "llvm/ADT/SmallBitVector.h" 56 57 using namespace mlir; 58 using namespace mlir::sparse_tensor; 59 60 namespace { 61 62 // Code generation. 63 struct CodeGen { 64 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) 65 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 66 pointers(numTensors, std::vector<Value>(numLoops)), 67 indices(numTensors, std::vector<Value>(numLoops)), 68 highs(numTensors, std::vector<Value>(numLoops)), 69 pidxs(numTensors, std::vector<Value>(numLoops)), 70 idxs(numTensors, std::vector<Value>(numLoops)), redExp(-1u), redVal(), 71 curVecLength(1), curVecMask() {} 72 /// Sparsification options. 73 SparsificationOptions options; 74 /// Universal dense indices and upper bounds (by index). The loops array 75 /// is updated with the value of the universal dense index in the current 76 /// loop. The sizes array is set once with the inferred dimension sizes. 77 std::vector<Value> loops; 78 std::vector<Value> sizes; 79 /// Buffers for storing dense and sparse numerical values (by tensor). 80 /// This array is set once during bufferization of all tensors. 81 std::vector<Value> buffers; 82 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 83 /// This array is set once during bufferization of all sparse tensors. 84 std::vector<std::vector<Value>> pointers; 85 std::vector<std::vector<Value>> indices; 86 /// Sparse iteration information (by tensor and index). These arrays 87 /// are updated to remain current within the current loop. 88 std::vector<std::vector<Value>> highs; 89 std::vector<std::vector<Value>> pidxs; 90 std::vector<std::vector<Value>> idxs; 91 /// Current reduction, updated during code generation. When indices of a 92 /// reduction are exhausted, all inner loops can "scalarize" the reduction. 93 // TODO: currently only done for (a chain of) innermost for-loops, where it 94 // is most effective; we could generalize to more outer and while-loops. 95 unsigned redExp; 96 Value redVal; 97 // Current vector length and mask. 98 unsigned curVecLength; 99 Value curVecMask; 100 }; 101 102 } // namespace 103 104 // Helper method to apply dimension ordering permutation. 105 static unsigned perm(SparseTensorEncodingAttr &enc, unsigned d) { 106 if (enc) { 107 auto order = enc.getDimOrdering(); 108 if (order) { 109 assert(order.isPermutation()); 110 return order.getDimPosition(d); 111 } 112 } 113 return d; 114 } 115 116 // Helper method to translate dim level type to internal representation. 117 static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) { 118 if (enc) { 119 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 120 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 121 return Dim::kSparse; 122 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 123 return Dim::kSingle; 124 } 125 return Dim::kDense; 126 } 127 128 /// Helper method to inspect sparse encodings in the tensor types. 129 /// Fills the per-dimension sparsity information for all tensors. 130 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 131 bool annotated = false; 132 for (OpOperand *t : op.getInputAndOutputOperands()) { 133 auto map = op.getTiedIndexingMap(t); 134 if (!map.isProjectedPermutation()) 135 return false; 136 auto enc = getSparseTensorEncoding(t->get().getType()); 137 if (enc) 138 annotated = true; 139 assert(map.getNumResults() == op.getRank(t)); 140 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 141 unsigned idx = map.getDimPosition(perm(enc, d)); 142 merger.setDim(t->getOperandNumber(), idx, toDim(enc, d)); 143 } 144 } 145 return annotated; 146 } 147 148 /// A DFS helper to compute a topological sort. Note that recursion is 149 /// bounded by the number of implicit loops, which is always small. 150 /// Returns false when a cycle is detected. 151 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 152 std::vector<unsigned> &topSort, 153 std::vector<std::vector<bool>> &adjM) { 154 if (visit[i] != 0) 155 return visit[i] != 1; // 1 denotes cycle! 156 visit[i] = 1; 157 for (unsigned j = 0, e = visit.size(); j < e; j++) 158 if (adjM[i][j]) 159 if (!topSortDFS(j, visit, topSort, adjM)) 160 return false; 161 visit[i] = 2; 162 topSort.push_back(i); 163 return true; 164 } 165 166 /// Computes a topologically sorted iteration graph for the linalg operation. 167 /// Ensures all tensors are visited in natural index order. This is essential 168 /// for sparse storage formats since these only support access along fixed 169 /// dimensions. Even for dense storage formats, however, the natural index 170 /// order yields innermost unit-stride access with better spatial locality. 171 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 172 std::vector<unsigned> &topSort, 173 bool sparseOnly) { 174 // Set up an n x n from/to adjacency matrix of the iteration graph 175 // for the implicit loop indices i_0 .. i_n-1. 176 unsigned n = op.getNumLoops(); 177 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 178 179 // Iterate over the indexing maps of every tensor in the tensor expression. 180 for (OpOperand *t : op.getInputAndOutputOperands()) { 181 auto map = op.getTiedIndexingMap(t); 182 auto enc = getSparseTensorEncoding(t->get().getType()); 183 assert(map.getNumDims() == n); 184 // Skip dense tensor constraints when sparse only is requested. 185 if (sparseOnly && !enc) 186 continue; 187 // Each tensor expression and optional dimension ordering (row-major 188 // by default) puts an ordering constraint on the loop indices. For 189 // example, the tensor expresion A_ijk forces the ordering i < j < k 190 // on the loop indices if no explicit dimension ordering is given. 191 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 192 unsigned f = map.getDimPosition(perm(enc, d - 1)); 193 unsigned t = map.getDimPosition(perm(enc, d)); 194 adjM[f][t] = true; 195 } 196 } 197 198 // Topologically sort the iteration graph to determine loop order. 199 // Report failure for a cyclic iteration graph. 200 topSort.clear(); 201 topSort.reserve(n); 202 std::vector<unsigned> visit(n, 0); 203 for (unsigned i = 0; i < n; i++) 204 if (visit[i] == 0) 205 if (!topSortDFS(i, visit, topSort, adjM)) 206 return false; // cycle! 207 std::reverse(std::begin(topSort), std::end(topSort)); 208 return true; 209 } 210 211 /// Traverses the SSA tree (possibly a DAG) to build a tensor expression. 212 /// This simplifies constructing (sub)expressions during iteration lattice 213 /// building (compared to using the SSA representation everywhere). 214 static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op, 215 Value val) { 216 if (auto arg = val.dyn_cast<BlockArgument>()) { 217 unsigned argN = arg.getArgNumber(); 218 // Any argument of the generic op that is not marked as a scalar 219 // argument is considered a tensor, indexed by the implicit loop 220 // bounds. This includes rank-0 tensor arguments. 221 if (arg.getOwner()->getParentOp() == op) { 222 OpOperand *t = op.getInputAndOutputOperands()[argN]; 223 if (!op.isScalar(t)) 224 return merger.addExp(Kind::kTensor, argN); 225 val = t->get(); // get scalar value 226 } 227 // Any other argument (marked as scalar argument for the generic op 228 // or belonging to an enveloping op) is considered invariant. 229 return merger.addExp(Kind::kInvariant, val); 230 } 231 Operation *def = val.getDefiningOp(); 232 if (def->getBlock() != &op.region().front()) { 233 // Something defined outside is invariant. 234 return merger.addExp(Kind::kInvariant, val); 235 } else if (def->getNumOperands() == 2) { 236 // Construct binary operations if subexpressions could be built. 237 auto x = buildTensorExp(merger, op, def->getOperand(0)); 238 auto y = buildTensorExp(merger, op, def->getOperand(1)); 239 if (x.hasValue() && y.hasValue()) { 240 unsigned e0 = x.getValue(); 241 unsigned e1 = y.getValue(); 242 if (isa<MulFOp>(def)) 243 return merger.addExp(Kind::kMulF, e0, e1); 244 if (isa<MulIOp>(def)) 245 return merger.addExp(Kind::kMulI, e0, e1); 246 if (isa<AddFOp>(def)) 247 return merger.addExp(Kind::kAddF, e0, e1); 248 if (isa<AddIOp>(def)) 249 return merger.addExp(Kind::kAddI, e0, e1); 250 } 251 } 252 // Cannot build (yet). 253 return None; 254 } 255 256 /// Returns true if given tensor co-iterates with conjunction only. 257 /// For the output tensor, this defines a "simply dynamic" operation. 258 /// For instance: A(I) = A(I) * B(I) * C(I) 259 static unsigned isConjunction(Merger &merger, unsigned tensor, unsigned exp) { 260 switch (merger.exp(exp).kind) { 261 case Kind::kTensor: 262 return merger.exp(exp).e0 == tensor; 263 case Kind::kMulF: 264 case Kind::kMulI: 265 return isConjunction(merger, tensor, merger.exp(exp).e0) || 266 isConjunction(merger, tensor, merger.exp(exp).e1); 267 default: 268 return false; 269 } 270 } 271 272 /// Returns true when the tensor expression is admissable for codegen. 273 /// Since all sparse input tensors are admissable, we just need to check 274 /// whether the output tensor in the tensor expression codegen is admissable. 275 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 276 unsigned exp) { 277 OpOperand *lhs = op.getOutputOperand(0); 278 unsigned tensor = lhs->getOperandNumber(); 279 auto enc = getSparseTensorEncoding(lhs->get().getType()); 280 // An non-annotated output tensor is assumed dense, and becomes a random 281 // access n-dim memref. Admissable since inserstions cannot occur. 282 if (!enc) 283 return true; 284 // An all-dense annotated "sparse" output tensor becomes a linearized random 285 // access 1-dim memref. Also admissable since insertions cannot occur. 286 bool allDense = true; 287 unsigned numLoops = op.iterator_types().getValue().size(); 288 for (unsigned i = 0; i < numLoops; i++) 289 if (merger.isDim(tensor, i, Dim::kSparse)) { 290 allDense = false; 291 break; 292 } 293 if (allDense) 294 return true; 295 // A tensor expression with a sparse output tensor that changes its values 296 // but not its nonzero structure, an operation called "simply dynamic" in 297 // [Bik96,Ch9], is also admissable without special codegen. 298 if (isConjunction(merger, tensor, exp)) 299 return true; 300 // Reject for now since this requires changes to the nonzero structure. 301 // TODO: implement "workspaces" [Kjolstad2019] 302 return false; 303 } 304 305 /// Builds the iteration lattices in a bottom-up traversal given the remaining 306 /// tensor (sub)expression and the next loop index in the iteration graph. 307 static unsigned buildLattices(Merger &merger, linalg::GenericOp op, 308 unsigned exp, unsigned idx) { 309 Kind kind = merger.exp(exp).kind; 310 if (kind == Kind::kTensor || kind == Kind::kInvariant) { 311 // Either the index is really used in the tensor expression, or it is 312 // set to the undefined index in that dimension. An invariant expression 313 // is set to a synthetic tensor with undefined indices only. 314 unsigned s = merger.addSet(); 315 unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0 316 : op.getNumInputsAndOutputs(); 317 merger.set(s).push_back(merger.addLat(t, idx, exp)); 318 return s; 319 } 320 unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx); 321 unsigned s1 = buildLattices(merger, op, merger.exp(exp).e1, idx); 322 switch (kind) { 323 case Kind::kTensor: 324 case Kind::kInvariant: 325 llvm_unreachable("handled above"); 326 case Kind::kMulF: 327 case Kind::kMulI: 328 return merger.takeConj(kind, s0, s1); 329 case Kind::kAddF: 330 case Kind::kAddI: 331 return merger.takeDisj(kind, s0, s1); 332 } 333 llvm_unreachable("unexpected expression kind"); 334 } 335 336 /// Maps sparse integer option to actual integral storage type. 337 static Type genIntType(PatternRewriter &rewriter, unsigned width) { 338 if (width == 0) 339 return rewriter.getIndexType(); 340 return rewriter.getIntegerType(width); 341 } 342 343 /// Detects in-place annotation on tensor argument. 344 static bool getInPlace(Value val) { 345 if (auto arg = val.dyn_cast<BlockArgument>()) 346 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 347 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 348 arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) 349 return attr.getValue(); 350 return false; 351 } 352 353 /// Generates buffer for the output tensor. 354 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 355 linalg::GenericOp op, MemRefType denseTp, 356 ArrayRef<Value> args) { 357 Location loc = op.getLoc(); 358 Value tensor = op.getOutputOperand(0)->get(); 359 // The output tensor simply could materialize from the buffer that will 360 // be generated for the tensor present in the outs() clause. This has 361 // the major advantage that the sparse kernel only updates the nonzero 362 // positions for the output tensor. 363 if (getInPlace(tensor)) 364 return rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 365 // By default, a new buffer is allocated which is initialized to the 366 // tensor defined in the outs() clause. This is always correct but 367 // introduces a dense initialization component that may negatively 368 // impact the running complexity of the sparse kernel. 369 Value init = rewriter.create<memref::BufferCastOp>(loc, denseTp, tensor); 370 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 371 rewriter.create<linalg::CopyOp>(loc, init, alloc); 372 return alloc; 373 } 374 375 /// Local bufferization of all dense and sparse data structures. 376 /// This code enables testing the first prototype sparse compiler. 377 // TODO: replace this with a proliferated bufferization strategy 378 static bool genBuffers(Merger &merger, CodeGen &codegen, 379 PatternRewriter &rewriter, linalg::GenericOp op) { 380 Location loc = op.getLoc(); 381 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 382 // For every tensor, find lower and upper bound on dimensions, set the 383 // same bounds on loop indices, and obtain dense or sparse buffer(s). 384 SmallVector<Value, 4> args; 385 for (OpOperand *t : op.getInputAndOutputOperands()) { 386 unsigned tensor = t->getOperandNumber(); 387 auto shape = op.getShape(t); 388 auto map = op.getTiedIndexingMap(t); 389 auto enc = getSparseTensorEncoding(t->get().getType()); 390 // Scan all dimensions of current tensor. 391 args.clear(); 392 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 393 unsigned idx = map.getDimPosition(perm(enc, d)); 394 // Handle sparse storage schemes. 395 if (merger.isDim(tensor, idx, Dim::kSparse)) { 396 auto dynShape = {ShapedType::kDynamicSize}; 397 auto ptrTp = MemRefType::get( 398 dynShape, genIntType(rewriter, enc.getPointerBitWidth())); 399 auto indTp = MemRefType::get( 400 dynShape, genIntType(rewriter, enc.getIndexBitWidth())); 401 Value dim = rewriter.create<ConstantIndexOp>(loc, d); 402 // Generate sparse primitives to obtains pointer and indices. 403 codegen.pointers[tensor][idx] = 404 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 405 codegen.indices[tensor][idx] = 406 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 407 } 408 // Find lower and upper bound in current dimension. 409 Value up; 410 if (shape[d] == MemRefType::kDynamicSize) { 411 up = rewriter.create<memref::DimOp>(loc, t->get(), d); 412 args.push_back(up); 413 } else { 414 up = rewriter.create<ConstantIndexOp>(loc, shape[d]); 415 } 416 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 417 } 418 // Perform the required bufferization. Dense inputs materialize 419 // from the input tensors. Dense outputs need special handling. 420 // Sparse inputs use sparse primitives to obtain the values. 421 // We also accept in-place all-dense annotated "sparse" outputs. 422 Type elementType = getElementTypeOrSelf(t->get().getType()); 423 if (!enc) { 424 // Non-annotated dense tensors. 425 auto denseTp = MemRefType::get(shape, elementType); 426 if (tensor < op.getNumInputs()) 427 codegen.buffers[tensor] = 428 rewriter.create<memref::BufferCastOp>(loc, denseTp, t->get()); 429 else 430 codegen.buffers[tensor] = 431 genOutputBuffer(codegen, rewriter, op, denseTp, args); 432 } else { 433 // Annotated sparse tensors. 434 if (tensor == op.getNumInputs() && !getInPlace(t->get())) 435 return false; // reject output if not in-place 436 auto dynShape = {ShapedType::kDynamicSize}; 437 auto sparseTp = MemRefType::get(dynShape, elementType); 438 codegen.buffers[tensor] = 439 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 440 } 441 } 442 return true; 443 } 444 445 /// Constructs vector type. 446 static VectorType vectorType(CodeGen &codegen, Type etp) { 447 return VectorType::get(codegen.curVecLength, etp); 448 } 449 450 /// Constructs vector type from pointer. 451 static VectorType vectorType(CodeGen &codegen, Value ptr) { 452 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 453 } 454 455 /// Constructs vector iteration mask. 456 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 457 Value iv, Value lo, Value hi, Value step) { 458 Location loc = iv.getLoc(); 459 VectorType mtp = vectorType(codegen, rewriter.getIntegerType(1)); 460 // Special case if the vector length evenly divides the trip count (for 461 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 462 // so that all subsequent masked memory operations are immediately folded 463 // into unconditional memory operations. 464 IntegerAttr loInt, hiInt, stepInt; 465 if (matchPattern(lo, m_Constant(&loInt)) && 466 matchPattern(hi, m_Constant(&hiInt)) && 467 matchPattern(step, m_Constant(&stepInt))) { 468 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 469 return rewriter.create<vector::BroadcastOp>( 470 loc, mtp, rewriter.create<ConstantIntOp>(loc, 1, 1)); 471 } 472 // Otherwise, generate a vector mask that avoids overrunning the upperbound 473 // during vector execution. Here we rely on subsequent loop optimizations to 474 // avoid executing the mask in all iterations, for example, by splitting the 475 // loop into an unconditional vector loop and a scalar cleanup loop. 476 Value end = rewriter.create<SubIOp>(loc, hi, iv); 477 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 478 } 479 480 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 481 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 482 Value ptr, ArrayRef<Value> args) { 483 Location loc = ptr.getLoc(); 484 VectorType vtp = vectorType(codegen, ptr); 485 Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp)); 486 if (args.back().getType().isa<VectorType>()) { 487 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 488 Value indexVec = args.back(); 489 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 490 return rewriter.create<vector::GatherOp>( 491 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 492 } 493 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 494 codegen.curVecMask, pass); 495 } 496 497 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 498 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 499 Value rhs, Value ptr, ArrayRef<Value> args) { 500 Location loc = ptr.getLoc(); 501 if (args.back().getType().isa<VectorType>()) { 502 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 503 Value indexVec = args.back(); 504 scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0); 505 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 506 codegen.curVecMask, rhs); 507 return; 508 } 509 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 510 rhs); 511 } 512 513 /// Generates a vectorized invariant. Here we rely on subsequent loop 514 /// optimizations to hoist the invariant broadcast out of the vector loop. 515 static Value genVectorInvariantValue(CodeGen &codegen, 516 PatternRewriter &rewriter, Value val) { 517 VectorType vtp = vectorType(codegen, val.getType()); 518 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 519 } 520 521 /// Generates a load on a dense or sparse tensor. 522 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 523 PatternRewriter &rewriter, linalg::GenericOp op, 524 unsigned exp) { 525 // Test if the load was hoisted to a higher loop nest. 526 Value val = merger.exp(exp).val; 527 if (val) { 528 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 529 return genVectorInvariantValue(codegen, rewriter, val); 530 return val; 531 } 532 // Actual load. 533 SmallVector<Value, 4> args; 534 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0]; 535 unsigned tensor = t->getOperandNumber(); 536 auto map = op.getTiedIndexingMap(t); 537 auto enc = getSparseTensorEncoding(t->get().getType()); 538 unsigned rank = map.getNumResults(); 539 if (enc) { 540 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 541 assert(codegen.pidxs[tensor][idx] != nullptr); 542 args.push_back(codegen.pidxs[tensor][idx]); // position index 543 } else { 544 for (unsigned d = 0; d < rank; d++) { 545 unsigned idx = map.getDimPosition(d); 546 args.push_back(codegen.loops[idx]); // universal dense index 547 } 548 } 549 Location loc = op.getLoc(); 550 Value ptr = codegen.buffers[tensor]; 551 if (codegen.curVecLength > 1) 552 return genVectorLoad(codegen, rewriter, ptr, args); 553 return rewriter.create<memref::LoadOp>(loc, ptr, args); 554 } 555 556 /// Generates a store on a dense or sparse tensor. 557 static void genTensorStore(Merger &merger, CodeGen &codegen, 558 PatternRewriter &rewriter, linalg::GenericOp op, 559 OpOperand *t, Value rhs) { 560 Location loc = op.getLoc(); 561 // Test if this is a scalarized reduction. 562 OpOperand *lhs = op.getOutputOperand(0); 563 if (lhs == t && codegen.redVal) { 564 if (codegen.curVecLength > 1) 565 rhs = rewriter.create<SelectOp>(loc, codegen.curVecMask, rhs, 566 codegen.redVal); 567 codegen.redVal = rhs; 568 return; 569 } 570 // Actual store. 571 SmallVector<Value, 4> args; 572 unsigned tensor = t->getOperandNumber(); 573 auto map = op.getTiedIndexingMap(t); 574 auto enc = getSparseTensorEncoding(t->get().getType()); 575 unsigned rank = map.getNumResults(); 576 if (enc) { 577 unsigned idx = map.getDimPosition(perm(enc, rank - 1)); 578 assert(codegen.pidxs[tensor][idx] != nullptr); 579 args.push_back(codegen.pidxs[tensor][idx]); // position index 580 } else { 581 for (unsigned d = 0; d < rank; d++) { 582 unsigned idx = map.getDimPosition(d); 583 args.push_back(codegen.loops[idx]); // universal dense index 584 } 585 } 586 Value ptr = codegen.buffers[tensor]; 587 if (codegen.curVecLength > 1) 588 genVectorStore(codegen, rewriter, rhs, ptr, args); 589 else 590 rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 591 } 592 593 /// Generates a pointer/index load from the sparse storage scheme. Narrower 594 /// data types need to be zero extended before casting the value into the 595 /// index type used for looping and indexing. 596 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 597 Value ptr, Value s) { 598 // See https://llvm.org/docs/GetElementPtr.html for some background on 599 // the complications described below. 600 if (codegen.curVecLength > 1) { 601 // Since the index vector is used in a subsequent gather/scatter operations, 602 // which effectively defines an unsigned pointer + signed index, we must 603 // zero extend the vector to an index width. For 8-bit and 16-bit values, 604 // an 32-bit index width suffices. For 32-bit values, zero extending the 605 // elements into 64-bit loses some performance since the 32-bit indexed 606 // gather/scatter is more efficient than the 64-bit index variant (if the 607 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 608 // preserve this performance). For 64-bit values, there is no good way 609 // to state that the indices are unsigned, with creates the potential of 610 // incorrect address calculations in the unlikely case we need such 611 // extremely large offsets. 612 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 613 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 614 if (!etp.isa<IndexType>()) { 615 if (etp.getIntOrFloatBitWidth() < 32) 616 vload = rewriter.create<ZeroExtendIOp>( 617 loc, vload, vectorType(codegen, rewriter.getIntegerType(32))); 618 else if (etp.getIntOrFloatBitWidth() < 64 && 619 !codegen.options.enableSIMDIndex32) 620 vload = rewriter.create<ZeroExtendIOp>( 621 loc, vload, vectorType(codegen, rewriter.getIntegerType(64))); 622 } 623 return vload; 624 } 625 // For the scalar case, we simply zero extend narrower indices into 64-bit 626 // values before casting to index without a performance penalty. Here too, 627 // however, indices that already are 64-bit, in theory, cannot express the 628 // full range as explained above. 629 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 630 if (!load.getType().isa<IndexType>()) { 631 if (load.getType().getIntOrFloatBitWidth() < 64) 632 load = rewriter.create<ZeroExtendIOp>(loc, load, 633 rewriter.getIntegerType(64)); 634 load = rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType()); 635 } 636 return load; 637 } 638 639 /// Generates an invariant value. 640 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 641 PatternRewriter &rewriter, unsigned exp) { 642 Value val = merger.exp(exp).val; 643 if (codegen.curVecLength > 1) 644 return genVectorInvariantValue(codegen, rewriter, val); 645 return val; 646 } 647 648 /// Generates an address computation "sz * p + i". 649 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 650 Location loc, Value size, Value p, Value i) { 651 Value mul = rewriter.create<MulIOp>(loc, size, p); 652 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 653 Value inv = rewriter.create<IndexCastOp>(loc, mul, vtp.getElementType()); 654 mul = genVectorInvariantValue(codegen, rewriter, inv); 655 } 656 return rewriter.create<AddIOp>(loc, mul, i); 657 } 658 659 /// Generates start of a reduction. 660 static Value genReductionStart(Merger &merger, CodeGen &codegen, 661 PatternRewriter &rewriter, 662 linalg::GenericOp op) { 663 if (codegen.redVal) 664 return codegen.redVal; // chained with previous for-loop 665 if (codegen.curVecLength > 1) { 666 // TODO: assumes + reductions for now 667 VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]); 668 return rewriter.create<ConstantOp>(op.getLoc(), vtp, 669 rewriter.getZeroAttr(vtp)); 670 } 671 return genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 672 } 673 674 /// Generates end of a reduction. 675 static void genReductionEnd(Merger &merger, CodeGen &codegen, 676 PatternRewriter &rewriter, linalg::GenericOp op) { 677 Value red = codegen.redVal; 678 if (!red) 679 return; 680 assert(codegen.curVecLength == 1); 681 codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain 682 OpOperand *lhs = op.getOutputOperand(0); 683 if (auto vtp = red.getType().dyn_cast<VectorType>()) { 684 // TODO: assumes + reductions for now 685 StringAttr kind = rewriter.getStringAttr("add"); 686 Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); 687 // Integer reductions don't accept an accumulator. 688 if (vtp.getElementType().isa<IntegerType>()) { 689 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 690 kind, red, ValueRange{}); 691 red = rewriter.create<AddIOp>(op.getLoc(), red, ld); 692 } else { 693 red = rewriter.create<vector::ReductionOp>(op.getLoc(), ld.getType(), 694 kind, red, ld); 695 } 696 } 697 genTensorStore(merger, codegen, rewriter, op, lhs, red); 698 } 699 700 /// Recursively generates tensor expression. 701 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 702 linalg::GenericOp op, unsigned exp) { 703 if (merger.exp(exp).kind == Kind::kTensor) 704 return genTensorLoad(merger, codegen, rewriter, op, exp); 705 else if (merger.exp(exp).kind == Kind::kInvariant) 706 return genInvariantValue(merger, codegen, rewriter, exp); 707 Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0); 708 Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1); 709 switch (merger.exp(exp).kind) { 710 case Kind::kTensor: 711 case Kind::kInvariant: 712 llvm_unreachable("handled above"); 713 case Kind::kMulF: 714 return rewriter.create<MulFOp>(op.getLoc(), v0, v1); 715 case Kind::kMulI: 716 return rewriter.create<MulIOp>(op.getLoc(), v0, v1); 717 case Kind::kAddF: 718 return rewriter.create<AddFOp>(op.getLoc(), v0, v1); 719 case Kind::kAddI: 720 return rewriter.create<AddIOp>(op.getLoc(), v0, v1); 721 } 722 llvm_unreachable("unexpected expression kind"); 723 } 724 725 /// Hoists loop invariant tensor loads for which indices have been exhausted. 726 static void genInvariants(Merger &merger, CodeGen &codegen, 727 PatternRewriter &rewriter, linalg::GenericOp op, 728 unsigned exp, unsigned ldx, bool hoist) { 729 if (merger.exp(exp).kind == Kind::kTensor) { 730 // Inspect tensor indices. 731 bool atLevel = ldx == -1u; 732 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0]; 733 auto map = op.getTiedIndexingMap(t); 734 auto enc = getSparseTensorEncoding(t->get().getType()); 735 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 736 unsigned idx = map.getDimPosition(perm(enc, d)); 737 if (!codegen.loops[idx]) 738 return; // still in play 739 else if (idx == ldx) 740 atLevel = true; 741 } 742 // All exhausted at this level (atLevel denotes exactly at this level). 743 OpOperand *lhs = op.getOutputOperand(0); 744 if (lhs == t) { 745 codegen.redExp = hoist ? exp : -1u; 746 } else if (atLevel) { 747 merger.exp(exp).val = 748 hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 749 } 750 } else if (merger.exp(exp).kind != Kind::kInvariant) { 751 // Traverse into the binary operations. Note that we only hoist 752 // tensor loads, since subsequent MLIR/LLVM passes know how to 753 // deal with all other kinds of derived loop invariants. 754 unsigned e0 = merger.exp(exp).e0; 755 unsigned e1 = merger.exp(exp).e1; 756 genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist); 757 genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist); 758 } 759 } 760 761 /// Generates initialization code for the subsequent loop sequence at 762 /// current index level. Returns true if the loop sequence needs to 763 /// maintain the universal index. 764 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 765 linalg::GenericOp op, std::vector<unsigned> &topSort, 766 unsigned at, llvm::BitVector &inits) { 767 bool needsUniv = false; 768 Location loc = op.getLoc(); 769 unsigned idx = topSort[at]; 770 771 // Initialize sparse positions. 772 for (unsigned b = 0, be = inits.size(); b < be; b++) { 773 if (inits[b]) { 774 unsigned tensor = merger.tensor(b); 775 assert(idx == merger.index(b)); 776 if (merger.isDim(b, Dim::kSparse)) { 777 // Initialize sparse index. 778 unsigned pat = at; 779 for (; pat != 0; pat--) { 780 if (codegen.pidxs[tensor][topSort[pat - 1]]) 781 break; 782 } 783 Value ptr = codegen.pointers[tensor][idx]; 784 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 785 Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 786 : codegen.pidxs[tensor][topSort[pat - 1]]; 787 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 788 Value p1 = rewriter.create<AddIOp>(loc, p0, one); 789 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 790 } else { 791 // Dense index still in play. 792 needsUniv = true; 793 } 794 } 795 } 796 797 // Initialize the universal dense index. 798 codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0); 799 return needsUniv; 800 } 801 802 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 803 /// operation is a candidate. Whether it is actually converted to SIMD code 804 /// depends on the requested strategy. 805 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { 806 switch (codegen.options.vectorizationStrategy) { 807 case SparseVectorizationStrategy::kNone: 808 return false; 809 case SparseVectorizationStrategy::kDenseInnerLoop: 810 return isInner && !isSparse; 811 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 812 return isInner; 813 } 814 llvm_unreachable("unexpected vectorization strategy"); 815 } 816 817 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 818 /// that is marked "parallel" is a candidate. Whether it is actually converted 819 /// to a parallel operation depends on the requested strategy. 820 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 821 bool isSparse, bool isVector) { 822 switch (codegen.options.parallelizationStrategy) { 823 case SparseParallelizationStrategy::kNone: 824 return false; 825 case SparseParallelizationStrategy::kDenseOuterLoop: 826 return isOuter && !isSparse && !isReduction && !isVector; 827 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 828 return isOuter && !isReduction && !isVector; 829 case SparseParallelizationStrategy::kDenseAnyLoop: 830 return !isSparse && !isReduction && !isVector; 831 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 832 return !isReduction && !isVector; 833 } 834 llvm_unreachable("unexpected parallelization strategy"); 835 } 836 837 /// Checks unit strides for dense tensors. The iteration graph may have ignored 838 /// dense access patterns in order to avoid cycles (sparse access patterns are 839 /// always placed innermost), but that means dense access has become strided. 840 /// For now, we reject vectorization of such cases. 841 /// TODO: implement strided load/stores on dense arrays 842 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 843 unsigned idx) { 844 for (OpOperand *t : op.getInputAndOutputOperands()) { 845 if (!getSparseTensorEncoding(t->get().getType())) { 846 auto map = op.getTiedIndexingMap(t); 847 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 848 if (map.getDimPosition(d) == idx && d != rank - 1) 849 return false; 850 } 851 } 852 } 853 return true; 854 } 855 856 /// Generates a for-loop on a single index. 857 static Operation *genFor(Merger &merger, CodeGen &codegen, 858 PatternRewriter &rewriter, linalg::GenericOp op, 859 bool isOuter, bool isInner, unsigned idx, 860 llvm::BitVector &indices) { 861 unsigned fb = indices.find_first(); 862 unsigned tensor = merger.tensor(fb); 863 assert(idx == merger.index(fb)); 864 auto iteratorTypes = op.iterator_types().getValue(); 865 bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]); 866 bool isSparse = merger.isDim(fb, Dim::kSparse); 867 bool isVector = isVectorFor(codegen, isInner, isSparse) && 868 denseUnitStrides(merger, op, idx); 869 bool isParallel = 870 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 871 872 // Prepare vector length. 873 if (isVector) 874 codegen.curVecLength = codegen.options.vectorLength; 875 876 // Loop bounds and increment. 877 Location loc = op.getLoc(); 878 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 879 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 880 Value step = rewriter.create<ConstantIndexOp>(loc, codegen.curVecLength); 881 882 // Emit a parallel loop. 883 if (isParallel) { 884 assert(!isVector); 885 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 886 if (isSparse) 887 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 888 else 889 codegen.loops[idx] = parOp.getInductionVars()[0]; 890 rewriter.setInsertionPointToStart(parOp.getBody()); 891 return parOp; 892 } 893 894 // Emit a sequential loop, potentially with a scalarized reduction. 895 bool scalarRed = isInner && codegen.redExp != -1u; 896 SmallVector<Value, 4> operands; 897 if (scalarRed) { 898 Value load = genReductionStart(merger, codegen, rewriter, op); 899 operands.push_back(load); 900 } 901 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 902 if (scalarRed) { 903 codegen.redVal = merger.exp(codegen.redExp).val = 904 forOp.getRegionIterArgs().front(); 905 } 906 // Assign induction variable to sparse or dense index. 907 Value iv = forOp.getInductionVar(); 908 if (isSparse) 909 codegen.pidxs[tensor][idx] = iv; 910 else 911 codegen.loops[idx] = iv; 912 rewriter.setInsertionPointToStart(forOp.getBody()); 913 // Share vector iteration mask between all subsequent loads/stores. 914 if (isVector) 915 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 916 return forOp; 917 } 918 919 /// Emit a while-loop for co-iteration over multiple indices. 920 static Operation *genWhile(Merger &merger, CodeGen &codegen, 921 PatternRewriter &rewriter, linalg::GenericOp op, 922 unsigned idx, bool needsUniv, 923 llvm::BitVector &indices) { 924 SmallVector<Type, 4> types; 925 SmallVector<Value, 4> operands; 926 // Construct the while-loop with a parameter for each index. 927 Type indexType = rewriter.getIndexType(); 928 for (unsigned b = 0, be = indices.size(); b < be; b++) { 929 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 930 unsigned tensor = merger.tensor(b); 931 assert(idx == merger.index(b)); 932 types.push_back(indexType); 933 assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() && 934 "type mismatch for sparse index"); 935 operands.push_back(codegen.pidxs[tensor][idx]); 936 } 937 } 938 if (needsUniv) { 939 types.push_back(indexType); 940 assert(codegen.loops[idx].getType().isa<IndexType>() && 941 "type mismatch for universal index"); 942 operands.push_back(codegen.loops[idx]); 943 } 944 Location loc = op.getLoc(); 945 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 946 Block *before = rewriter.createBlock(&whileOp.before(), {}, types); 947 Block *after = rewriter.createBlock(&whileOp.after(), {}, types); 948 949 // Build the "before" region, which effectively consists 950 // of a conjunction of "i < upper" tests on all induction. 951 rewriter.setInsertionPointToStart(&whileOp.before().front()); 952 Value cond; 953 unsigned o = 0; 954 for (unsigned b = 0, be = indices.size(); b < be; b++) { 955 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 956 unsigned tensor = merger.tensor(b); 957 assert(idx == merger.index(b)); 958 Value op1 = before->getArgument(o); 959 Value op2 = codegen.highs[tensor][idx]; 960 Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2); 961 cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc; 962 codegen.pidxs[tensor][idx] = after->getArgument(o++); 963 } 964 } 965 if (needsUniv) 966 codegen.loops[idx] = after->getArgument(o++); 967 assert(o == operands.size()); 968 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 969 rewriter.setInsertionPointToStart(&whileOp.after().front()); 970 return whileOp; 971 } 972 973 /// Generates a for-loop or a while-loop, depending on whether it implements 974 /// singleton iteration or co-iteration over the given conjunction. 975 static Operation *genLoop(Merger &merger, CodeGen &codegen, 976 PatternRewriter &rewriter, linalg::GenericOp op, 977 std::vector<unsigned> &topSort, unsigned at, 978 bool needsUniv, llvm::BitVector &indices) { 979 unsigned idx = topSort[at]; 980 if (indices.count() == 1) { 981 bool isOuter = at == 0; 982 bool isInner = at == topSort.size() - 1; 983 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 984 indices); 985 } 986 genReductionEnd(merger, codegen, rewriter, op); // cannot chain 987 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 988 } 989 990 /// Generates the local variables for this loop, consisting of the sparse 991 /// indices, restored universal dense index, and dense positions. 992 static void genLocals(Merger &merger, CodeGen &codegen, 993 PatternRewriter &rewriter, linalg::GenericOp op, 994 std::vector<unsigned> &topSort, unsigned at, 995 bool needsUniv, llvm::BitVector &locals) { 996 Location loc = op.getLoc(); 997 unsigned idx = topSort[at]; 998 999 // Initialize sparse indices. 1000 Value min; 1001 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1002 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1003 unsigned tensor = merger.tensor(b); 1004 assert(idx == merger.index(b)); 1005 Value ptr = codegen.indices[tensor][idx]; 1006 Value s = codegen.pidxs[tensor][idx]; 1007 Value load = genLoad(codegen, rewriter, loc, ptr, s); 1008 codegen.idxs[tensor][idx] = load; 1009 if (!needsUniv) { 1010 if (min) { 1011 Value cmp = 1012 rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min); 1013 min = rewriter.create<SelectOp>(loc, cmp, load, min); 1014 } else { 1015 min = load; 1016 } 1017 } 1018 } 1019 } 1020 1021 // Merge dense universal index over minimum. 1022 if (min) { 1023 assert(!needsUniv); 1024 codegen.loops[idx] = min; 1025 } 1026 1027 // Initialize dense positions. Note that we generate dense indices of the 1028 // output tensor unconditionally, since they may not appear in the lattice, 1029 // but may be needed for linearized codegen. 1030 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1031 if ((locals[b] || merger.isOutTensor(b, idx)) && 1032 merger.isDim(b, Dim::kDense)) { 1033 unsigned tensor = merger.tensor(b); 1034 assert(idx == merger.index(b)); 1035 unsigned pat = at; 1036 for (; pat != 0; pat--) 1037 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1038 break; 1039 Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0) 1040 : codegen.pidxs[tensor][topSort[pat - 1]]; 1041 codegen.pidxs[tensor][idx] = genAddress( 1042 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1043 } 1044 } 1045 } 1046 1047 /// Generates the induction structure for a while-loop. 1048 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1049 PatternRewriter &rewriter, linalg::GenericOp op, 1050 unsigned idx, bool needsUniv, 1051 llvm::BitVector &induction, ResultRange results) { 1052 Location loc = op.getLoc(); 1053 unsigned o = 0; 1054 SmallVector<Value, 4> operands; 1055 Value one = rewriter.create<ConstantIndexOp>(loc, 1); 1056 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1057 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1058 unsigned tensor = merger.tensor(b); 1059 assert(idx == merger.index(b)); 1060 Value op1 = codegen.idxs[tensor][idx]; 1061 Value op2 = codegen.loops[idx]; 1062 Value op3 = codegen.pidxs[tensor][idx]; 1063 Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 1064 Value add = rewriter.create<AddIOp>(loc, op3, one); 1065 operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3)); 1066 codegen.pidxs[tensor][idx] = results[o++]; 1067 } 1068 } 1069 if (needsUniv) { 1070 operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one)); 1071 codegen.loops[idx] = results[o++]; 1072 } 1073 assert(o == operands.size()); 1074 rewriter.create<scf::YieldOp>(loc, operands); 1075 } 1076 1077 /// Generates a single if-statement within a while-loop. 1078 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1079 PatternRewriter &rewriter, linalg::GenericOp op, 1080 unsigned idx, llvm::BitVector &conditions) { 1081 Location loc = op.getLoc(); 1082 Value cond; 1083 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1084 if (conditions[b]) { 1085 unsigned tensor = merger.tensor(b); 1086 assert(idx == merger.index(b)); 1087 Value clause; 1088 if (merger.isDim(b, Dim::kSparse)) { 1089 Value op1 = codegen.idxs[tensor][idx]; 1090 Value op2 = codegen.loops[idx]; 1091 clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2); 1092 } else { 1093 clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true 1094 } 1095 cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause; 1096 } 1097 } 1098 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true); 1099 rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); 1100 return ifOp; 1101 } 1102 1103 /// Recursively generates code while computing iteration lattices in order 1104 /// to manage the complexity of implementing co-iteration over unions 1105 /// and intersections of sparse iterations spaces. 1106 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1107 linalg::GenericOp op, std::vector<unsigned> &topSort, 1108 unsigned exp, unsigned at) { 1109 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1110 if (at == topSort.size()) { 1111 OpOperand *lhs = op.getOutputOperand(0); 1112 Value rhs = genExp(merger, codegen, rewriter, op, exp); 1113 genTensorStore(merger, codegen, rewriter, op, lhs, rhs); 1114 return; 1115 } 1116 assert(codegen.curVecLength == 1); 1117 1118 // Construct iteration lattices for current loop index, with L0 at top. 1119 // Then emit initialization code for the loop sequence at this level. 1120 // We maintain the universal dense index if dense indices are still 1121 // in play for a non-singleton loop sequence. 1122 Location loc = op.getLoc(); 1123 unsigned idx = topSort[at]; 1124 unsigned lts = merger.optimizeSet(buildLattices(merger, op, exp, idx)); 1125 unsigned lsize = merger.set(lts).size(); 1126 assert(lsize != 0); 1127 unsigned l0 = merger.set(lts)[0]; 1128 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1129 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true); 1130 bool needsUniv = false; 1131 if (genInit(merger, codegen, rewriter, op, topSort, at, 1132 merger.lat(l0).bits)) { 1133 // Maintain the universal index only if it is actually 1134 // consumed by a subsequent lattice point. 1135 for (unsigned i = 1; i < lsize; i++) { 1136 unsigned li = merger.set(lts)[i]; 1137 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) { 1138 needsUniv = true; 1139 break; 1140 } 1141 } 1142 } 1143 1144 // Emit a loop for every lattice point L0 >= Li. 1145 for (unsigned i = 0; i < lsize; i++) { 1146 unsigned li = merger.set(lts)[i]; 1147 1148 // Emit loop. 1149 codegen.curVecLength = 1; 1150 llvm::BitVector indices = merger.lat(li).simple; 1151 Operation *loop = 1152 genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices); 1153 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1154 merger.lat(li).bits); 1155 1156 // Visit all lattices points with Li >= Lj to generate the 1157 // loop-body, possibly with if statements for coiteration. 1158 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1159 for (unsigned j = 0; j < lsize; j++) { 1160 unsigned lj = merger.set(lts)[j]; 1161 unsigned ej = merger.lat(lj).exp; 1162 if (li == lj || merger.latGT(li, lj)) { 1163 // Recurse into body of each branch. 1164 if (isWhile) { 1165 scf::IfOp ifOp = 1166 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1167 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1168 rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); 1169 } else { 1170 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1171 } 1172 } 1173 } 1174 1175 // Wrap-up induction and restore insertion point. 1176 if (isWhile) { 1177 scf::WhileOp whileOp = cast<scf::WhileOp>(loop); 1178 rewriter.setInsertionPointToEnd(&whileOp.after().front()); 1179 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1180 merger.lat(li).bits, whileOp.results()); 1181 } else { 1182 needsUniv = false; 1183 if (codegen.redVal) { 1184 rewriter.create<scf::YieldOp>(loc, codegen.redVal); 1185 codegen.redVal = loop->getResult(0); 1186 } 1187 } 1188 rewriter.setInsertionPointAfter(loop); 1189 } 1190 1191 // Wrap-up loop sequence. 1192 codegen.curVecLength = 1; 1193 genReductionEnd(merger, codegen, rewriter, op); 1194 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); 1195 codegen.loops[idx] = Value(); 1196 } 1197 1198 /// Converts the result computed by the sparse kernel into the required form. 1199 static void genResult(Merger &merger, CodeGen &codegen, 1200 PatternRewriter &rewriter, linalg::GenericOp op) { 1201 Location loc = op.getLoc(); 1202 OpOperand *lhs = op.getOutputOperand(0); 1203 Type resType = lhs->get().getType(); 1204 unsigned tensor = lhs->getOperandNumber(); 1205 auto map = op.getTiedIndexingMap(lhs); 1206 auto enc = getSparseTensorEncoding(resType); 1207 Value result = codegen.buffers.back(); // value array 1208 if (enc) { 1209 // The sparse annotation unambigiously defines the arrays needed 1210 // to "reconstruct" the sparse tensor from the storage scheme 1211 // (even though lowering should never need this eventually). 1212 SmallVector<Value, 4> args; 1213 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1214 unsigned idx = map.getDimPosition(perm(enc, d)); 1215 if (merger.isDim(tensor, idx, Dim::kSparse)) { 1216 args.push_back(codegen.pointers[tensor][idx]); 1217 args.push_back(codegen.indices[tensor][idx]); 1218 } 1219 } 1220 args.push_back(result); 1221 result = rewriter.create<ToTensorOp>(loc, resType, args); 1222 } else { 1223 // To "reconstruct" an non-annotated tensor, sipmly load it 1224 // from the bufferized value. 1225 result = rewriter.create<memref::TensorLoadOp>(loc, resType, result); 1226 } 1227 rewriter.replaceOp(op, result); 1228 } 1229 1230 namespace { 1231 1232 /// Sparse rewriting rule for generic Lingalg operation. 1233 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1234 public: 1235 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1236 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1237 1238 LogicalResult matchAndRewrite(linalg::GenericOp op, 1239 PatternRewriter &rewriter) const override { 1240 // Detects sparse annotations and translate the per-dimension sparsity 1241 // information for all tensors to loop indices in the kernel. 1242 assert(op.getNumOutputs() == 1); 1243 unsigned numTensors = op.getNumInputsAndOutputs(); 1244 unsigned numLoops = op.iterator_types().getValue().size(); 1245 Merger merger(numTensors, numLoops); 1246 if (!findSparseAnnotations(merger, op)) 1247 return failure(); 1248 1249 // Computes a topologically sorted iteration graph to ensure 1250 // tensors are visited in natural index order. Fails on cycles. 1251 // This assumes that higher-level passes have already put the 1252 // tensors in each tensor expression in a feasible order. 1253 std::vector<unsigned> topSort; 1254 if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) && 1255 !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true)) 1256 return failure(); 1257 1258 // Finds the terminating yield statement and builds the tensor 1259 // expression for the Linalg operation in SSA form. 1260 Operation *yield = op.region().front().getTerminator(); 1261 Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0)); 1262 if (!exp.hasValue()) 1263 return failure(); // build failure 1264 1265 // Reject an inadmissable tensor expression. 1266 if (!isAdmissableTensorExp(merger, op, exp.getValue())) 1267 return failure(); 1268 1269 // Recursively generates code. 1270 CodeGen codegen(options, numTensors, numLoops); 1271 if (!genBuffers(merger, codegen, rewriter, op)) 1272 return failure(); // could not bufferize 1273 genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); 1274 genResult(merger, codegen, rewriter, op); 1275 return success(); 1276 } 1277 1278 private: 1279 /// Options to control sparse code generation. 1280 SparsificationOptions options; 1281 }; 1282 1283 } // namespace 1284 1285 /// Populates the given patterns list with rewriting rules required for 1286 /// the sparsification of linear algebra operations. 1287 void mlir::populateSparsificationPatterns( 1288 RewritePatternSet &patterns, const SparsificationOptions &options) { 1289 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1290 } 1291