1 //===- Sparsification.cpp - Implementation of sparsification --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements converting sparse tensor types to actual sparse code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CodegenUtils.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 18 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 19 #include "mlir/Dialect/Func/IR/FuncOps.h" 20 #include "mlir/Dialect/Linalg/IR/Linalg.h" 21 #include "mlir/Dialect/Linalg/Utils/Utils.h" 22 #include "mlir/Dialect/MemRef/IR/MemRef.h" 23 #include "mlir/Dialect/SCF/SCF.h" 24 #include "mlir/Dialect/SCF/Transforms.h" 25 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" 27 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 28 #include "mlir/Dialect/Vector/IR/VectorOps.h" 29 #include "mlir/IR/Matchers.h" 30 #include "mlir/IR/TensorEncoding.h" 31 #include "llvm/ADT/SmallBitVector.h" 32 33 using namespace mlir; 34 using namespace mlir::sparse_tensor; 35 36 //===----------------------------------------------------------------------===// 37 // Declarations of data structures. 38 //===----------------------------------------------------------------------===// 39 40 namespace { 41 42 // Iteration graph sorting. 43 enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; 44 45 // Reduction kinds. 46 enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor }; 47 48 // Code generation. 49 struct CodeGen { 50 CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops, 51 OpOperand *op, unsigned nest) 52 : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), 53 pointers(numTensors, std::vector<Value>(numLoops)), 54 indices(numTensors, std::vector<Value>(numLoops)), 55 highs(numTensors, std::vector<Value>(numLoops)), 56 pidxs(numTensors, std::vector<Value>(numLoops)), 57 idxs(numTensors, std::vector<Value>(numLoops)), redVal(), sparseOut(op), 58 outerParNest(nest), lexIdx(), expValues(), expFilled(), expAdded(), 59 expCount(), curVecMask() {} 60 /// Sparsification options. 61 SparsificationOptions options; 62 /// Universal dense indices and upper bounds (by index). The loops array 63 /// is updated with the value of the universal dense index in the current 64 /// loop. The sizes array is set once with the inferred dimension sizes. 65 std::vector<Value> loops; 66 std::vector<Value> sizes; 67 /// Buffers for storing dense and sparse numerical values (by tensor). 68 /// This array is set once during bufferization of all tensors. 69 std::vector<Value> buffers; 70 /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). 71 /// This array is set once during bufferization of all sparse tensors. 72 std::vector<std::vector<Value>> pointers; 73 std::vector<std::vector<Value>> indices; 74 /// Sparse iteration information (by tensor and index). These arrays 75 /// are updated to remain current within the current loop. 76 std::vector<std::vector<Value>> highs; 77 std::vector<std::vector<Value>> pidxs; 78 std::vector<std::vector<Value>> idxs; 79 /// Current reduction, updated during code generation. When indices of a 80 /// reduction are exhausted, all inner loops can use a scalarized reduction. 81 unsigned redExp = -1u; 82 Value redVal; 83 Reduction redKind = kNoReduc; 84 // Sparse tensor as output. Implemented either through direct injective 85 // insertion in lexicographic index order (where indices are updated 86 // in the temporary array `lexIdx`) or through access pattern expansion 87 // in the innermost loop nest (`expValues` through `expCount`). 88 OpOperand *sparseOut; 89 unsigned outerParNest; 90 Value lexIdx; 91 Value expValues; 92 Value expFilled; 93 Value expAdded; 94 Value expCount; 95 // Current vector length and mask. 96 unsigned curVecLength = 1; 97 Value curVecMask; 98 }; 99 100 } // namespace 101 102 //===----------------------------------------------------------------------===// 103 // Sparse compiler analysis methods. 104 //===----------------------------------------------------------------------===// 105 106 /// Helper method to apply dimension ordering permutation. 107 static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d) { 108 if (enc) { 109 auto order = enc.getDimOrdering(); 110 if (order) { 111 assert(order.isPermutation()); 112 return order.getDimPosition(d); 113 } 114 } 115 return d; 116 } 117 118 /// Helper method to translate dim level type to internal representation. 119 static Dim toDim(const SparseTensorEncodingAttr &enc, unsigned d) { 120 if (enc) { 121 SparseTensorEncodingAttr::DimLevelType tp = enc.getDimLevelType()[d]; 122 if (tp == SparseTensorEncodingAttr::DimLevelType::Compressed) 123 return Dim::kSparse; 124 if (tp == SparseTensorEncodingAttr::DimLevelType::Singleton) 125 return Dim::kSingle; 126 } 127 return Dim::kDense; 128 } 129 130 /// Helper method to inspect affine expressions. Rejects cases where the 131 /// same index is used more than once. Also rejects affine expressions 132 /// that are not a direct index for annotated tensors. 133 // TODO: accept more affine cases for sparse tensors 134 static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, Dim dim, 135 bool isDense) { 136 switch (a.getKind()) { 137 case AffineExprKind::DimId: { 138 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 139 if (!merger.isDim(tensor, idx, Dim::kUndef)) 140 return false; // used more than once 141 merger.setDim(tensor, idx, dim); 142 return true; 143 } 144 case AffineExprKind::Add: 145 case AffineExprKind::Mul: { 146 if (!isDense) 147 return false; 148 auto binOp = a.cast<AffineBinaryOpExpr>(); 149 return findAffine(merger, tensor, binOp.getLHS(), dim, isDense) && 150 findAffine(merger, tensor, binOp.getRHS(), dim, isDense); 151 } 152 case AffineExprKind::Constant: 153 return isDense; 154 default: 155 return false; 156 } 157 } 158 159 /// Helper method to inspect sparse encodings in the tensor types. 160 /// Fills the per-dimension sparsity information for all tensors. 161 /// Returns true if the sparse annotations and affine subscript 162 /// expressions of all tensors are admissable. Returns false if 163 /// no annotations are found or inadmissable constructs occur. 164 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { 165 bool annotated = false; 166 for (OpOperand *t : op.getInputAndOutputOperands()) { 167 auto map = op.getTiedIndexingMap(t); 168 auto enc = getSparseTensorEncoding(t->get().getType()); 169 if (enc) 170 annotated = true; 171 assert(map.getNumResults() == op.getRank(t)); 172 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 173 unsigned tensor = t->getOperandNumber(); 174 AffineExpr a = map.getResult(perm(enc, d)); 175 if (!findAffine(merger, tensor, a, toDim(enc, d), !enc)) 176 return false; // inadmissable affine expression 177 } 178 } 179 return annotated; 180 } 181 182 /// A DFS helper to compute a topological sort. Note that recursion is 183 /// bounded by the number of implicit loops, which is always small. 184 /// Returns false when a cycle is detected. 185 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit, 186 std::vector<unsigned> &topSort, 187 std::vector<std::vector<bool>> &adjM) { 188 if (visit[i] != 0) 189 return visit[i] != 1; // 1 denotes cycle! 190 visit[i] = 1; 191 for (unsigned j = 0, e = visit.size(); j < e; j++) 192 if (adjM[i][j]) 193 if (!topSortDFS(j, visit, topSort, adjM)) 194 return false; 195 visit[i] = 2; 196 topSort.push_back(i); 197 return true; 198 } 199 200 /// Helper method to add all constraints from the indices in one affine 201 /// expression before all indices in the other affine expression. For 202 /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. 203 static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, 204 AffineExpr a, AffineExpr b, unsigned fidx) { 205 switch (a.getKind()) { 206 case AffineExprKind::DimId: { 207 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 208 if (b) 209 addAffineOrderings(adjM, b, AffineExpr(), idx); 210 else 211 adjM[fidx][idx] = true; 212 break; 213 } 214 case AffineExprKind::Add: 215 case AffineExprKind::Mul: { 216 auto binOp = a.cast<AffineBinaryOpExpr>(); 217 addAffineOrderings(adjM, binOp.getLHS(), b, fidx); 218 addAffineOrderings(adjM, binOp.getRHS(), b, fidx); 219 break; 220 } 221 default: 222 break; 223 } 224 } 225 226 /// Computes a topologically sorted iteration graph for the linalg operation. 227 /// Ensures all tensors are visited in natural index order. This is essential 228 /// for sparse storage formats since these only support access along fixed 229 /// dimensions. Even for dense storage formats, however, the natural index 230 /// order yields innermost unit-stride access with better spatial locality. 231 static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, 232 std::vector<unsigned> &topSort, 233 unsigned mask) { 234 // Set up an n x n from/to adjacency matrix of the iteration graph 235 // for the implicit loop indices i_0 .. i_n-1. 236 unsigned n = op.getNumLoops(); 237 std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); 238 239 // Iterate over the indexing maps of every tensor in the tensor expression. 240 for (OpOperand *t : op.getInputAndOutputOperands()) { 241 auto map = op.getTiedIndexingMap(t); 242 auto enc = getSparseTensorEncoding(t->get().getType()); 243 assert(map.getNumDims() == n); 244 // Skip dense tensor constraints when not requested. 245 if (!(mask & SortMask::kIncludeDense) && !enc) 246 continue; 247 // Each tensor expression and optional dimension ordering (row-major 248 // by default) puts an ordering constraint on the loop indices. For 249 // example, the tensor expresion A_ijk forces the ordering i < j < k 250 // on the loop indices if no explicit dimension ordering is given. 251 for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { 252 AffineExpr f = map.getResult(perm(enc, d - 1)); 253 AffineExpr t = map.getResult(perm(enc, d)); 254 addAffineOrderings(adjM, f, t, 0); 255 } 256 // Push unrelated loops into sparse iteration space, so these 257 // will be skipped more often. 258 if (mask & SortMask::kIncludeUndef) { 259 unsigned tensor = t->getOperandNumber(); 260 for (unsigned i = 0; i < n; i++) 261 if (merger.isDim(tensor, i, Dim::kSparse)) 262 for (unsigned j = 0; j < n; j++) 263 if (merger.isDim(tensor, j, Dim::kUndef)) 264 adjM[i][j] = true; 265 } 266 } 267 268 // Topologically sort the iteration graph to determine loop order. 269 // Report failure for a cyclic iteration graph. 270 topSort.clear(); 271 topSort.reserve(n); 272 std::vector<unsigned> visit(n, 0); 273 for (unsigned i = 0; i < n; i++) 274 if (visit[i] == 0) 275 if (!topSortDFS(i, visit, topSort, adjM)) 276 return false; // cycle! 277 std::reverse(std::begin(topSort), std::end(topSort)); 278 return true; 279 } 280 281 /// Returns true if tensor has an in-place annotation. 282 static bool isInPlace(Value val) { 283 if (auto arg = val.dyn_cast<BlockArgument>()) 284 if (auto funcOp = dyn_cast<FuncOp>(arg.getOwner()->getParentOp())) 285 if (auto attr = funcOp.getArgAttrOfType<BoolAttr>( 286 arg.getArgNumber(), 287 bufferization::BufferizableOpInterface::kInplaceableAttrName)) 288 return attr.getValue(); 289 return false; 290 } 291 292 /// Returns true if tensor materializes uninitialized into the computation. 293 static bool isMaterializing(Value val) { 294 return val.getDefiningOp<linalg::InitTensorOp>() || 295 val.getDefiningOp<InitOp>(); 296 } 297 298 /// Returns true when the tensor expression is admissable for codegen. 299 /// Since all sparse input tensors are admissable, we just need to check 300 /// whether the out tensor in the tensor expression codegen is admissable. 301 /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective 302 /// nesting depth when a "truly dynamic" sparse tensor output occurs. 303 static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, 304 std::vector<unsigned> &topSort, unsigned exp, 305 OpOperand **sparseOut, 306 unsigned &outerParNest) { 307 OpOperand *lhs = op.getOutputOperand(0); 308 unsigned tensor = lhs->getOperandNumber(); 309 auto enc = getSparseTensorEncoding(lhs->get().getType()); 310 // An non-annotated output tensor is assumed dense, and becomes a random 311 // access n-dim memref. Admissable since insertions cannot occur. 312 if (!enc) 313 return true; 314 // An all-dense annotated "sparse" output tensor becomes a linearized random 315 // access 1-dim memref. Also admissable since insertions cannot occur. 316 bool allDense = true; 317 auto iteratorTypes = op.iterator_types().getValue(); 318 unsigned numLoops = iteratorTypes.size(); 319 for (unsigned i = 0; i < numLoops; i++) 320 if (merger.isDim(tensor, i, Dim::kSparse)) { 321 allDense = false; 322 break; 323 } 324 if (allDense) 325 return true; 326 // A tensor expression with a sparse output tensor that changes its values 327 // but not its nonzero structure, an operation called "simply dynamic" in 328 // [Bik96,Ch9], is also admissable without special codegen, provided 329 // the tensor's underlying sparse storage scheme can be modified in place. 330 if (merger.isSingleCondition(tensor, exp) && isInPlace(lhs->get())) 331 return true; 332 // Accept "truly dynamic" if the output tensor materializes uninitialized 333 // into the computation and insertions occur in lexicographic index order. 334 if (isMaterializing(lhs->get())) { 335 unsigned nest = 0; 336 for (unsigned i = 0; i < numLoops; i++) { 337 if (isReductionIterator(iteratorTypes[topSort[i]])) 338 break; // terminate at first reduction 339 nest++; 340 } 341 // Determine admissable dynamic insertion situations: 342 // (1) fully injective, since there are no reductions, 343 // (2) admissable 1-d expansion in innermost dimension. 344 if (nest >= op.getRank(lhs) - 1) { 345 *sparseOut = lhs; 346 outerParNest = nest; 347 return true; 348 } 349 } 350 return false; 351 } 352 353 //===----------------------------------------------------------------------===// 354 // Sparse compiler synthesis methods (reductions). 355 //===----------------------------------------------------------------------===// 356 357 /// Maps reduction kind to vector::CombiningKind. 358 static vector::CombiningKind getCombiningKind(Reduction kind) { 359 switch (kind) { 360 case kNoReduc: 361 break; 362 case kSum: 363 return vector::CombiningKind::ADD; 364 case kProduct: 365 return vector::CombiningKind::MUL; 366 case kAnd: 367 return vector::CombiningKind::AND; 368 case kOr: 369 return vector::CombiningKind::OR; 370 case kXor: 371 return vector::CombiningKind::XOR; 372 } 373 llvm_unreachable("unknown reduction kind"); 374 } 375 376 /// Maps operation to reduction. 377 static Reduction getReduction(Kind kind) { 378 switch (kind) { 379 case Kind::kAddF: 380 case Kind::kAddI: 381 case Kind::kSubF: 382 case Kind::kSubI: 383 return kSum; 384 case Kind::kMulF: 385 case Kind::kMulI: 386 return kProduct; 387 case Kind::kAndI: 388 return kAnd; 389 case Kind::kOrI: 390 return kOr; 391 case Kind::kXorI: 392 return kXor; 393 default: 394 llvm_unreachable("unexpected reduction operator"); 395 } 396 } 397 398 /// Generates an initial value for a vector reduction, following the scheme 399 /// given in Chapter 5 of "The Software Vectorization Handbook", where the 400 /// initial scalar value is correctly embedded in the vector reduction value, 401 /// and a straightforward horizontal reduction will complete the operation. 402 static Value genVectorReducInit(CodeGen &codegen, PatternRewriter &rewriter, 403 Location loc, VectorType vtp) { 404 Value r = codegen.redVal; 405 switch (codegen.redKind) { 406 case kNoReduc: 407 break; 408 case kSum: 409 case kXor: 410 // Initialize reduction vector to: | 0 | .. | 0 | r | 411 return rewriter.create<vector::InsertElementOp>( 412 loc, r, constantZero(rewriter, loc, vtp), 413 constantIndex(rewriter, loc, 0)); 414 case kProduct: 415 // Initialize reduction vector to: | 1 | .. | 1 | r | 416 return rewriter.create<vector::InsertElementOp>( 417 loc, r, constantOne(rewriter, loc, vtp), 418 constantIndex(rewriter, loc, 0)); 419 case kAnd: 420 case kOr: 421 // Initialize reduction vector to: | r | .. | r | r | 422 return rewriter.create<vector::BroadcastOp>(loc, vtp, r); 423 } 424 llvm_unreachable("unknown reduction kind"); 425 } 426 427 /// Generates final value for a vector reduction. 428 static Value genVectorReducEnd(CodeGen &codegen, PatternRewriter &rewriter, 429 Location loc, VectorType vtp) { 430 vector::CombiningKind kind = getCombiningKind(codegen.redKind); 431 return rewriter.create<vector::ReductionOp>(loc, kind, codegen.redVal); 432 } 433 434 /// Updates scalarized reduction value. 435 static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) { 436 assert(codegen.redKind != kNoReduc); 437 codegen.redVal = merger.exp(codegen.redExp).val = reduc; 438 } 439 440 //===----------------------------------------------------------------------===// 441 // Sparse compiler synthesis methods (statements and expressions). 442 //===----------------------------------------------------------------------===// 443 444 /// Generates buffer for the output tensor. Note that all sparse kernels 445 /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), 446 /// the output buffer is already initialized to all zeroes and only nonzeroes 447 /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), 448 /// only nonzeroes values are used for the updates and no assumption on the 449 /// original contents of the output buffer is necessary.. 450 static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, 451 linalg::GenericOp op, MemRefType denseTp, 452 ArrayRef<Value> args) { 453 Location loc = op.getLoc(); 454 Value tensor = op.getOutputOperand(0)->get(); 455 // The output tensor simply could materialize from the buffer that will 456 // be generated for the tensor present in the outs() clause. This has 457 // the major advantage that the sparse kernel only updates the nonzero 458 // positions for the output tensor. 459 if (isInPlace(tensor)) 460 return rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 461 // By default, a new buffer is allocated which is initialized to the 462 // tensor defined in the outs() clause. This is always correct but 463 // introduces a dense initialization component that may negatively 464 // impact the running complexity of the sparse kernel. If the tensor 465 // materializes into the computation, we need to preserve the zero 466 // initialization assumption of all sparse output buffers. 467 Value alloc = rewriter.create<memref::AllocOp>(loc, denseTp, args); 468 if (isMaterializing(tensor)) { 469 Value zero = constantZero(rewriter, loc, denseTp.getElementType()); 470 rewriter.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{alloc}); 471 } else { 472 Value init = 473 rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); 474 rewriter.create<memref::CopyOp>(loc, init, alloc); 475 } 476 return alloc; 477 } 478 479 /// Local bufferization of all dense and sparse data structures. 480 /// This code enables testing the first prototype sparse compiler. 481 // TODO: replace this with a proliferated bufferization strategy 482 static void genBuffers(Merger &merger, CodeGen &codegen, 483 PatternRewriter &rewriter, linalg::GenericOp op) { 484 Location loc = op.getLoc(); 485 assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); 486 // For every tensor, find lower and upper bound on dimensions, set the 487 // same bounds on loop indices, and obtain dense or sparse buffer(s). 488 SmallVector<Value, 4> args; 489 for (OpOperand *t : op.getInputAndOutputOperands()) { 490 unsigned tensor = t->getOperandNumber(); 491 auto shape = op.getShape(t); 492 auto map = op.getTiedIndexingMap(t); 493 auto enc = getSparseTensorEncoding(t->get().getType()); 494 // Scan all dimensions of current tensor. 495 args.clear(); 496 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 497 AffineExpr a = map.getResult(perm(enc, d)); 498 if (a.getKind() != AffineExprKind::DimId) 499 continue; // compound 500 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 501 // Handle sparse storage schemes. 502 if (merger.isDim(tensor, idx, Dim::kSparse)) { 503 auto dynShape = {ShapedType::kDynamicSize}; 504 auto ptrTp = 505 MemRefType::get(dynShape, getPointerOverheadType(rewriter, enc)); 506 auto indTp = 507 MemRefType::get(dynShape, getIndexOverheadType(rewriter, enc)); 508 Value dim = constantIndex(rewriter, loc, d); 509 // Generate sparse primitives to obtains pointer and indices. 510 codegen.pointers[tensor][idx] = 511 rewriter.create<ToPointersOp>(loc, ptrTp, t->get(), dim); 512 codegen.indices[tensor][idx] = 513 rewriter.create<ToIndicesOp>(loc, indTp, t->get(), dim); 514 } 515 // Find upper bound in current dimension. 516 unsigned p = perm(enc, d); 517 Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); 518 if (ShapedType::isDynamic(shape[p])) 519 args.push_back(up); 520 assert(codegen.highs[tensor][idx] == nullptr); 521 codegen.sizes[idx] = codegen.highs[tensor][idx] = up; 522 } 523 // Perform the required bufferization. Dense inputs materialize 524 // from the input tensors. Dense outputs need special handling. 525 // Sparse inputs use sparse primitives to obtain the values. 526 // We also accept in-place all-dense annotated "sparse" outputs. 527 Type elementType = getElementTypeOrSelf(t->get().getType()); 528 if (!enc) { 529 // Non-annotated dense tensors. 530 auto denseTp = MemRefType::get(shape, elementType); 531 if (tensor < op.getNumInputs()) 532 codegen.buffers[tensor] = 533 rewriter.create<bufferization::ToMemrefOp>(loc, denseTp, t->get()); 534 else 535 codegen.buffers[tensor] = 536 genOutputBuffer(codegen, rewriter, op, denseTp, args); 537 } else if (t == codegen.sparseOut) { 538 // True sparse output needs a lexIdx array. 539 Value rank = constantIndex(rewriter, loc, op.getRank(t)); 540 auto dynShape = {ShapedType::kDynamicSize}; 541 auto memTp = MemRefType::get(dynShape, rewriter.getIndexType()); 542 codegen.lexIdx = rewriter.create<memref::AllocaOp>(loc, memTp, rank); 543 } else { 544 // Annotated sparse tensors. 545 auto dynShape = {ShapedType::kDynamicSize}; 546 auto sparseTp = MemRefType::get(dynShape, elementType); 547 codegen.buffers[tensor] = 548 rewriter.create<ToValuesOp>(loc, sparseTp, t->get()); 549 } 550 } 551 } 552 553 /// Constructs vector type. 554 static VectorType vectorType(CodeGen &codegen, Type etp) { 555 return VectorType::get(codegen.curVecLength, etp); 556 } 557 558 /// Constructs vector type from pointer. 559 static VectorType vectorType(CodeGen &codegen, Value ptr) { 560 return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); 561 } 562 563 /// Constructs vector iteration mask. 564 static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, 565 Value iv, Value lo, Value hi, Value step) { 566 Location loc = iv.getLoc(); 567 VectorType mtp = vectorType(codegen, rewriter.getI1Type()); 568 // Special case if the vector length evenly divides the trip count (for 569 // example, "for i = 0, 128, 16"). A constant all-true mask is generated 570 // so that all subsequent masked memory operations are immediately folded 571 // into unconditional memory operations. 572 IntegerAttr loInt, hiInt, stepInt; 573 if (matchPattern(lo, m_Constant(&loInt)) && 574 matchPattern(hi, m_Constant(&hiInt)) && 575 matchPattern(step, m_Constant(&stepInt))) { 576 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) 577 return rewriter.create<vector::BroadcastOp>( 578 loc, mtp, constantI1(rewriter, loc, true)); 579 } 580 // Otherwise, generate a vector mask that avoids overrunning the upperbound 581 // during vector execution. Here we rely on subsequent loop optimizations to 582 // avoid executing the mask in all iterations, for example, by splitting the 583 // loop into an unconditional vector loop and a scalar cleanup loop. 584 auto minMap = AffineMap::get( 585 /*dimCount=*/2, /*symbolCount=*/1, 586 {rewriter.getAffineSymbolExpr(0), 587 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, 588 rewriter.getContext()); 589 Value end = 590 rewriter.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); 591 return rewriter.create<vector::CreateMaskOp>(loc, mtp, end); 592 } 593 594 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. 595 static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, 596 Value ptr, ArrayRef<Value> args) { 597 Location loc = ptr.getLoc(); 598 VectorType vtp = vectorType(codegen, ptr); 599 Value pass = constantZero(rewriter, loc, vtp); 600 if (args.back().getType().isa<VectorType>()) { 601 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 602 Value indexVec = args.back(); 603 scalarArgs.back() = constantIndex(rewriter, loc, 0); 604 return rewriter.create<vector::GatherOp>( 605 loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass); 606 } 607 return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, 608 codegen.curVecMask, pass); 609 } 610 611 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. 612 static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, 613 Value rhs, Value ptr, ArrayRef<Value> args) { 614 Location loc = ptr.getLoc(); 615 if (args.back().getType().isa<VectorType>()) { 616 SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); 617 Value indexVec = args.back(); 618 scalarArgs.back() = constantIndex(rewriter, loc, 0); 619 rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, 620 codegen.curVecMask, rhs); 621 return; 622 } 623 rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, 624 rhs); 625 } 626 627 /// Generates a vectorized invariant. Here we rely on subsequent loop 628 /// optimizations to hoist the invariant broadcast out of the vector loop. 629 static Value genVectorInvariantValue(CodeGen &codegen, 630 PatternRewriter &rewriter, Value val) { 631 VectorType vtp = vectorType(codegen, val.getType()); 632 return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val); 633 } 634 635 /// Generates an affine expression. 636 // 637 // TODO: generalize for sparse tensor subscripts 638 // 639 static Value genAffine(CodeGen &codegen, PatternRewriter &rewriter, 640 AffineExpr a, Location loc) { 641 switch (a.getKind()) { 642 case AffineExprKind::DimId: { 643 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 644 return codegen.loops[idx]; // universal dense index 645 } 646 case AffineExprKind::Add: { 647 auto binOp = a.cast<AffineBinaryOpExpr>(); 648 return rewriter.create<arith::AddIOp>( 649 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 650 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 651 } 652 case AffineExprKind::Mul: { 653 auto binOp = a.cast<AffineBinaryOpExpr>(); 654 return rewriter.create<arith::MulIOp>( 655 loc, genAffine(codegen, rewriter, binOp.getLHS(), loc), 656 genAffine(codegen, rewriter, binOp.getRHS(), loc)); 657 } 658 case AffineExprKind::Constant: { 659 int64_t c = a.cast<AffineConstantExpr>().getValue(); 660 return constantIndex(rewriter, loc, c); 661 } 662 default: 663 llvm_unreachable("unexpected affine subscript"); 664 } 665 } 666 667 /// Generates index for load/store on sparse tensor. 668 static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) { 669 auto map = op.getTiedIndexingMap(t); 670 auto enc = getSparseTensorEncoding(t->get().getType()); 671 AffineExpr a = map.getResult(perm(enc, map.getNumResults() - 1)); 672 assert(a.getKind() == AffineExprKind::DimId); 673 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 674 return codegen.loops[idx]; 675 } 676 677 /// Generates subscript for load/store on a dense or sparse tensor. 678 static Value genSubscript(CodeGen &codegen, PatternRewriter &rewriter, 679 linalg::GenericOp op, OpOperand *t, 680 SmallVector<Value, 4> &args) { 681 unsigned tensor = t->getOperandNumber(); 682 auto map = op.getTiedIndexingMap(t); 683 auto enc = getSparseTensorEncoding(t->get().getType()); 684 unsigned rank = map.getNumResults(); 685 if (enc) { 686 // Note that currently, all sparse subscripts are simple. 687 // TODO: accept affine too? 688 AffineExpr a = map.getResult(perm(enc, rank - 1)); 689 assert(a.getKind() == AffineExprKind::DimId); 690 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 691 assert(codegen.pidxs[tensor][idx] != nullptr); 692 args.push_back(codegen.pidxs[tensor][idx]); // position index 693 } else { 694 for (unsigned d = 0; d < rank; d++) { 695 AffineExpr a = map.getResult(perm(enc, d)); 696 args.push_back(genAffine(codegen, rewriter, a, op.getLoc())); 697 } 698 } 699 return codegen.buffers[tensor]; 700 } 701 702 /// Generates insertion code to implement dynamic tensor load. 703 static Value genInsertionLoad(CodeGen &codegen, PatternRewriter &rewriter, 704 linalg::GenericOp op, OpOperand *t) { 705 Location loc = op.getLoc(); 706 // Direct lexicographic index order, tensor loads as zero. 707 if (!codegen.expValues) { 708 Type tp = getElementTypeOrSelf(t->get().getType()); 709 return constantZero(rewriter, loc, tp); 710 } 711 // Load from expanded access pattern. 712 Value index = genIndex(codegen, op, t); 713 return rewriter.create<memref::LoadOp>(loc, codegen.expValues, index); 714 } 715 716 /// Generates insertion code to implement dynamic tensor store. 717 static void genInsertionStore(CodeGen &codegen, PatternRewriter &rewriter, 718 linalg::GenericOp op, OpOperand *t, Value rhs) { 719 Location loc = op.getLoc(); 720 // Direct insertion in lexicographic index order. 721 if (!codegen.expValues) { 722 rewriter.create<LexInsertOp>(loc, t->get(), codegen.lexIdx, rhs); 723 return; 724 } 725 // Generates insertion code along expanded access pattern. 726 // if (!expFilled[i]) then 727 // expFilled[i] = true 728 // expAdded[inserts++] = i 729 // endif 730 // values[i] = rhs 731 Value index = genIndex(codegen, op, t); 732 Value fval = constantI1(rewriter, loc, false); 733 Value tval = constantI1(rewriter, loc, true); 734 // If statement. 735 Value filled = rewriter.create<memref::LoadOp>(loc, codegen.expFilled, index); 736 Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 737 filled, fval); 738 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, rewriter.getIndexType(), 739 cond, /*else=*/true); 740 // True branch. 741 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 742 rewriter.create<memref::StoreOp>(loc, tval, codegen.expFilled, index); 743 rewriter.create<memref::StoreOp>(loc, index, codegen.expAdded, 744 codegen.expCount); 745 Value one = constantIndex(rewriter, loc, 1); 746 Value add = rewriter.create<arith::AddIOp>(loc, codegen.expCount, one); 747 rewriter.create<scf::YieldOp>(loc, add); 748 // False branch. 749 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 750 rewriter.create<scf::YieldOp>(loc, codegen.expCount); 751 rewriter.setInsertionPointAfter(ifOp); 752 // Value assignment. 753 codegen.expCount = ifOp.getResult(0); 754 rewriter.create<memref::StoreOp>(loc, rhs, codegen.expValues, index); 755 } 756 757 /// Generates a load on a dense or sparse tensor. 758 static Value genTensorLoad(Merger &merger, CodeGen &codegen, 759 PatternRewriter &rewriter, linalg::GenericOp op, 760 unsigned exp) { 761 // Test if the load was hoisted to a higher loop nest. 762 Value val = merger.exp(exp).val; 763 if (val) { 764 if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) 765 return genVectorInvariantValue(codegen, rewriter, val); 766 return val; 767 } 768 // Load during insertion. 769 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 770 if (t == codegen.sparseOut) 771 return genInsertionLoad(codegen, rewriter, op, t); 772 // Actual load. 773 SmallVector<Value, 4> args; 774 Value ptr = genSubscript(codegen, rewriter, op, t, args); 775 if (codegen.curVecLength > 1) 776 return genVectorLoad(codegen, rewriter, ptr, args); 777 return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args); 778 } 779 780 /// Generates a store on a dense or sparse tensor. 781 static void genTensorStore(Merger &merger, CodeGen &codegen, 782 PatternRewriter &rewriter, linalg::GenericOp op, 783 Value rhs) { 784 Location loc = op.getLoc(); 785 // Test if this is a scalarized reduction. 786 if (codegen.redVal) { 787 if (codegen.curVecLength > 1) 788 rhs = rewriter.create<arith::SelectOp>(loc, codegen.curVecMask, rhs, 789 codegen.redVal); 790 updateReduc(merger, codegen, rhs); 791 return; 792 } 793 // Store during insertion. 794 OpOperand *t = op.getOutputOperand(0); 795 if (t == codegen.sparseOut) { 796 genInsertionStore(codegen, rewriter, op, t, rhs); 797 return; 798 } 799 // Actual store. 800 SmallVector<Value, 4> args; 801 Value ptr = genSubscript(codegen, rewriter, op, t, args); 802 if (codegen.curVecLength > 1) 803 genVectorStore(codegen, rewriter, rhs, ptr, args); 804 else 805 rewriter.create<memref::StoreOp>(loc, rhs, ptr, args); 806 } 807 808 /// Generates a pointer/index load from the sparse storage scheme. Narrower 809 /// data types need to be zero extended before casting the value into the 810 /// index type used for looping and indexing. 811 static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, 812 Value ptr, Value s) { 813 // See https://llvm.org/docs/GetElementPtr.html for some background on 814 // the complications described below. 815 if (codegen.curVecLength > 1) { 816 // Since the index vector is used in a subsequent gather/scatter operations, 817 // which effectively defines an unsigned pointer + signed index, we must 818 // zero extend the vector to an index width. For 8-bit and 16-bit values, 819 // an 32-bit index width suffices. For 32-bit values, zero extending the 820 // elements into 64-bit loses some performance since the 32-bit indexed 821 // gather/scatter is more efficient than the 64-bit index variant (if the 822 // negative 32-bit index space is unused, the enableSIMDIndex32 flag can 823 // preserve this performance). For 64-bit values, there is no good way 824 // to state that the indices are unsigned, with creates the potential of 825 // incorrect address calculations in the unlikely case we need such 826 // extremely large offsets. 827 Type etp = ptr.getType().cast<MemRefType>().getElementType(); 828 Value vload = genVectorLoad(codegen, rewriter, ptr, {s}); 829 if (!etp.isa<IndexType>()) { 830 if (etp.getIntOrFloatBitWidth() < 32) 831 vload = rewriter.create<arith::ExtUIOp>( 832 loc, vectorType(codegen, rewriter.getI32Type()), vload); 833 else if (etp.getIntOrFloatBitWidth() < 64 && 834 !codegen.options.enableSIMDIndex32) 835 vload = rewriter.create<arith::ExtUIOp>( 836 loc, vectorType(codegen, rewriter.getI64Type()), vload); 837 } 838 return vload; 839 } 840 // For the scalar case, we simply zero extend narrower indices into 64-bit 841 // values before casting to index without a performance penalty. Here too, 842 // however, indices that already are 64-bit, in theory, cannot express the 843 // full range as explained above. 844 Value load = rewriter.create<memref::LoadOp>(loc, ptr, s); 845 if (!load.getType().isa<IndexType>()) { 846 if (load.getType().getIntOrFloatBitWidth() < 64) 847 load = rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), load); 848 load = 849 rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), load); 850 } 851 return load; 852 } 853 854 /// Generates an invariant value. 855 static Value genInvariantValue(Merger &merger, CodeGen &codegen, 856 PatternRewriter &rewriter, unsigned exp) { 857 Value val = merger.exp(exp).val; 858 if (codegen.curVecLength > 1) 859 return genVectorInvariantValue(codegen, rewriter, val); 860 return val; 861 } 862 863 /// Generates an address computation "sz * p + i". 864 static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter, 865 Location loc, Value size, Value p, Value i) { 866 Value mul = rewriter.create<arith::MulIOp>(loc, size, p); 867 if (auto vtp = i.getType().dyn_cast<VectorType>()) { 868 Value inv = 869 rewriter.create<arith::IndexCastOp>(loc, vtp.getElementType(), mul); 870 mul = genVectorInvariantValue(codegen, rewriter, inv); 871 } 872 return rewriter.create<arith::AddIOp>(loc, mul, i); 873 } 874 875 /// Generates an index value. 876 static Value genIndexValue(Merger &merger, CodeGen &codegen, 877 PatternRewriter &rewriter, unsigned exp, 878 unsigned ldx) { 879 unsigned idx = merger.exp(exp).index; 880 Value ival = codegen.loops[idx]; 881 Type itype = ival.getType(); 882 // During vectorization, we either encounter: 883 // (1) indices already in vector form, as in ... = ind[lo:hi], good to go, or 884 // (2) single index, as in ... = i, must convert to [i, i+1, ...] for inner i. 885 unsigned vl = codegen.curVecLength; 886 if (vl > 1 && !itype.isa<VectorType>()) { 887 Location loc = ival.getLoc(); 888 VectorType vtp = vectorType(codegen, itype); 889 ival = rewriter.create<vector::BroadcastOp>(loc, vtp, ival); 890 if (idx == ldx) { 891 SmallVector<APInt, 4> integers; 892 for (unsigned i = 0; i < vl; i++) 893 integers.push_back(APInt(/*width=*/64, i)); 894 auto values = DenseElementsAttr::get(vtp, integers); 895 Value incr = rewriter.create<arith::ConstantOp>(loc, vtp, values); 896 ival = rewriter.create<arith::AddIOp>(loc, ival, incr); 897 } 898 } 899 return ival; 900 } 901 902 /// Recursively generates tensor expression. 903 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 904 linalg::GenericOp op, unsigned exp, unsigned ldx) { 905 Location loc = op.getLoc(); 906 if (exp == -1u) 907 return Value(); 908 if (merger.exp(exp).kind == Kind::kTensor) 909 return genTensorLoad(merger, codegen, rewriter, op, exp); 910 if (merger.exp(exp).kind == Kind::kInvariant) 911 return genInvariantValue(merger, codegen, rewriter, exp); 912 if (merger.exp(exp).kind == Kind::kIndex) 913 return genIndexValue(merger, codegen, rewriter, exp, ldx); 914 Value v0 = 915 genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx); 916 Value v1 = 917 genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1, ldx); 918 return merger.buildExp(rewriter, loc, exp, v0, v1); 919 } 920 921 /// Determines if affine expression is invariant. 922 static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, 923 unsigned ldx, bool &atLevel) { 924 switch (a.getKind()) { 925 case AffineExprKind::DimId: { 926 unsigned idx = a.cast<AffineDimExpr>().getPosition(); 927 if (idx == ldx) 928 atLevel = true; 929 return codegen.loops[idx] != nullptr; // no longer in play? 930 } 931 case AffineExprKind::Add: 932 case AffineExprKind::Mul: { 933 auto binOp = a.cast<AffineBinaryOpExpr>(); 934 return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && 935 isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); 936 } 937 default: 938 return true; 939 } 940 } 941 942 /// Hoists loop invariant tensor loads for which indices have been exhausted. 943 static void genInvariants(Merger &merger, CodeGen &codegen, 944 PatternRewriter &rewriter, linalg::GenericOp op, 945 unsigned exp, unsigned ldx, bool atStart, 946 Kind last = Kind::kTensor) { 947 if (exp == -1u) 948 return; 949 if (merger.exp(exp).kind == Kind::kTensor) { 950 // Inspect tensor indices. 951 bool atLevel = ldx == -1u; 952 OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; 953 auto map = op.getTiedIndexingMap(t); 954 auto enc = getSparseTensorEncoding(t->get().getType()); 955 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 956 AffineExpr a = map.getResult(perm(enc, d)); 957 if (!isInvariantAffine(codegen, a, ldx, atLevel)) 958 return; // still in play 959 } 960 // All exhausted at this level (atLevel denotes exactly at this level). 961 if (!atLevel) 962 return; 963 OpOperand *lhs = op.getOutputOperand(0); 964 if (lhs == t) { 965 // Start or end a scalarized reduction 966 if (atStart) { 967 Value load = genTensorLoad(merger, codegen, rewriter, op, exp); 968 codegen.redKind = getReduction(last); 969 codegen.redExp = exp; 970 updateReduc(merger, codegen, load); 971 } else { 972 Value redVal = codegen.redVal; 973 updateReduc(merger, codegen, Value()); 974 codegen.redExp = -1u; 975 codegen.redKind = kNoReduc; 976 genTensorStore(merger, codegen, rewriter, op, redVal); 977 } 978 } else { 979 // Start or end loop invariant hoisting of a tensor load. 980 merger.exp(exp).val = 981 atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); 982 } 983 } else if (merger.exp(exp).kind != Kind::kInvariant && 984 merger.exp(exp).kind != Kind::kIndex) { 985 // Traverse into the binary operations. Note that we only hoist 986 // tensor loads, since subsequent MLIR/LLVM passes know how to 987 // deal with all other kinds of derived loop invariants. 988 Kind last = merger.exp(exp).kind; 989 unsigned e0 = merger.exp(exp).children.e0; 990 unsigned e1 = merger.exp(exp).children.e1; 991 genInvariants(merger, codegen, rewriter, op, e0, ldx, atStart, last); 992 genInvariants(merger, codegen, rewriter, op, e1, ldx, atStart, last); 993 } 994 } 995 996 /// Generates an expanded access pattern in innermost dimension. 997 static void genExpansion(Merger &merger, CodeGen &codegen, 998 PatternRewriter &rewriter, linalg::GenericOp op, 999 unsigned at, bool atStart) { 1000 OpOperand *lhs = codegen.sparseOut; 1001 if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 || 1002 at != codegen.outerParNest) 1003 return; // not needed at this level 1004 // Generate start or end of an expanded access pattern. 1005 Value tensor = lhs->get(); 1006 Location loc = op.getLoc(); 1007 if (atStart) { 1008 auto dynShape = {ShapedType::kDynamicSize}; 1009 Type etp = tensor.getType().cast<ShapedType>().getElementType(); 1010 Type t1 = MemRefType::get(dynShape, etp); 1011 Type t2 = MemRefType::get(dynShape, rewriter.getI1Type()); 1012 Type t3 = MemRefType::get(dynShape, rewriter.getIndexType()); 1013 Type t4 = rewriter.getIndexType(); 1014 auto res = 1015 rewriter.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor); 1016 assert(res.getNumResults() == 4); 1017 assert(!codegen.expValues); 1018 codegen.expValues = res.getResult(0); 1019 codegen.expFilled = res.getResult(1); 1020 codegen.expAdded = res.getResult(2); 1021 codegen.expCount = res.getResult(3); 1022 } else { 1023 assert(codegen.expValues); 1024 rewriter.create<CompressOp>(loc, tensor, codegen.lexIdx, codegen.expValues, 1025 codegen.expFilled, codegen.expAdded, 1026 codegen.expCount); 1027 codegen.expValues = codegen.expFilled = codegen.expAdded = 1028 codegen.expCount = Value(); 1029 } 1030 } 1031 1032 /// Generates initialization code for the subsequent loop sequence at 1033 /// current index level. Returns true if the loop sequence needs to 1034 /// maintain the universal index. 1035 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1036 linalg::GenericOp op, std::vector<unsigned> &topSort, 1037 unsigned at, BitVector &inits) { 1038 bool needsUniv = false; 1039 Location loc = op.getLoc(); 1040 unsigned idx = topSort[at]; 1041 1042 // Initialize sparse positions. 1043 for (unsigned b = 0, be = inits.size(); b < be; b++) { 1044 if (inits[b]) { 1045 unsigned tensor = merger.tensor(b); 1046 assert(idx == merger.index(b)); 1047 if (merger.isDim(b, Dim::kSparse)) { 1048 // Initialize sparse index. 1049 unsigned pat = at; 1050 for (; pat != 0; pat--) { 1051 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1052 break; 1053 } 1054 Value ptr = codegen.pointers[tensor][idx]; 1055 Value one = constantIndex(rewriter, loc, 1); 1056 Value p0 = (pat == 0) ? constantIndex(rewriter, loc, 0) 1057 : codegen.pidxs[tensor][topSort[pat - 1]]; 1058 codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); 1059 Value p1 = rewriter.create<arith::AddIOp>(loc, p0, one); 1060 codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); 1061 } else { 1062 // Dense index still in play. 1063 needsUniv = true; 1064 } 1065 } 1066 } 1067 1068 // Initialize the universal dense index. 1069 codegen.loops[idx] = constantIndex(rewriter, loc, 0); 1070 return needsUniv; 1071 } 1072 1073 /// Returns vectorization strategy. Any implicit inner loop in the Linalg 1074 /// operation is a candidate. Whether it is actually converted to SIMD code 1075 /// depends on the requested strategy. 1076 static bool isVectorFor(CodeGen &codegen, bool isInner, bool isReduction, 1077 bool isSparse) { 1078 // Reject vectorization of sparse output, unless innermost is reduction. 1079 if (codegen.sparseOut && !isReduction) 1080 return false; 1081 // Inspect strategy. 1082 switch (codegen.options.vectorizationStrategy) { 1083 case SparseVectorizationStrategy::kNone: 1084 return false; 1085 case SparseVectorizationStrategy::kDenseInnerLoop: 1086 return isInner && !isSparse; 1087 case SparseVectorizationStrategy::kAnyStorageInnerLoop: 1088 return isInner; 1089 } 1090 llvm_unreachable("unexpected vectorization strategy"); 1091 } 1092 1093 /// Returns parallelization strategy. Any implicit loop in the Linalg operation 1094 /// that is marked "parallel" is a candidate. Whether it is actually converted 1095 /// to a parallel operation depends on the requested strategy. 1096 static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, 1097 bool isSparse, bool isVector) { 1098 // Reject parallelization of sparse output. 1099 if (codegen.sparseOut) 1100 return false; 1101 // Inspect strategy. 1102 switch (codegen.options.parallelizationStrategy) { 1103 case SparseParallelizationStrategy::kNone: 1104 return false; 1105 case SparseParallelizationStrategy::kDenseOuterLoop: 1106 return isOuter && !isSparse && !isReduction && !isVector; 1107 case SparseParallelizationStrategy::kAnyStorageOuterLoop: 1108 return isOuter && !isReduction && !isVector; 1109 case SparseParallelizationStrategy::kDenseAnyLoop: 1110 return !isSparse && !isReduction && !isVector; 1111 case SparseParallelizationStrategy::kAnyStorageAnyLoop: 1112 return !isReduction && !isVector; 1113 } 1114 llvm_unreachable("unexpected parallelization strategy"); 1115 } 1116 1117 /// Checks unit stride for dense tensors. The iteration graph may have ignored 1118 /// dense access patterns in order to avoid cycles (sparse access patterns are 1119 /// always placed innermost), but that means dense access has become strided. 1120 /// This prevents effective vectorization. 1121 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, 1122 unsigned idx) { 1123 for (OpOperand *t : op.getInputAndOutputOperands()) { 1124 if (!getSparseTensorEncoding(t->get().getType())) { 1125 auto map = op.getTiedIndexingMap(t); 1126 for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { 1127 AffineExpr a = map.getResult(d); 1128 // Report non-unit stride if innermost index appears at an outer 1129 // dimension (true non-unit stride) or if the innermost index appears 1130 // in a compound subscript in the innermost dimension. Even if the 1131 // latter is unit stride, it does not play well with scatter/gather. 1132 // TODO: accept unit stride affine innermost like a[i,j+k+1]? 1133 if (a.isFunctionOfDim(idx) && 1134 ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) 1135 return false; 1136 } 1137 } 1138 } 1139 return true; 1140 } 1141 1142 /// Generates a for-loop on a single index. 1143 static Operation *genFor(Merger &merger, CodeGen &codegen, 1144 PatternRewriter &rewriter, linalg::GenericOp op, 1145 bool isOuter, bool isInner, unsigned idx, 1146 BitVector &indices) { 1147 unsigned fb = indices.find_first(); 1148 unsigned tensor = merger.tensor(fb); 1149 assert(idx == merger.index(fb)); 1150 auto iteratorTypes = op.iterator_types().getValue(); 1151 bool isReduction = isReductionIterator(iteratorTypes[idx]); 1152 bool isSparse = merger.isDim(fb, Dim::kSparse); 1153 bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) && 1154 denseUnitStrides(merger, op, idx); 1155 bool isParallel = 1156 isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); 1157 1158 // Prepare vector length. 1159 if (isVector) 1160 codegen.curVecLength = codegen.options.vectorLength; 1161 1162 // Loop bounds and increment. 1163 Location loc = op.getLoc(); 1164 Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; 1165 Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; 1166 Value step = constantIndex(rewriter, loc, codegen.curVecLength); 1167 1168 // Emit a parallel loop. 1169 if (isParallel) { 1170 assert(!isVector); 1171 scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step); 1172 if (isSparse) 1173 codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; 1174 else 1175 codegen.loops[idx] = parOp.getInductionVars()[0]; 1176 rewriter.setInsertionPointToStart(parOp.getBody()); 1177 return parOp; 1178 } 1179 1180 // Emit a sequential or vector loop. 1181 SmallVector<Value, 4> operands; 1182 if (codegen.redVal) { 1183 // In a vector loop, bring reduction into SIMD form, if not already. 1184 if (isVector && !codegen.redVal.getType().isa<VectorType>()) { 1185 VectorType vtp = vectorType(codegen, codegen.redVal.getType()); 1186 Value vred = genVectorReducInit(codegen, rewriter, loc, vtp); 1187 updateReduc(merger, codegen, vred); 1188 } 1189 operands.push_back(codegen.redVal); 1190 } 1191 if (codegen.expValues) 1192 operands.push_back(codegen.expCount); 1193 scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, operands); 1194 if (codegen.redVal) 1195 updateReduc(merger, codegen, forOp.getRegionIterArgs().front()); 1196 if (codegen.expValues) 1197 codegen.expCount = forOp.getRegionIterArgs().back(); 1198 // Assign induction variable to sparse or dense index. 1199 Value iv = forOp.getInductionVar(); 1200 if (isSparse) 1201 codegen.pidxs[tensor][idx] = iv; 1202 else 1203 codegen.loops[idx] = iv; 1204 rewriter.setInsertionPointToStart(forOp.getBody()); 1205 // Share vector iteration mask between all subsequent loads/stores. 1206 if (isVector) 1207 codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); 1208 return forOp; 1209 } 1210 1211 /// Emit a while-loop for co-iteration over multiple indices. 1212 static Operation *genWhile(Merger &merger, CodeGen &codegen, 1213 PatternRewriter &rewriter, linalg::GenericOp op, 1214 unsigned idx, bool needsUniv, 1215 BitVector &indices) { 1216 SmallVector<Type, 4> types; 1217 SmallVector<Value, 4> operands; 1218 // Construct the while-loop with a parameter for each index. 1219 Type indexType = rewriter.getIndexType(); 1220 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1221 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1222 unsigned tensor = merger.tensor(b); 1223 assert(idx == merger.index(b)); 1224 types.push_back(indexType); 1225 operands.push_back(codegen.pidxs[tensor][idx]); 1226 } 1227 } 1228 if (codegen.redVal) { 1229 types.push_back(codegen.redVal.getType()); 1230 operands.push_back(codegen.redVal); 1231 } 1232 if (codegen.expValues) { 1233 types.push_back(indexType); 1234 operands.push_back(codegen.expCount); 1235 } 1236 if (needsUniv) { 1237 types.push_back(indexType); 1238 operands.push_back(codegen.loops[idx]); 1239 } 1240 assert(types.size() == operands.size()); 1241 Location loc = op.getLoc(); 1242 scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 1243 1244 SmallVector<Location> locs(types.size(), loc); 1245 Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, locs); 1246 Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, types, locs); 1247 1248 // Build the "before" region, which effectively consists 1249 // of a conjunction of "i < upper" tests on all induction. 1250 rewriter.setInsertionPointToStart(&whileOp.getBefore().front()); 1251 Value cond; 1252 unsigned o = 0; 1253 for (unsigned b = 0, be = indices.size(); b < be; b++) { 1254 if (indices[b] && merger.isDim(b, Dim::kSparse)) { 1255 unsigned tensor = merger.tensor(b); 1256 assert(idx == merger.index(b)); 1257 Value op1 = before->getArgument(o); 1258 Value op2 = codegen.highs[tensor][idx]; 1259 Value opc = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, 1260 op1, op2); 1261 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, opc) : opc; 1262 codegen.pidxs[tensor][idx] = after->getArgument(o++); 1263 } 1264 } 1265 if (codegen.redVal) 1266 updateReduc(merger, codegen, after->getArgument(o++)); 1267 if (codegen.expValues) 1268 codegen.expCount = after->getArgument(o++); 1269 if (needsUniv) 1270 codegen.loops[idx] = after->getArgument(o++); 1271 assert(o == operands.size()); 1272 rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments()); 1273 rewriter.setInsertionPointToStart(&whileOp.getAfter().front()); 1274 return whileOp; 1275 } 1276 1277 /// Generates a for-loop or a while-loop, depending on whether it implements 1278 /// singleton iteration or co-iteration over the given conjunction. 1279 static Operation *genLoop(Merger &merger, CodeGen &codegen, 1280 PatternRewriter &rewriter, linalg::GenericOp op, 1281 std::vector<unsigned> &topSort, unsigned at, 1282 bool needsUniv, BitVector &indices) { 1283 unsigned idx = topSort[at]; 1284 if (indices.count() == 1) { 1285 bool isOuter = at == 0; 1286 bool isInner = at == topSort.size() - 1; 1287 return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, 1288 indices); 1289 } 1290 return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); 1291 } 1292 1293 /// Generates the local variables for this loop, consisting of the sparse 1294 /// indices, restored universal dense index, and dense positions. 1295 static void genLocals(Merger &merger, CodeGen &codegen, 1296 PatternRewriter &rewriter, linalg::GenericOp op, 1297 std::vector<unsigned> &topSort, unsigned at, 1298 bool needsUniv, BitVector &locals) { 1299 Location loc = op.getLoc(); 1300 unsigned idx = topSort[at]; 1301 1302 // Initialize sparse indices. 1303 Value min; 1304 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1305 if (locals[b] && merger.isDim(b, Dim::kSparse)) { 1306 unsigned tensor = merger.tensor(b); 1307 assert(idx == merger.index(b)); 1308 Value ptr = codegen.indices[tensor][idx]; 1309 Value s = codegen.pidxs[tensor][idx]; 1310 Value load = genLoad(codegen, rewriter, loc, ptr, s); 1311 codegen.idxs[tensor][idx] = load; 1312 if (!needsUniv) { 1313 if (min) { 1314 Value cmp = rewriter.create<arith::CmpIOp>( 1315 loc, arith::CmpIPredicate::ult, load, min); 1316 min = rewriter.create<arith::SelectOp>(loc, cmp, load, min); 1317 } else { 1318 min = load; 1319 } 1320 } 1321 } 1322 } 1323 1324 // Merge dense universal index over minimum. 1325 if (min) { 1326 assert(!needsUniv); 1327 codegen.loops[idx] = min; 1328 } 1329 1330 // Initialize dense positions. Note that we generate dense indices of the 1331 // output tensor unconditionally, since they may not appear in the lattice, 1332 // but may be needed for linearized codegen. 1333 for (unsigned b = 0, be = locals.size(); b < be; b++) { 1334 if ((locals[b] || merger.isOutTensor(b, idx)) && 1335 merger.isDim(b, Dim::kDense)) { 1336 unsigned tensor = merger.tensor(b); 1337 assert(idx == merger.index(b)); 1338 unsigned pat = at; 1339 for (; pat != 0; pat--) 1340 if (codegen.pidxs[tensor][topSort[pat - 1]]) 1341 break; 1342 Value p = (pat == 0) ? constantIndex(rewriter, loc, 0) 1343 : codegen.pidxs[tensor][topSort[pat - 1]]; 1344 codegen.pidxs[tensor][idx] = genAddress( 1345 codegen, rewriter, loc, codegen.sizes[idx], p, codegen.loops[idx]); 1346 } 1347 } 1348 1349 // Move the insertion indices in lexicographic index order. During access 1350 // pattern expansion, we can skip setting the innermost dimension. 1351 if (codegen.sparseOut && !codegen.expValues) { 1352 Value pos = constantIndex(rewriter, loc, at); 1353 rewriter.create<memref::StoreOp>(loc, codegen.loops[idx], codegen.lexIdx, 1354 pos); 1355 } 1356 } 1357 1358 /// Generates the induction structure for a while-loop. 1359 static void genWhileInduction(Merger &merger, CodeGen &codegen, 1360 PatternRewriter &rewriter, linalg::GenericOp op, 1361 unsigned idx, bool needsUniv, 1362 BitVector &induction, 1363 scf::WhileOp whileOp) { 1364 Location loc = op.getLoc(); 1365 // Finalize each else branch of all if statements. 1366 if (codegen.redVal || codegen.expValues) { 1367 while (auto ifOp = dyn_cast_or_null<scf::IfOp>( 1368 rewriter.getInsertionBlock()->getParentOp())) { 1369 unsigned y = 0; 1370 SmallVector<Value, 4> yields; 1371 if (codegen.redVal) { 1372 yields.push_back(codegen.redVal); 1373 updateReduc(merger, codegen, ifOp.getResult(y++)); 1374 } 1375 if (codegen.expValues) { 1376 yields.push_back(codegen.expCount); 1377 codegen.expCount = ifOp->getResult(y++); 1378 } 1379 assert(y == yields.size()); 1380 rewriter.create<scf::YieldOp>(loc, yields); 1381 rewriter.setInsertionPointAfter(ifOp); 1382 } 1383 } 1384 rewriter.setInsertionPointToEnd(&whileOp.getAfter().front()); 1385 // Finalize the induction. Note that the induction could be performed 1386 // in the individual if-branches to avoid re-evaluating the conditions. 1387 // However, that would result in a rather elaborate forest of yield 1388 // instructions during code generation. Moreover, performing the induction 1389 // after the if-statements more closely resembles code generated by TACO. 1390 unsigned o = 0; 1391 SmallVector<Value, 4> operands; 1392 Value one = constantIndex(rewriter, loc, 1); 1393 for (unsigned b = 0, be = induction.size(); b < be; b++) { 1394 if (induction[b] && merger.isDim(b, Dim::kSparse)) { 1395 unsigned tensor = merger.tensor(b); 1396 assert(idx == merger.index(b)); 1397 Value op1 = codegen.idxs[tensor][idx]; 1398 Value op2 = codegen.loops[idx]; 1399 Value op3 = codegen.pidxs[tensor][idx]; 1400 Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1401 op1, op2); 1402 Value add = rewriter.create<arith::AddIOp>(loc, op3, one); 1403 operands.push_back(rewriter.create<arith::SelectOp>(loc, cmp, add, op3)); 1404 codegen.pidxs[tensor][idx] = whileOp->getResult(o++); 1405 } 1406 } 1407 if (codegen.redVal) { 1408 operands.push_back(codegen.redVal); 1409 updateReduc(merger, codegen, whileOp->getResult(o++)); 1410 } 1411 if (codegen.expValues) { 1412 operands.push_back(codegen.expCount); 1413 codegen.expCount = whileOp->getResult(o++); 1414 } 1415 if (needsUniv) { 1416 operands.push_back( 1417 rewriter.create<arith::AddIOp>(loc, codegen.loops[idx], one)); 1418 codegen.loops[idx] = whileOp->getResult(o++); 1419 } 1420 assert(o == operands.size()); 1421 rewriter.create<scf::YieldOp>(loc, operands); 1422 rewriter.setInsertionPointAfter(whileOp); 1423 } 1424 1425 /// Generates the induction structure for a for-loop. 1426 static void genForInduction(Merger &merger, CodeGen &codegen, 1427 PatternRewriter &rewriter, linalg::GenericOp op, 1428 Operation *loop) { 1429 Location loc = op.getLoc(); 1430 unsigned o = 0; 1431 SmallVector<Value, 4> operands; 1432 if (codegen.redVal) { 1433 operands.push_back(codegen.redVal); 1434 updateReduc(merger, codegen, loop->getResult(o++)); 1435 } 1436 if (codegen.expValues) { 1437 operands.push_back(codegen.expCount); 1438 codegen.expCount = loop->getResult(o++); 1439 } 1440 assert(o == operands.size()); 1441 if (o > 0) 1442 rewriter.create<scf::YieldOp>(loc, operands); 1443 rewriter.setInsertionPointAfter(loop); 1444 } 1445 1446 /// Generates a single if-statement within a while-loop. 1447 static scf::IfOp genIf(Merger &merger, CodeGen &codegen, 1448 PatternRewriter &rewriter, linalg::GenericOp op, 1449 unsigned idx, BitVector &conditions) { 1450 Location loc = op.getLoc(); 1451 SmallVector<Type, 4> types; 1452 Value cond; 1453 for (unsigned b = 0, be = conditions.size(); b < be; b++) { 1454 if (conditions[b]) { 1455 unsigned tensor = merger.tensor(b); 1456 assert(idx == merger.index(b)); 1457 Value clause; 1458 if (merger.isDim(b, Dim::kSparse)) { 1459 Value op1 = codegen.idxs[tensor][idx]; 1460 Value op2 = codegen.loops[idx]; 1461 clause = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 1462 op1, op2); 1463 } else { 1464 clause = constantI1(rewriter, loc, true); 1465 } 1466 cond = cond ? rewriter.create<arith::AndIOp>(loc, cond, clause) : clause; 1467 } 1468 } 1469 if (codegen.redVal) 1470 types.push_back(codegen.redVal.getType()); 1471 if (codegen.expValues) 1472 types.push_back(rewriter.getIndexType()); 1473 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, types, cond, /*else=*/true); 1474 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1475 return ifOp; 1476 } 1477 1478 /// Generates end of true branch of if-statement within a while-loop. 1479 static void endIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1480 linalg::GenericOp op, scf::IfOp ifOp, Operation *loop, 1481 Value redInput, Value cntInput) { 1482 SmallVector<Value, 4> operands; 1483 if (codegen.redVal) { 1484 operands.push_back(codegen.redVal); 1485 updateReduc(merger, codegen, redInput); 1486 } 1487 if (codegen.expValues) { 1488 operands.push_back(codegen.expCount); 1489 codegen.expCount = cntInput; 1490 } 1491 if (!operands.empty()) 1492 rewriter.create<scf::YieldOp>(op.getLoc(), operands); 1493 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); 1494 } 1495 1496 //===----------------------------------------------------------------------===// 1497 // Sparse compiler synthesis methods (loop sequence). 1498 //===----------------------------------------------------------------------===// 1499 1500 /// Starts a loop sequence at given level. Returns true if 1501 /// the universal loop index must be maintained at this level. 1502 static bool startLoopSeq(Merger &merger, CodeGen &codegen, 1503 PatternRewriter &rewriter, linalg::GenericOp op, 1504 std::vector<unsigned> &topSort, unsigned exp, 1505 unsigned at, unsigned idx, unsigned ldx, 1506 unsigned lts) { 1507 assert(codegen.curVecLength == 1); 1508 assert(!codegen.loops[idx]); 1509 // Emit invariants at this loop sequence level. 1510 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/true); 1511 // Emit access pattern expansion for sparse tensor output. 1512 genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/true); 1513 // Emit further intitialization at this loop sequence level. 1514 unsigned l0 = merger.set(lts)[0]; 1515 bool needsUniv = 1516 genInit(merger, codegen, rewriter, op, topSort, at, merger.lat(l0).bits); 1517 // Maintain the universal index only if it is actually 1518 // consumed by a subsequent lattice point. 1519 if (needsUniv) { 1520 unsigned lsize = merger.set(lts).size(); 1521 for (unsigned i = 1; i < lsize; i++) { 1522 unsigned li = merger.set(lts)[i]; 1523 if (!merger.hasAnyDimOf(merger.lat(li).simple, Dim::kSparse)) 1524 return true; 1525 } 1526 } 1527 return false; 1528 } 1529 1530 /// Starts a single loop in current sequence. 1531 static Operation *startLoop(Merger &merger, CodeGen &codegen, 1532 PatternRewriter &rewriter, linalg::GenericOp op, 1533 std::vector<unsigned> &topSort, unsigned at, 1534 unsigned li, bool needsUniv) { 1535 assert(codegen.curVecLength == 1); 1536 // Emit the for/while-loop control. 1537 Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at, 1538 needsUniv, merger.lat(li).simple); 1539 // Emit the locals for this loop. 1540 genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, 1541 merger.lat(li).bits); 1542 return loop; 1543 } 1544 1545 /// Ends a single loop in current sequence. Returns new values for needsUniv. 1546 static bool endLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1547 linalg::GenericOp op, Operation *loop, unsigned idx, 1548 unsigned li, bool needsUniv) { 1549 codegen.curVecLength = 1; 1550 // End a while-loop. 1551 if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { 1552 genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, 1553 merger.lat(li).bits, whileOp); 1554 return needsUniv; 1555 } 1556 // End a for-loop. 1557 genForInduction(merger, codegen, rewriter, op, loop); 1558 return false; 1559 } 1560 1561 /// Ends a loop sequence at given level. 1562 static void endLoopSeq(Merger &merger, CodeGen &codegen, 1563 PatternRewriter &rewriter, linalg::GenericOp op, 1564 unsigned exp, unsigned at, unsigned idx, unsigned ldx) { 1565 assert(codegen.curVecLength == 1); 1566 codegen.loops[idx] = Value(); 1567 // Bring a pending reduction back from SIMD form when sequence ends. 1568 if (codegen.redVal) 1569 if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>()) 1570 updateReduc(merger, codegen, 1571 genVectorReducEnd(codegen, rewriter, op.getLoc(), vtp)); 1572 // Unmark bookkeeping of invariants and loop index. 1573 genInvariants(merger, codegen, rewriter, op, exp, ldx, /*atStart=*/false); 1574 // Finalize access pattern expansion for sparse tensor output. 1575 genExpansion(merger, codegen, rewriter, op, at, /*atStart=*/false); 1576 } 1577 1578 /// Recursively generates code while computing iteration lattices in order 1579 /// to manage the complexity of implementing co-iteration over unions 1580 /// and intersections of sparse iterations spaces. 1581 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, 1582 linalg::GenericOp op, std::vector<unsigned> &topSort, 1583 unsigned exp, unsigned at) { 1584 // At each leaf, assign remaining tensor (sub)expression to output tensor. 1585 if (at == topSort.size()) { 1586 unsigned ldx = topSort[at - 1]; 1587 Value rhs = genExp(merger, codegen, rewriter, op, exp, ldx); 1588 genTensorStore(merger, codegen, rewriter, op, rhs); 1589 return; 1590 } 1591 1592 // Construct iteration lattices for current loop index, with L0 at top. 1593 unsigned idx = topSort[at]; 1594 unsigned ldx = at == 0 ? -1u : topSort[at - 1]; 1595 unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); 1596 1597 // Start a loop sequence. 1598 bool needsUniv = startLoopSeq(merger, codegen, rewriter, op, topSort, exp, at, 1599 idx, ldx, lts); 1600 1601 // Emit a loop for every lattice point L0 >= Li in this loop sequence. 1602 unsigned lsize = merger.set(lts).size(); 1603 for (unsigned i = 0; i < lsize; i++) { 1604 // Start a loop. 1605 unsigned li = merger.set(lts)[i]; 1606 Operation *loop = 1607 startLoop(merger, codegen, rewriter, op, topSort, at, li, needsUniv); 1608 1609 // Visit all lattices points with Li >= Lj to generate the 1610 // loop-body, possibly with if statements for coiteration. 1611 Value redInput = codegen.redVal; 1612 Value cntInput = codegen.expCount; 1613 bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; 1614 for (unsigned j = 0; j < lsize; j++) { 1615 unsigned lj = merger.set(lts)[j]; 1616 unsigned ej = merger.lat(lj).exp; 1617 if (li == lj || merger.latGT(li, lj)) { 1618 // Recurse into body of each branch. 1619 if (isWhile) { 1620 scf::IfOp ifOp = 1621 genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); 1622 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1623 endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput); 1624 } else { 1625 genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); 1626 } 1627 } 1628 } 1629 1630 // End a loop. 1631 needsUniv = 1632 endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv); 1633 } 1634 1635 // End a loop sequence. 1636 endLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx); 1637 } 1638 1639 /// Converts the result computed by the sparse kernel into the required form. 1640 static void genResult(Merger &merger, CodeGen &codegen, 1641 PatternRewriter &rewriter, linalg::GenericOp op) { 1642 OpOperand *lhs = op.getOutputOperand(0); 1643 Type resType = lhs->get().getType(); 1644 if (getSparseTensorEncoding(resType)) { 1645 // The sparse tensor rematerializes from the original sparse tensor's 1646 // underlying sparse storage format. 1647 rewriter.replaceOpWithNewOp<LoadOp>(op, resType, lhs->get(), 1648 codegen.sparseOut == lhs); 1649 } else { 1650 // To rematerialize an non-annotated tensor, simply load it 1651 // from the bufferized value. 1652 Value val = codegen.buffers.back(); // value array 1653 rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val); 1654 } 1655 } 1656 1657 //===----------------------------------------------------------------------===// 1658 // Sparse compiler rewriting methods. 1659 //===----------------------------------------------------------------------===// 1660 1661 namespace { 1662 1663 /// Sparse rewriting rule for generic Lingalg operation. 1664 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { 1665 public: 1666 GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) 1667 : OpRewritePattern<linalg::GenericOp>(context), options(o) {} 1668 1669 LogicalResult matchAndRewrite(linalg::GenericOp op, 1670 PatternRewriter &rewriter) const override { 1671 // Detects sparse annotations and translate the per-dimension sparsity 1672 // information for all tensors to loop indices in the kernel. 1673 assert(op.getNumOutputs() == 1); 1674 unsigned numTensors = op.getNumInputsAndOutputs(); 1675 unsigned numLoops = op.iterator_types().getValue().size(); 1676 Merger merger(numTensors, numLoops); 1677 if (!findSparseAnnotations(merger, op)) 1678 return failure(); 1679 1680 // Computes a topologically sorted iteration graph to ensure 1681 // tensors are visited in natural index order. Fails on cycles. 1682 // This assumes that higher-level passes have already put the 1683 // tensors in each tensor expression in a feasible order. 1684 std::vector<unsigned> topSort; 1685 if (!computeIterationGraph(merger, op, topSort, 1686 SortMask::kIncludeUndef | 1687 SortMask::kIncludeDense) && 1688 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && 1689 !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && 1690 !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) 1691 return failure(); 1692 1693 // Builds the tensor expression for the Linalg operation in SSA form. 1694 Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op); 1695 if (!optExp.hasValue()) 1696 return failure(); 1697 unsigned exp = optExp.getValue(); 1698 1699 // Rejects an inadmissable tensor expression. 1700 OpOperand *sparseOut = nullptr; 1701 unsigned outerParNest = 0; 1702 if (!isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut, 1703 outerParNest)) 1704 return failure(); 1705 1706 // Recursively generates code. 1707 merger.setHasSparseOut(sparseOut != nullptr); 1708 CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest); 1709 genBuffers(merger, codegen, rewriter, op); 1710 genStmt(merger, codegen, rewriter, op, topSort, exp, 0); 1711 genResult(merger, codegen, rewriter, op); 1712 return success(); 1713 } 1714 1715 private: 1716 /// Options to control sparse code generation. 1717 SparsificationOptions options; 1718 }; 1719 1720 } // namespace 1721 1722 /// Populates the given patterns list with rewriting rules required for 1723 /// the sparsification of linear algebra operations. 1724 void mlir::populateSparsificationPatterns( 1725 RewritePatternSet &patterns, const SparsificationOptions &options) { 1726 patterns.add<GenericOpSparsifier>(patterns.getContext(), options); 1727 } 1728